In [15]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [16]:
## AUGRU实现
class AUGRUCell(nn.Module):
    """AUGRU cell for attention update
       input_size是嵌入向量维度
       hidden_size自定义
    """
    def __init__(self, input_size, hidden_size, bias=True):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        # (Wxr|Wxz|Wxh)
        self.weight_xrzh = nn.Parameter(
            torch.ones(input_size, 3 * hidden_size, dtype=torch.float32))
        # (Hxr|Hxz|Hxh)
        self.weight_hrzh = nn.Parameter(
            torch.ones(hidden_size, 3 * hidden_size, dtype=torch.float32))
        if bias:
            # (b)
            self.bias_r = nn.Parameter(torch.zeros(hidden_size))
            self.bias_z = nn.Parameter(torch.zeros(hidden_size))
            self.bias_h = nn.Parameter(torch.zeros(hidden_size))
        else:
            self.register_parameter('bias_r', None)
            self.register_parameter('bias_z', None)
            self.register_parameter('bias_h', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / self.hidden_size ** 0.5
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.uniform_(param, -stdv, stdv)
            elif 'bias' in name:
                nn.init.zeros_(param)

    # att_score 应该是 batch_size*1
    # x是上一层gru的输出应该是,B 1 input_size
    def forward(self, x, hidden_state, att_score):
        W_xr, W_xz, W_xh = self.weight_xrzh.chunk(3, 1)
        W_hr, W_hz, W_hh = self.weight_hrzh.chunk(3, 1)
        reset_gate = torch.sigmoid(torch.matmul(x, W_xr) + torch.matmul(hidden_state, W_hr) + self.bias_r)
        # batch_size *hidden_size
        update_gate_pre = torch.sigmoid(torch.matmul(x, W_xz) + torch.matmul(hidden_state, W_hz) + self.bias_z)
        update_gate = att_score.reshape(-1, 1) * update_gate_pre
        hidden_gate = torch.tanh(torch.matmul(x, W_xh) + torch.matmul((reset_gate * hidden_state), W_hh) + self.bias_h)
        hidden_state = (1 - update_gate) * hidden_state + update_gate * hidden_gate
        return hidden_state


# 双线性注意力计算
class BLAttention(nn.Module):
    """注意力计算层"""
    def __init__(self, embed_size,hidden_size):
        super(BLAttention, self).__init__()
        self.attention_W = nn.Parameter(torch.zeros(hidden_size, embed_size))
        nn.init.xavier_uniform_(self.attention_W)

    def forward(self, query, keys):
        # query: [B, H] W: [H,E] keys: [B, 1, E] mask: [B, 1]      
        # 计算注意力分数
        # B,1,E
        att_score_inter = torch.matmul(query, self.attention_W).unsqueeze(1)
        # B,1
        att_score = torch.matmul(att_score_inter, keys.permute(0, 2, 1)).squeeze(1)
        # if mask is not None:
        #     att_score = att_score.masked_fill(mask == 0, -1e9)
        att_score = F.softmax(att_score, dim=-1)  # [B, 1]
        return att_score


class InterestExtractor(nn.Module):
    """兴趣抽取层
    输入用户行为序列 B,L,input_size
    输出隐藏状态序列 B,L,Hidden_size 和辅助损失
    """

    def __init__(self, input_size, hidden_size):
        super(InterestExtractor, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size,batch_first=True)
        self.fc = nn.Linear(hidden_size, input_size)  # 辅助分类层

    def forward(self, x,mask=None):
        # x: [B, L, H]
        # lengths: [B]
        # mask [B,L]
        mask=mask.unsqueeze(-1)
        outputs, _ = self.gru(x)

        if mask is not None:
            outputs = outputs.masked_fill(mask == 0, 0)
        # # 计算辅助损失
        # aux_loss = 0
        # if self.training:
        #     shifted = x[:, 1:, :]  # 下一个行为作为标签
        #     pred = self.fc(outputs[:, :-1, :])
        #     aux_loss = F.binary_cross_entropy_with_logits(pred, shifted)

        return outputs


class InterestEvolving(nn.Module):
    """兴趣演化层
        输入 上层GRU的output：B,L,Hidden_size
        输出 这部分输出只有一个 B,L,input_size
    """

    def __init__(self, input_size, hidden_size):
        super(InterestEvolving, self).__init__()
        self.hidden_size = hidden_size
        self.augru_cell = AUGRUCell(hidden_size, hidden_size)
        self.attention = BLAttention(input_size,hidden_size)

    def forward(self, interests, target, mask=None):
        # interests: [B, L, H]
        # target: [B,1, E]
        # mask: [B,L]
        if mask is not None:
            interests = interests.masked_fill(mask.unsqueeze(-1) == 0, 0)
        batch_size, seq_len, hidden_size = interests.shape
        # 初始化隐藏状态
        h = torch.zeros(batch_size, self.hidden_size)

        # 按时间步进行演化
        for t in range(seq_len):
            # 计算注意力权重
            x = interests[:, t, :]  # 当前兴趣状态
            att_score = self.attention(x, target)
            # AUGRU兴趣演化 部分
            h = self.augru_cell(x, h, att_score)
        return h


class DIEN(nn.Module):
    """完整的DIEN模型"""

    def __init__(self, user_vocab_size, item_vocab_size, cat_vocab_size,
                 emb_dim=32, hidden_size=64, max_seq_len=50):
        super(DIEN, self).__init__()

        # 嵌入层
        self.user_emb = nn.Embedding(user_vocab_size, emb_dim)
        self.item_emb = nn.Embedding(item_vocab_size, emb_dim)
        self.cat_emb = nn.Embedding(cat_vocab_size, emb_dim)

        # 兴趣抽取层
        self.interest_extractor = InterestExtractor(emb_dim * 2, hidden_size)

        # 兴趣演化层
        self.interest_evolving = InterestEvolving(emb_dim * 2, hidden_size)

        # 全连接层
        self.fc = nn.Sequential(
            nn.Linear(hidden_size + emb_dim * 3, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, item_vocab_size)
        )

        self.max_seq_len = max_seq_len

    def forward(self, inputs):
        # 解包输入
        user_ids = inputs['user_id']
        target_item = inputs['target_item']
        target_cat = inputs['target_cat']
        hist_items = inputs['hist_items']
        hist_cats = inputs['hist_cats']
        seq_lengths = inputs['seq_lengths']

        # 嵌入层
        user_emb = self.user_emb(user_ids)
        target_item_emb = self.item_emb(target_item)
        target_cat_emb = self.cat_emb(target_cat)
        hist_item_emb = self.item_emb(hist_items)
        hist_cat_emb = self.cat_emb(hist_cats)

        # 拼接历史行为特征
        hist_emb = torch.cat([hist_item_emb, hist_cat_emb], dim=-1)  # [B, L, 2*E]
        target_emb = torch.cat([target_item_emb, target_cat_emb], dim=-1)
        # 生成序列mask B,L
        mask = (torch.arange(self.max_seq_len, device=seq_lengths.device)[None, :]
                < seq_lengths[:, None]).float()
        # 兴趣抽取
        interests = self.interest_extractor(hist_emb,mask)

        # 兴趣演化
        final_interest = self.interest_evolving(interests, target_emb, mask)

        # 拼接所有特征
        concat = torch.cat([
            user_emb.squeeze(1),
            final_interest,
            target_emb.squeeze(1)
        ], dim=1)

        # 最终预测
        output = torch.sigmoid(self.fc(concat))

        return output


In [17]:
# 假设参数配置
config = {
    'user_vocab_size': 100,
    'item_vocab_size': 200,
    'cat_vocab_size': 100,
    'emb_dim': 10,
    'hidden_size': 64,
    'max_seq_len': 20
}

# 初始化模型
model = DIEN(**config)
# 模拟输入数据
batch_size = 10
inputs = {
    'user_id': torch.randint(0, 10, (batch_size, 1)),
    'target_item': torch.randint(0, 200, (batch_size, 1)),
    'target_cat': torch.randint(0, 100, (batch_size, 1)),
    'hist_items': torch.randint(0, 200, (batch_size, 20)),
    'hist_cats': torch.randint(0, 100, (batch_size, 20)),
    'seq_lengths': torch.randint(1, 10, (batch_size,))
}
# 前向传播
model(inputs),model(inputs).shape

(tensor([[0.4827, 0.4631, 0.5054,  ..., 0.5055, 0.5014, 0.5385],
         [0.4758, 0.4592, 0.5154,  ..., 0.5167, 0.5246, 0.5384],
         [0.4828, 0.4390, 0.5018,  ..., 0.5005, 0.5245, 0.5383],
         ...,
         [0.4801, 0.4497, 0.5000,  ..., 0.4990, 0.5174, 0.5312],
         [0.4848, 0.4545, 0.4964,  ..., 0.4998, 0.5270, 0.5214],
         [0.4764, 0.4445, 0.4855,  ..., 0.5044, 0.5354, 0.5508]],
        grad_fn=<SigmoidBackward0>),
 torch.Size([10, 200]))