In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class WideDeep(nn.Module):
    def __init__(self, input_config, shared_emb_config=None, use_moving_statistics=True):
        super(WideDeep, self).__init__()
        
        self._use_moving_statistics = use_moving_statistics
        self._emb_dict = {}
        self._seq_names = set()
        self.input_config = input_config

        # 构建嵌入字典
        for feat_name, config in input_config.items():
            input_category = config['category']
            if input_category == 'embedding':
                self._emb_dict[feat_name] = feat_name

            if 'seq_name' in config:
                seq_name = config['seq_name']
                self._seq_names.add(seq_name)

        # 处理共享嵌入
        if shared_emb_config is not None:
            for emb_name, feat_names in shared_emb_config.items():
                for feat_name in feat_names:
                    if feat_name in self._emb_dict:
                        self._emb_dict[feat_name] = emb_name

        self._seq_names = sorted(list(self._seq_names))

        # 初始化嵌入层
        self.embeddings = nn.ModuleDict({
            feat_name: nn.Embedding(config['emb_shape'][0], config['emb_shape'][1])
            for feat_name, config in self.input_config.items() if config['category'] == 'embedding'
        })

        # MLP深度部分
        self.dnn = nn.Sequential(
            nn.Linear(200, 80),
            nn.PReLU(),
            nn.Linear(80, 2)
        )

        # Wide部分 - 使用线性层来进行交叉
        self.wide = nn.Linear(200, 2, bias=False)

    def forward(self, features, mode='train'):
        # 分离特征
        separated_features = { 
            'embedding': {}, 'seq': {}, 'mask': {}, 'value': {}
        }
        for feat_name, feat_values in features.items():
            config = self.input_config[feat_name]
            category = config['category']
            if category == 'embedding':
                separated_features['embedding'][feat_name] = feat_values
            elif category == 'sequence':
                seq_name = config['seq_name']
                if seq_name not in separated_features['seq']:
                    separated_features['seq'][seq_name] = {}
                separated_features['seq'][seq_name][feat_name] = feat_values
            elif category == 'mask':
                separated_features['mask'][feat_name] = feat_values
            elif category == 'value':
                separated_features['value'][feat_name] = feat_values

        # 获取嵌入
        emb_feats = {feat_name: self.embeddings[feat_name](feat_values) for feat_name, feat_values in separated_features['embedding'].items()}
        
        # 合并向量部分
        vec_cat = torch.cat([emb_feats[feat_name] for feat_name in emb_feats], dim=-1)

        # 序列部分处理：这里我们做一个简单的sum pool操作
        seq_cat = {}
        for seq_name, seq_feats in separated_features['seq'].items():
            seq_cat[seq_name] = torch.sum(torch.stack([emb_feats[feat_name] for feat_name in seq_feats], dim=1), dim=1)

        # Wide部分 - 特征对内积
        wide_inputs = [vec_cat]
        for field1_idx in range(len(vec_cat) - 1):
            field1 = vec_cat[field1_idx]
            for field2_idx in range(field1_idx + 1, len(vec_cat)):
                field2 = vec_cat[field2_idx]
                wide_inputs.append(field1 * field2)

        wide_input = torch.cat(wide_inputs, dim=-1)
        wide_logits = self.wide(wide_input)

        # Deep部分 - 使用MLP
        deep_input = torch.cat([vec_cat] + list(seq_cat.values()), dim=-1)
        deep_logits = self.dnn(deep_input)

        # 输出
        logits = wide_logits + deep_logits
        return logits