In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    """注意力计算层"""
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.linear = nn.Linear(hidden_size * 2, hidden_size)
        self.query = nn.Linear(hidden_size, hidden_size, bias=False)
        self.key = nn.Linear(hidden_size, hidden_size, bias=False)
        self.value = nn.Linear(hidden_size, hidden_size, bias=False)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, target, keys, mask=None):
        # target: [B, H]
        # keys: [B, L, H]
        # mask: [B, L]
        target = target.unsqueeze(1)  # [B, 1, H]
        
        # 计算注意力分数
        energy = torch.tanh(self.linear(torch.cat([target.expand_as(keys), keys], dim=-1)))
        scores = self.fc(energy).squeeze(-1)  # [B, L]
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(scores, dim=-1)  # [B, L]
        context = torch.bmm(attn_weights.unsqueeze(1), keys).squeeze(1)  # [B, H]
        return context, attn_weights

class AUGRUCell(nn.Module):
    """带注意力更新门的GRU单元"""
    def __init__(self, input_size, hidden_size):
        super(AUGRUCell, self).__init__()
        self.update_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.reset_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.candidate = nn.Linear(input_size + hidden_size, hidden_size)
        
    def forward(self, x, h_prev, att_score):
        # x: [B, H]
        # h_prev: [B, H]
        # att_score: [B, 1]
        combined = torch.cat([x, h_prev], dim=1)
        
        # 原始GRU计算
        z = torch.sigmoid(self.update_gate(combined))  # 更新门
        r = torch.sigmoid(self.reset_gate(combined))    # 重置门
        combined_reset = torch.cat([x, r * h_prev], dim=1)
        h_candidate = torch.tanh(self.candidate(combined_reset))
        
        # 应用注意力得分调整更新门
        z = att_score * z
        h_next = (1 - z) * h_prev + z * h_candidate
        return h_next

class InterestExtractor(nn.Module):
    """兴趣抽取层"""
    def __init__(self, input_size, hidden_size):
        super(InterestExtractor, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, input_size)  # 辅助分类层
        
    def forward(self, x, lengths):
        # x: [B, L, H]
        # lengths: [B]
        packed = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        outputs, _ = self.gru(packed)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        
        # 计算辅助损失
        aux_loss = 0
        if self.training:
            shifted = x[:, 1:, :]  # 下一个行为作为标签
            pred = self.fc(outputs[:, :-1, :])
            aux_loss = F.binary_cross_entropy_with_logits(pred, shifted)
            
        return outputs, aux_loss

class InterestEvolving(nn.Module):
    """兴趣演化层"""
    def __init__(self, input_size, hidden_size):
        super(InterestEvolving, self).__init__()
        self.augru_cell = AUGRUCell(input_size, hidden_size)
        self.attention = Attention(hidden_size)
        
    def forward(self, interests, target, mask):
        # interests: [B, L, H]
        # target: [B, H]
        # mask: [B, L]
        batch_size, seq_len, hidden_size = interests.size()
        
        # 初始化隐藏状态
        h = torch.zeros(batch_size, hidden_size).to(interests.device)
        
        # 计算注意力权重
        context, att_weights = self.attention(target, interests, mask)
        
        # 按时间步进行演化
        for t in range(seq_len):
            x = interests[:, t, :]  # 当前兴趣状态
            att_score = att_weights[:, t].unsqueeze(1)  # 当前注意力分数
            h = self.augru_cell(x, h, att_score)
            
        return h

class DIEN(nn.Module):
    """完整的DIEN模型"""
    def __init__(self, user_vocab_size, item_vocab_size, cat_vocab_size, 
                 emb_dim=32, hidden_size=64, max_seq_len=50):
        super(DIEN, self).__init__()
        
        # 嵌入层
        self.user_emb = nn.Embedding(user_vocab_size, emb_dim)
        self.item_emb = nn.Embedding(item_vocab_size, emb_dim)
        self.cat_emb = nn.Embedding(cat_vocab_size, emb_dim)
        
        # 兴趣抽取层
        self.interest_extractor = InterestExtractor(emb_dim*2, hidden_size)
        
        # 兴趣演化层
        self.interest_evolving = InterestEvolving(hidden_size, hidden_size)
        
        # 全连接层
        self.fc = nn.Sequential(
            nn.Linear(hidden_size*2 + emb_dim*3, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        self.max_seq_len = max_seq_len

    def forward(self, inputs):
        # 解包输入
        user_ids = inputs['user_id']
        target_item = inputs['target_item']
        target_cat = inputs['target_cat']
        hist_items = inputs['hist_items']
        hist_cats = inputs['hist_cats']
        seq_lengths = inputs['seq_lengths']
        
        # 嵌入层
        user_emb = self.user_emb(user_ids)
        target_item_emb = self.item_emb(target_item)
        target_cat_emb = self.cat_emb(target_cat)
        hist_item_emb = self.item_emb(hist_items)
        hist_cat_emb = self.cat_emb(hist_cats)
        
        # 拼接历史行为特征
        hist_emb = torch.cat([hist_item_emb, hist_cat_emb], dim=-1)  # [B, L, 2*E]
        target_emb = torch.cat([target_item_emb, target_cat_emb], dim=-1).squeeze(1)
        
        # 兴趣抽取
        interests, aux_loss = self.interest_extractor(hist_emb, seq_lengths)
        
        # 生成序列mask
        mask = (torch.arange(self.max_seq_len, device=seq_lengths.device)[None, :] 
                < seq_lengths[:, None]).float()
        
        # 兴趣演化
        final_interest = self.interest_evolving(interests, target_emb, mask)
        
        # 拼接所有特征
        concat = torch.cat([
            user_emb.squeeze(1),
            final_interest,
            target_item_emb.squeeze(1),
            target_cat_emb.squeeze(1),
            target_emb
        ], dim=1)
        
        # 最终预测
        output = torch.sigmoid(self.fc(concat))
        
        return output, aux_loss

# 示例用法
if __name__ == "__main__":
    # 假设参数配置
    config = {
        'user_vocab_size': 10000,
        'item_vocab_size': 50000,
        'cat_vocab_size': 500,
        'emb_dim': 32,
        'hidden_size': 64,
        'max_seq_len': 20
    }
    
    # 初始化模型
    model = DIEN(**config)
    
    # 模拟输入数据
    batch_size = 32
    inputs = {
        'user_id': torch.randint(0, 10000, (batch_size, 1)),
        'target_item': torch.randint(0, 50000, (batch_size, 1)),
        'target_cat': torch.randint(0, 500, (batch_size, 1)),
        'hist_items': torch.randint(0, 50000, (batch_size, 20)),
        'hist_cats': torch.randint(0, 500, (batch_size, 20)),
        'seq_lengths': torch.randint(1, 20, (batch_size,))
    }
    
    # 前向传播
    output, aux_loss = model(inputs)
    print(f"Output shape: {output.shape}")
    print(f"Aux loss: {aux_loss.item()}")
