In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [31]:
## AUGRU实现
class AUGRUCell(nn.Module):
    """AUGRU cell for attention update
       input_size是嵌入向量维度
       hidden_size自定义
    """
    def __init__(self, input_size, hidden_size, bias=True):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        # (Wxr|Wxz|Wxh)
        self.weight_xrzh = nn.Parameter(
            torch.ones(input_size,3 * hidden_size,dtype=torch.float32))
        # (Hxr|Hxz|Hxh)
        self.weight_hrzh = nn.Parameter(
            torch.ones(hidden_size,3 * hidden_size,dtype=torch.float32))
        if bias:
            # (b)
            self.bias_r = nn.Parameter(torch.zeros(hidden_size))
            self.bias_z = nn.Parameter(torch.zeros(hidden_size))
            self.bias_h = nn.Parameter(torch.zeros(hidden_size))
        else:
            self.register_parameter('bias_r', None)
            self.register_parameter('bias_z', None)
            self.register_parameter('bias_h', None)
        self.reset_parameters()
 
    def reset_parameters(self):
        stdv = 1.0 / self.hidden_size ** 0.5
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.uniform_(param, -stdv, stdv)
            elif 'bias' in name:
                nn.init.zeros_(param)

    # att_score 是 batch_size*1
    # x是上一层gru的输出应该是
    def forward(self, x, hidden_state, att_score):
        W_xr,W_xz,W_xh = self.weight_xrzh.chunk(3, 1)
        W_hr,W_hz,W_hh = self.weight_hrzh.chunk(3, 1)
 
        reset_gate = torch.sigmoid(torch.matmul(x,W_xr) + torch.matmul(hidden_state,W_hr)+self.bias_r)
        # batch_size *hidden_size
        update_gate_pre = torch.sigmoid(torch.matmul(x,W_xz) +torch.matmul(hidden_state,W_hz) +self.bias_z)
        update_gate = att_score.reshape(-1, 1) * update_gate_pre
        hidden_gate = torch.tanh(torch.matmul(x,W_xh) + torch.matmul((reset_gate * hidden_state),W_hh) +self.bias_h)
 
        
        hidden_state = (1-update_gate)*hidden_state +  update_gate*hidden_gate
 
        return hidden_state

# 双线性注意力计算
class BLAttention(nn.Module):
    """注意力计算层"""
    def __init__(self,hidden_size,embed_size):
        super(BLAttention, self).__init__()
        self.attention_W = nn.Parameter(torch.zeros(hidden_size, hidden_size))
        nn.init.xavier_uniform_(self.attention_W)

    def forward(self, query, keys, mask=None):
        # query: [B, H] 
        # keys: [B, L, E]
        # mask: [B, L]
        att_score_inter = torch.matmul(query, self.attention_W)
        att_score = F.softmax(torch.matmul(att_score_inter,keys), dim=-1)
        
        # 计算注意力分数
        energy = torch.tanh(self.linear(torch.cat([target.expand_as(keys), keys], dim=-1)))
        scores = self.fc(energy).squeeze(-1)  # [B, L]
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(scores, dim=-1)  # [B, L]
        context = torch.bmm(attn_weights.unsqueeze(1), keys).squeeze(1)  # [B, H]
        return context, attn_weights


class InterestExtractor(nn.Module):
    """兴趣抽取层"""
    def __init__(self, input_size, hidden_size):
        super(InterestExtractor, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, input_size)  # 辅助分类层
        
    def forward(self, x, lengths):
        # x: [B, L, H]
        # lengths: [B]
        packed = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        outputs, _ = self.gru(packed)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        
        # 计算辅助损失
        aux_loss = 0
        if self.training:
            shifted = x[:, 1:, :]  # 下一个行为作为标签
            pred = self.fc(outputs[:, :-1, :])
            aux_loss = F.binary_cross_entropy_with_logits(pred, shifted)
            
        return outputs, aux_loss

class InterestEvolving(nn.Module):
    """兴趣演化层"""
    def __init__(self, input_size, hidden_size):
        super(InterestEvolving, self).__init__()
        self.augru_cell = AUGRUCell(input_size, hidden_size)
        self.attention = Attention(hidden_size)
        
    def forward(self, interests, target, mask):
        # interests: [B, L, H]
        # target: [B, H]
        # mask: [B, L]
        batch_size, seq_len, hidden_size = interests.size()
        
        # 初始化隐藏状态
        h = torch.zeros(batch_size, hidden_size).to(interests.device)
        
        # 计算注意力权重
        context, att_weights = self.attention(target, interests, mask)
        
        # 按时间步进行演化
        for t in range(seq_len):
            x = interests[:, t, :]  # 当前兴趣状态
            att_score = att_weights[:, t].unsqueeze(1)  # 当前注意力分数
            h = self.augru_cell(x, h, att_score)
            
        return h

class DIEN(nn.Module):
    """完整的DIEN模型"""
    def __init__(self, user_vocab_size, item_vocab_size, cat_vocab_size, 
                 emb_dim=32, hidden_size=64, max_seq_len=50):
        super(DIEN, self).__init__()
        
        # 嵌入层
        self.user_emb = nn.Embedding(user_vocab_size, emb_dim)
        self.item_emb = nn.Embedding(item_vocab_size, emb_dim)
        self.cat_emb = nn.Embedding(cat_vocab_size, emb_dim)
        
        # 兴趣抽取层
        self.interest_extractor = InterestExtractor(emb_dim*2, hidden_size)
        
        # 兴趣演化层
        self.interest_evolving = InterestEvolving(hidden_size, hidden_size)
        
        # 全连接层
        self.fc = nn.Sequential(
            nn.Linear(hidden_size*2 + emb_dim*3, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        self.max_seq_len = max_seq_len

    def forward(self, inputs):
        # 解包输入
        user_ids = inputs['user_id']
        target_item = inputs['target_item']
        target_cat = inputs['target_cat']
        hist_items = inputs['hist_items']
        hist_cats = inputs['hist_cats']
        seq_lengths = inputs['seq_lengths']
        
        # 嵌入层
        user_emb = self.user_emb(user_ids)
        target_item_emb = self.item_emb(target_item)
        target_cat_emb = self.cat_emb(target_cat)
        hist_item_emb = self.item_emb(hist_items)
        hist_cat_emb = self.cat_emb(hist_cats)
        
        # 拼接历史行为特征
        hist_emb = torch.cat([hist_item_emb, hist_cat_emb], dim=-1)  # [B, L, 2*E]
        target_emb = torch.cat([target_item_emb, target_cat_emb], dim=-1).squeeze(1)
        
        # 兴趣抽取
        interests, aux_loss = self.interest_extractor(hist_emb, seq_lengths)
        
        # 生成序列mask
        mask = (torch.arange(self.max_seq_len, device=seq_lengths.device)[None, :] 
                < seq_lengths[:, None]).float()
        
        # 兴趣演化
        final_interest = self.interest_evolving(interests, target_emb, mask)
        
        # 拼接所有特征
        concat = torch.cat([
            user_emb.squeeze(1),
            final_interest,
            target_item_emb.squeeze(1),
            target_cat_emb.squeeze(1),
            target_emb
        ], dim=1)
        
        # 最终预测
        output = torch.sigmoid(self.fc(concat))
        
        return output, aux_loss

# 示例用法
if __name__ == "__main__":
    # 假设参数配置
    config = {
        'user_vocab_size': 10000,
        'item_vocab_size': 50000,
        'cat_vocab_size': 500,
        'emb_dim': 32,
        'hidden_size': 64,
        'max_seq_len': 20
    }
    
    # 初始化模型
    model = DIEN(**config)
    
    # 模拟输入数据
    batch_size = 32
    inputs = {
        'user_id': torch.randint(0, 10000, (batch_size, 1)),
        'target_item': torch.randint(0, 50000, (batch_size, 1)),
        'target_cat': torch.randint(0, 500, (batch_size, 1)),
        'hist_items': torch.randint(0, 50000, (batch_size, 20)),
        'hist_cats': torch.randint(0, 500, (batch_size, 20)),
        'seq_lengths': torch.randint(1, 20, (batch_size,))
    }
    
    # 前向传播
    output, aux_loss = model(inputs)
    print(f"Output shape: {output.shape}")
    print(f"Aux loss: {aux_loss.item()}")


In [115]:
batch_size=2

In [116]:
inputs = {
    'user_id': torch.randint(0, 10000, (batch_size, 1)),
    'target_item': torch.randint(0, 50000, (batch_size, 1)),
    'target_cat': torch.randint(0, 500, (batch_size, 1)),
    'hist_items': torch.randint(0, 50000, (batch_size, 20)),
    'hist_cats': torch.randint(0, 500, (batch_size, 20)),
    'seq_lengths': torch.randint(1, 20, (batch_size,))
}

In [117]:
inputs

{'user_id': tensor([[9401],
         [7751]]),
 'target_item': tensor([[2413],
         [1085]]),
 'target_cat': tensor([[359],
         [334]]),
 'hist_items': tensor([[26272, 13704, 38834,  7080, 13441,  6823,  1643, 47795,  4518, 27885,
          32883,  5716, 47214, 30686, 48031,    77,  6285, 22474, 31309, 21887],
         [26810, 47436, 26162,  8322, 49154,  7238,  2166,  8375, 11906, 49889,
           1998, 11137,  7820, 45308, 27316, 42816, 20566, 46410, 46107, 38008]]),
 'hist_cats': tensor([[117,  10, 154, 349, 397, 439, 340, 156, 226, 331, 322,  53, 206,  36,
          465, 255, 487, 451, 217, 338],
         [291, 493, 298, 221, 132, 384, 164,   8, 390, 162,  16, 120, 292, 294,
          235, 232, 198, 368, 389,  12]]),
 'seq_lengths': tensor([15,  1])}