In [10]:
import torch
from torch import nn
from torchsummary import summary

# Deep&Cross

In [12]:
class BaseModel:
    
    pass

class DCN(BaseModel):
    """Instantiates the Deep&Cross Network architecture. Including DCN-V (parameterization='vector')
    and DCN-M (parameterization='matrix').

    :param linear_feature_columns: An iterable containing all the features used by linear part of the model.
    :param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
    :param cross_num: positive integet,cross layer number
    :param cross_parameterization: str, ``"vector"`` or ``"matrix"``, how to parameterize the cross network.
    :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
    :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
    :param l2_reg_cross: float. L2 regularizer strength applied to cross net
    :param l2_reg_dnn: float. L2 regularizer strength applied to DNN
    :param init_std: float,to use as the initialize std of embedding vector
    :param seed: integer ,to use as random seed.
    :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
    :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not DNN
    :param dnn_activation: Activation function to use in DNN
    :param task: str, ``"binary"`` for  binary logloss or  ``"regression"`` for regression loss
    :param device: str, ``"cpu"`` or ``"cuda:0"``
    :return: A PyTorch model instance.
    
    """

    def __init__(self, linear_feature_columns, dnn_feature_columns, cross_num=2, cross_parameterization='vector',
                 dnn_hidden_units=(128, 128), l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_cross=0.00001,
                 l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False,
                 task='binary', device='cpu'):

        super(DCN, self).__init__(linear_feature_columns=linear_feature_columns,
                                  dnn_feature_columns=dnn_feature_columns, l2_reg_embedding=l2_reg_embedding,
                                  init_std=init_std, seed=seed, task=task, device=device)
        self.dnn_hidden_units = dnn_hidden_units
        self.cross_num = cross_num

        self.dnn = DNN(self.compute_input_dim(dnn_feature_columns), dnn_hidden_units,
                       activation=dnn_activation, use_bn=dnn_use_bn, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout,
                       init_std=init_std, device=device)

        # 计算堆叠层的输入特征维度
        if len(self.dnn_hidden_units) > 0 and self.cross_num > 0:
            dnn_linear_in_feature = self.compute_input_dim(dnn_feature_columns) + dnn_hidden_units[-1]
        elif len(self.dnn_hidden_units) > 0:
            dnn_linear_in_feature = dnn_hidden_units[-1]
        elif self.cross_num > 0:
            dnn_linear_in_feature = self.compute_input_dim(dnn_feature_columns)

        #  logistic层
        self.dnn_linear = nn.Linear(dnn_linear_in_feature, 1, bias=False).to(
            device)

        #  crossnet
        self.crossnet = CrossNet(in_features=self.compute_input_dim(dnn_feature_columns),
                                 layer_num=cross_num, parameterization=cross_parameterization, device=device)
        self.add_regularization_weight(
            filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn)
        self.add_regularization_weight(self.dnn_linear.weight, l2_reg_linear)
        self.add_regularization_weight(self.crossnet.kernels, l2_reg_cross)
        self.to(device)

    def forward(self, X):

        logit = self.linear_model(X)
        sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns,
                                                                                  self.embedding_dict)

        dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list)

        if len(self.dnn_hidden_units) > 0 and self.cross_num > 0:  # Deep & Cross
            # DNN和Cross网络中输入相同的数据
            deep_out = self.dnn(dnn_input)
            cross_out = self.crossnet(dnn_input)
            stack_out = torch.cat((cross_out, deep_out), dim=-1)
            logit += self.dnn_linear(stack_out)
        elif len(self.dnn_hidden_units) > 0:  # Only Deep
            deep_out = self.dnn(dnn_input)
            logit += self.dnn_linear(deep_out)
        elif self.cross_num > 0:  # Only Cross
            cross_out = self.crossnet(dnn_input)
            logit += self.dnn_linear(cross_out)
        else:  # Error
            pass
        y_pred = self.out(logit)
        return y_pred

## Cross网络
意义：Cross可以辅助Deep，减小了Deep的“工作量”，通过特殊的cross layer设计，用更少的参数量有效捕获有意义的、DNN难以捕捉的特征相关性.  
1) 有限高阶：叉乘阶数由网络深度决定，深度$L_c$对应最高 $L_c + 1$ 阶的叉乘
$$
\begin{aligned}
\boldsymbol{x}_{1} &=\boldsymbol{x}_{0} \boldsymbol{x}_{0}^{T} \boldsymbol{w}_{0}+\boldsymbol{x}_{0}=\left[\begin{array}{c}
x_{0,1} \\
x_{0,2}
\end{array}\right]\left[x_{0,1}, x_{0,2}\right]\left[\begin{array}{c}
w_{0,1} \\
w_{0,2}
\end{array}\right]+\left[\begin{array}{l}
x_{0,1} \\
x_{0,2}
\end{array}\right]=\left[\begin{array}{l}
w_{0,1} x_{0,1}^{2}+w_{0,2} x_{0,1} x_{0,2}+x_{0,1} \\
w_{0,1} x_{0,2} x_{0,1}+w_{0,2} x_{0,2}^{2}+x_{0,2}
\end{array}\right] \\
\boldsymbol{x}_{2} &=\boldsymbol{x}_{0} \boldsymbol{x}_{1}^{T} \boldsymbol{w}_{1}+\boldsymbol{x}_{1} \\
&=\left[\begin{array}{l}
w_{1,1} x_{0,1} x_{1,1}+w_{1,2} x_{0,1} x_{1,2}+x_{1,1} \\
w_{1,1} x_{0,2} x_{1,1}+w_{1,2} x_{0,2} x_{1,2}+x_{1,2} \\
\end{array}\right] 
\left[\begin{array}= & \left.w_{0,1} w_{1,1} x_{0,1}^{3}+\left(w_{0,2} w_{1,1}+w_{0,1} w_{1,2}\right) x_{0,1}^{2} x_{0,2}+w_{0,2} w_{1,2} x_{0,1} x_{0,2}^{2}+\left(w_{0,1}+w_{1,1}\right) x_{0,1}^{2}+\left(w_{0,2}+w_{1,2}\right) x_{0,1} x_{0,2}+x_{0,1}\right]
\end{array}\right.
\end{aligned}
$$  


2) 自动叉乘：Cross输出包含了原始特征从一阶（即本身）到 $L_c + 1$阶的所有叉乘组合，而模型参数量仅仅随输入维度成线性增长： $2*L_c*d$

3) 参数共享：不同叉乘项对应的权重不同，但并非每个叉乘组合对应独立的权重（指数数量级）， 通过参数共享，Cross有效降低了参数量。此外，参数共享还使得模型有更强的泛化性和鲁棒性。例如，如果独立训练权重，当训练集中$x_i!=0 and x_j!=0$这个叉乘特征没有出现 ，对应权重肯定是零，而参数共享则不会，类似地，数据集中的一些噪声可以由大部分正常样本来纠正权重参数的学习

In [13]:
class CrossNet(nn.Module):
    """The Cross Network part of Deep&Cross Network model,
    which leans both low and high degree cross feature.
      Input shape
        - 2D tensor with shape: ``(batch_size, units)``.
      Output shape
        - 2D tensor with shape: ``(batch_size, units)``.
      Arguments
        - **in_features** : Positive integer, dimensionality of input features.
        - **input_feature_num**: Positive integer, shape(Input tensor)[-1]
        - **layer_num**: Positive integer, the cross layer number
        - **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix
        - **seed**: A Python integer to use as random seed.
      References
        - [Wang R, Fu B, Fu G, et al. Deep & cross network for ad click predictions[C]//Proceedings of the ADKDD'17. ACM, 2017: 12.](https://arxiv.org/abs/1708.05123)
    """

    def __init__(self, in_features, layer_num=5, seed=1024, device='cpu'):
        super(CrossNet, self).__init__()
        self.layer_num = layer_num
        self.kernels = torch.nn.ParameterList(
            [nn.Parameter(nn.init.xavier_normal_(torch.empty(in_features, 1))) for i in range(self.layer_num)])
        self.bias = torch.nn.ParameterList(
            [nn.Parameter(nn.init.zeros_(torch.empty(in_features, 1))) for i in range(self.layer_num)])
        self.to(device)

    def forward(self, inputs):
        x_0 = inputs.unsqueeze(2)
        print(x_0.shape)
        x_l = x_0 # [B, d, 1]
        for i in range(self.layer_num):
            # 对应维度做点乘运算
            # [B, d, 1] dot [d, 1] = [B, 1, 1]
            xl_w = torch.tensordot(x_l, self.kernels[i], dims=([1], [0]))
            print(xl_w.shape)
            dot_ = torch.matmul(x_0, xl_w)
            print("dot_", dot_.shape)
            x_l = dot_ + self.bias[i] + x_l
            
        x_l = torch.squeeze(x_l, dim=2)
        return x_l

In [14]:
CrossNet(100)

CrossNet(
  (kernels): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 100x1]
      (1): Parameter containing: [torch.FloatTensor of size 100x1]
      (2): Parameter containing: [torch.FloatTensor of size 100x1]
      (3): Parameter containing: [torch.FloatTensor of size 100x1]
      (4): Parameter containing: [torch.FloatTensor of size 100x1]
  )
  (bias): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 100x1]
      (1): Parameter containing: [torch.FloatTensor of size 100x1]
      (2): Parameter containing: [torch.FloatTensor of size 100x1]
      (3): Parameter containing: [torch.FloatTensor of size 100x1]
      (4): Parameter containing: [torch.FloatTensor of size 100x1]
  )
)

In [15]:
A  = CrossNet(100)(torch.rand((128, 100)))

torch.Size([128, 100, 1])
torch.Size([128, 1, 1])
dot_ torch.Size([128, 100, 1])
torch.Size([128, 1, 1])
dot_ torch.Size([128, 100, 1])
torch.Size([128, 1, 1])
dot_ torch.Size([128, 100, 1])
torch.Size([128, 1, 1])
dot_ torch.Size([128, 100, 1])
torch.Size([128, 1, 1])
dot_ torch.Size([128, 100, 1])


# 参考
[Deep Cross Network (深度交叉网络, DCN) 介绍与代码分析](https://blog.csdn.net/Eric_1993/article/details/105600937)  
[揭秘 Deep & Cross : 如何自动构造高阶交叉特征](https://zhuanlan.zhihu.com/p/55234968)  
[【论文导读】Wide&Deep模型的进阶---Cross&Deep模型，附TF2.0复现代码](https://mp.weixin.qq.com/s/DkoaMaXhlgQv1NhZHF-7og)

# 知识点

## tensordot计算　

In [19]:
a = torch.rand((1, 2, 1)) # 各个维度权重相乘再相加
b = torch.rand((2, 1))
c = torch.tensordot(a, b, dims=([1], [0]))
c.shape

torch.Size([1, 1, 1])

In [20]:
a

tensor([[[0.2487],
         [0.1692]]])

In [21]:
b

tensor([[0.4675],
        [0.2660]])

In [22]:
c

tensor([[[0.1613]]])

In [23]:
0.8825*0.5069+0.013*0.9459

0.45963594999999996

## torch.matmul

In [24]:
a = torch.rand((2, 2, 1))
a

tensor([[[0.6680],
         [0.4520]],

        [[0.5840],
         [0.1657]]])

In [25]:
b = torch.rand((2, 1, 1))
b

tensor([[[0.9874]],

        [[0.5792]]])

In [26]:
torch.matmul(a, b)

tensor([[[0.6596],
         [0.4463]],

        [[0.3383],
         [0.0960]]])

In [27]:
0.3604*0.4648

0.16751391999999998