In [None]:
# 序列推荐模型
# GRU4Rec、DIEN

In [1]:
# 加载数据集2成序列数据集，评分[0,1,2]为负反馈，评分[3,4,5]为正反馈，只保留正样本，构造简单序列推荐数据集
# 数据集：ml-100k

import os, random
import numpy as np
import pandas as pd
random.seed(100)

# 加载数据: >=3分为正，用户评分次数不低于50，只保留最后50个，拆分为40: 5 + 15负例 (随机采样): 5 + 15负例 (随机采样)
ratings = np.array([[int(x) for x in line.strip().split('\t')[:4]] for line in open('./data/ml-100k/ua.base','r').read().strip().split('\n')], dtype=np.int32)
ratings_pd = pd.DataFrame({feature_name: list(feature_data) for feature_name, feature_data in zip(['user_id','item_id','rating','timestamp'], ratings.T)})
pos_ratings_pd = ratings_pd[ratings_pd['rating']>2.9][['user_id','item_id','timestamp']].dropna().sort_values('timestamp') # 已经排序了
pos_ratings_pd = pos_ratings_pd.groupby('user_id').filter(lambda x: x['user_id'].count()>=50)
userid2id = {user_id: i for i, user_id in enumerate(sorted(list(set(pos_ratings_pd['user_id'].tolist()))))}
itemid2id = {item_id: i for i, item_id in enumerate(sorted(list(set(pos_ratings_pd['item_id'].tolist()))))}
print(len(userid2id), len(itemid2id))
del ratings, ratings_pd

# new id
user_train_validate_test = {}
for user,item,t in pos_ratings_pd.values:
    u, i = userid2id[user], itemid2id[item]
    if u not in user_train_validate_test:
        user_train_validate_test[u] = [i]
    else:
        user_train_validate_test[u].append(i)
    user_train_validate_test[u] = user_train_validate_test[u][-50:]
train_seq_len = 40
pos_num = 5
neg_sample_num = 15
def sample(low, high, notinset, num):
    nums = set([])
    n = num
    while n>0:
        id = random.randint(low, high)
        if id not in notinset and id not in nums:
            nums.add(id)
            n -= 1
    return list(nums)
data = np.zeros((len(user_train_validate_test), 81), dtype=np.int32)
i = 0
for user, train_validate_test in user_train_validate_test.items():
    train, validate, test = train_validate_test[:train_seq_len], train_validate_test[-pos_num*2:-pos_num], train_validate_test[-pos_num:]
    data[i, 0] = user
    data[i,1:train_seq_len+1] = np.array(train)
    samples = sample(0, len(itemid2id)-1, set(train_validate_test), neg_sample_num * 2)
    data[i,1+train_seq_len : 1+train_seq_len+pos_num+neg_sample_num] = np.array(validate + samples[:neg_sample_num])
    data[i,1+train_seq_len+pos_num+neg_sample_num : ] = np.array(test + samples[neg_sample_num:])
    i += 1
del user_train_validate_test
print(data.shape)
print(data[:2,:])

446 1548
(446, 81)
[[ 397  303  260  306  312  744  257  285  338  270  682  862  328 1543
   344  881  326  298  867  265  673 1491  301  337  353  261  300 1260
  1238  302  325  334  331  351  347 1090  683  901  897  272  345   49
   585  309  521  126  386 1282 1038  539  288  417  418  931 1444  804
   164  933  941 1326  686 1001  316  324 1128  900 1091 1349  710  716
   470 1499  225   98 1508  357  365  757  887  248  633]
 [ 344  773  365  397  445   48  374   62  431  981  231  780  384  109
    39  775 1021  929 1022  393   93  399   89  386  569  715   66 1059
   748  414  459  418  139  832  495  831 1413  413  398  396   77  784
   717  786   50  142  768  259 1159  394   11  917  793 1306 1307  810
   302  944 1328  564   53  293  141  212  778  645  825  839  328  332
  1234 1363  342 1247  485  107 1134 1137  369 1277  254]]


In [17]:
# GRU4Rec: 
# user_embedding = GRU(item_embedding_seq)
# y = user_embedding * item_embedding.T
# 数据集：ml-100k


import torch
from torch import nn
from torch.nn import Module, CrossEntropyLoss, Sequential, Linear, Sigmoid
from torch.utils.data import Dataset, DataLoader, TensorDataset 
from sklearn.model_selection import train_test_split
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
device = torch.device("cuda:0" if torch.cuda.is_available() else ('mps:0' if torch.backends.mps.is_available() else "cpu"))
batch_size = 100
num_epochs = 10
dim=100

train_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(data[:,1: 1+ train_seq_len]).long(), torch.from_numpy(data[:,1+ train_seq_len:-(pos_num+neg_sample_num)]).long()), batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(data[:,1: 1+ train_seq_len + pos_num]).long(), torch.from_numpy(data[:,-(pos_num+neg_sample_num) : ]).long()), batch_size=batch_size, shuffle=False, pin_memory=True)

class GRU4Rec(nn.Module):
    def __init__(self, num_items, embedding_dim, gru_num_layers=1):
        super(GRU4Rec, self).__init__()
        self.num_items = num_items
        self.embedding_dim, self.gru_num_layers = embedding_dim, gru_num_layers
        self.item_embeddings = nn.Embedding(num_items, self.embedding_dim, padding_idx=-1)
        torch.nn.init.kaiming_normal_(self.item_embeddings.weight.data)
        self.gru = nn.GRU(input_size=self.embedding_dim, hidden_size=self.embedding_dim, num_layers=self.gru_num_layers, batch_first=True)
    # [batch, seq_len], [batch, label_len]
    def forward(self, item_seqs: torch.Tensor, test: torch.Tensor):
        batch_len = item_seqs.shape[0]
        # [batch, seq_len, dim]
        item_seqs_embeddings = self.item_embeddings(item_seqs)
        # [batch, label_len, dim]
        test_embeddings = self.item_embeddings(test)
        # gru输出最后的隐层输出当为user embedding
        _, user_emb = self.gru(item_seqs_embeddings)
        # [batch, dim * gru_num_layers]
        user_emb = user_emb.reshape((batch_len, self.gru_num_layers * self.embedding_dim))
        # predict
        scores = torch.sigmoid(torch.bmm(test_embeddings.repeat([1,1,self.gru_num_layers]), user_emb.unsqueeze(-1)).squeeze())
        return scores
model = GRU4Rec(num_items = len(itemid2id), embedding_dim = dim).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0003)
criterion = CrossEntropyLoss(reduction='sum').to(device)
label = torch.FloatTensor([1 for i in range(pos_num)] + [0 for i in range(neg_sample_num)]).to(device)

def DCG(batch_labels):
    dcgsum = np.zeros((batch_labels.shape[0]))
    for i in range(batch_labels.shape[-1]):
        dcg = (2 ** batch_labels[:,i] - 1) / np.math.log(i + 2, 2)
        dcgsum += dcg
    return dcgsum
def NDCG(output, labels):
    # ideal_dcg
    ideal_dcg = DCG(labels)
    # this
    dcg = DCG((np.argsort( - output, axis=-1)<pos_num).astype(np.float32))
    return np.sum(dcg/ideal_dcg)

for epoch in range(num_epochs):
    # train:
    epoch_train_losses = []
    model.train()
    for i, inputs in enumerate(train_loader):
        optimizer.zero_grad()
        item_seqs = inputs[0].to(device)
        test = inputs[1].to(device)
        output = model(item_seqs, test)
        labels = label.unsqueeze(0).repeat([item_seqs.shape[0],1])
        loss = criterion(output, labels)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=1, norm_type=2)
        optimizer.step()
        epoch_train_losses.append([item_seqs.shape[0], loss.item(), NDCG(output.detach().numpy(), labels.detach().numpy())])
    # validate:
    model.eval()
    epoch_test_losses = []
    for i, inputs in enumerate(test_loader):
        item_seqs = inputs[0].to(device)
        test = inputs[1].to(device)
        output = model(item_seqs, test)
        labels = label.unsqueeze(0).repeat([item_seqs.shape[0],1])
        loss = criterion(output, labels)
        epoch_test_losses.append([item_seqs.shape[0], loss.item(), NDCG(output.detach().numpy(), labels.detach().numpy())])
    train_loss = sum([x[1] for x in epoch_train_losses])/sum([x[0] * (pos_num + neg_sample_num) for x in epoch_train_losses])
    test_loss  = sum([x[1] for x in epoch_test_losses])/sum([x[0] * (pos_num + neg_sample_num) for x in epoch_test_losses])
    train_ndcg = sum([x[2] for x in epoch_train_losses])/sum([x[0] for x in epoch_train_losses])
    test_ndcg  = sum([x[2] for x in epoch_test_losses])/sum([x[0] for x in epoch_test_losses])
    # print
    print('['+datetime.now().strftime("%Y-%m-%d %H:%M:%S")+']', 'epoch=[{}/{}], train_ce_loss: {:.4f}, train_ndcg: {:.4f}, validate_ce_loss: {:.4f}, validate_ndcg: {:.4f}'.format(epoch+1, num_epochs,  train_loss, train_ndcg, test_loss, test_ndcg))


[2023-09-01 16:22:16] epoch=[1/10], train_ce_loss: 0.7405, train_ndcg: 0.6901, validate_ce_loss: 0.7556, validate_ndcg: 0.6248
[2023-09-01 16:22:20] epoch=[2/10], train_ce_loss: 0.6774, train_ndcg: 0.8633, validate_ce_loss: 0.7632, validate_ndcg: 0.6208
[2023-09-01 16:22:23] epoch=[3/10], train_ce_loss: 0.6616, train_ndcg: 0.8945, validate_ce_loss: 0.7653, validate_ndcg: 0.6203
[2023-09-01 16:22:27] epoch=[4/10], train_ce_loss: 0.6592, train_ndcg: 0.9063, validate_ce_loss: 0.7659, validate_ndcg: 0.6252
[2023-09-01 16:22:30] epoch=[5/10], train_ce_loss: 0.6581, train_ndcg: 0.9107, validate_ce_loss: 0.7662, validate_ndcg: 0.6275
[2023-09-01 16:22:34] epoch=[6/10], train_ce_loss: 0.6568, train_ndcg: 0.9145, validate_ce_loss: 0.7664, validate_ndcg: 0.6273
[2023-09-01 16:22:38] epoch=[7/10], train_ce_loss: 0.6549, train_ndcg: 0.9156, validate_ce_loss: 0.7665, validate_ndcg: 0.6279
[2023-09-01 16:22:41] epoch=[8/10], train_ce_loss: 0.6524, train_ndcg: 0.9182, validate_ce_loss: 0.7673, valida

In [None]:
# DIEN
# Deep Interest Evolution Network for Click-Through Rate Prediction, 2018
# 数据集：ml-100k


class InterestExtractor(nn.Module):
    def __init__(self, input_size, use_neg=False, dnn_hidden_units=[100, 50, 1], init_std=0.001):
        super(InterestExtractor, self).__init__()
        self.use_neg = use_neg
        self.gru = nn.GRU(input_size=input_size, hidden_size=input_size, batch_first=True)
        if self.use_neg:
            self.auxiliary_net = DNN(input_size * 2, dnn_hidden_units, 'sigmoid', init_std=init_std)
        for name, tensor in self.gru.named_parameters():
            if 'weight' in name:
                nn.init.normal_(tensor, mean=0, std=init_std)

    def forward(self, keys, keys_length, neg_keys=None):
        """
        keys:        [btz, seq_len, hdsz]
        keys_length: [btz, 1]
        neg_keys:    [btz, seq_len, hdsz]   
        """
        btz, seq_len, hdsz = keys.shape
        smp_mask = keys_length > 0
        keys_length = keys_length[smp_mask]  # [btz1, 1]

        # keys全部为空
        if keys_length.shape[0] == 0:
            return torch.zeros(btz, hdsz, device=keys.device)

        # 过RNN
        masked_keys = torch.masked_select(keys, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)  # 去除全为0序列的样本
        packed_keys = pack_padded_sequence(masked_keys, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False)
        packed_interests, _ = self.gru(packed_keys)
        interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0, total_length=seq_len)

        # 计算auxiliary_loss
        if self.use_neg and neg_keys is not None:
            masked_neg_keys = torch.masked_select(neg_keys, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)
            aux_loss = self._cal_auxiliary_loss(interests[:, :-1, :], masked_keys[:, 1:, :], 
                                                masked_neg_keys[:, 1:, :], keys_length - 1)
        return interests, aux_loss

    def _cal_auxiliary_loss(self, states, click_seq, noclick_seq, keys_length):
        """
        states:        [btz, seq_len, hdsz]
        click_seq:     [btz, seq_len, hdsz]   
        noclick_seq:   [btz, seq_len, hdsz]
        keys_length:   [btz, 1]
        """
        smp_mask = keys_length > 0
        keys_length = keys_length[smp_mask]  # [btz1, 1]

        # keys全部为空
        if keys_length.shape[0] == 0:
            return torch.zeros((1,), device=states.device)
        
        # 去除全为0序列的样本
        btz, seq_len, hdsz = states.shape
        states = torch.masked_select(states, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)
        click_seq = torch.masked_select(click_seq, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)
        noclick_seq = torch.masked_select(noclick_seq, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)

        # 仅对非mask部分计算loss
        mask = torch.arange(seq_len, device=states.device) < keys_length[:, None]
        click_input = torch.cat([states, click_seq], dim=-1)  # [btz, seq_len, hdsz*2]
        noclick_input = torch.cat([states, noclick_seq], dim=-1)  # [btz, seq_len, hdsz*2]
        click_p = self.auxiliary_net(click_input.view(-1, hdsz*2)).view(btz, seq_len)[mask].view(-1, 1)
        noclick_p = self.auxiliary_net(noclick_input.view(-1, hdsz*2)).view(btz, seq_len)[mask].view(-1, 1)
        click_target = torch.ones_like(click_p)
        noclick_target = torch.zeros_like(click_p)

        loss = F.binary_cross_entropy(torch.cat([click_p, noclick_p], dim=0), torch.cat([click_target, noclick_target], dim=0))
        return loss


class AGRUCell(nn.Module):
    """ Attention based GRU (AGRU)

        Reference:
        -  Deep Interest Evolution Network for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1809.03672, 2018.
    """

    def __init__(self, input_size, hidden_size, bias=True):
        super(AGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        # (W_ir|W_iz|W_ih)
        self.weight_ih = nn.Parameter(torch.Tensor(3 * hidden_size, input_size))
        self.register_parameter('weight_ih', self.weight_ih)
        # (W_hr|W_hz|W_hh)
        self.weight_hh = nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size))
        self.register_parameter('weight_hh', self.weight_hh)
        if bias:
            # (b_ir|b_iz|b_ih)
            self.bias_ih = nn.Parameter(torch.Tensor(3 * hidden_size))
            self.register_parameter('bias_ih', self.bias_ih)
            # (b_hr|b_hz|b_hh)
            self.bias_hh = nn.Parameter(torch.Tensor(3 * hidden_size))
            self.register_parameter('bias_hh', self.bias_hh)
            for tensor in [self.bias_ih, self.bias_hh]:
                nn.init.zeros_(tensor, )
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)

    def forward(self, inputs, hx, att_score):
        gi = F.linear(inputs, self.weight_ih, self.bias_ih)
        gh = F.linear(hx, self.weight_hh, self.bias_hh)
        i_r, _, i_n = gi.chunk(3, 1)
        h_r, _, h_n = gh.chunk(3, 1)

        reset_gate = torch.sigmoid(i_r + h_r)
        # update_gate = torch.sigmoid(i_z + h_z)
        new_state = torch.tanh(i_n + reset_gate * h_n)

        att_score = att_score.view(-1, 1)
        hy = (1. - att_score) * hx + att_score * new_state
        return hy


class AUGRUCell(nn.Module):
    """ Effect of GRU with attentional update gate (AUGRU)

        Reference:
        -  Deep Interest Evolution Network for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1809.03672, 2018.
    """

    def __init__(self, input_size, hidden_size, bias=True):
        super(AUGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        # (W_ir|W_iz|W_ih)
        self.weight_ih = nn.Parameter(torch.Tensor(3 * hidden_size, input_size))
        self.register_parameter('weight_ih', self.weight_ih)
        # (W_hr|W_hz|W_hh)
        self.weight_hh = nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size))
        self.register_parameter('weight_hh', self.weight_hh)
        if bias:
            # (b_ir|b_iz|b_ih)
            self.bias_ih = nn.Parameter(torch.Tensor(3 * hidden_size))
            self.register_parameter('bias_ih', self.bias_ih)
            # (b_hr|b_hz|b_hh)
            self.bias_hh = nn.Parameter(torch.Tensor(3 * hidden_size))
            self.register_parameter('bias_ih', self.bias_hh)
            for tensor in [self.bias_ih, self.bias_hh]:
                nn.init.zeros_(tensor, )
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)

    def forward(self, inputs, hx, att_score):
        gi = F.linear(inputs, self.weight_ih, self.bias_ih)
        gh = F.linear(hx, self.weight_hh, self.bias_hh)
        i_r, i_z, i_n = gi.chunk(3, 1)
        h_r, h_z, h_n = gh.chunk(3, 1)

        reset_gate = torch.sigmoid(i_r + h_r)
        update_gate = torch.sigmoid(i_z + h_z)
        new_state = torch.tanh(i_n + reset_gate * h_n)

        att_score = att_score.view(-1, 1)
        update_gate = att_score * update_gate
        hy = (1. - update_gate) * hx + update_gate * new_state
        return hy


class DynamicGRU(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True, gru_type='AGRU'):
        super(DynamicGRU, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        if gru_type == 'AGRU':
            self.rnn = AGRUCell(input_size, hidden_size, bias)
        elif gru_type == 'AUGRU':
            self.rnn = AUGRUCell(input_size, hidden_size, bias)

    def forward(self, inputs, att_scores=None, hx=None):
        if not isinstance(inputs, PackedSequence) or not isinstance(att_scores, PackedSequence):
            raise NotImplementedError("DynamicGRU only supports packed input and att_scores")

        inputs, batch_sizes, sorted_indices, unsorted_indices = inputs
        att_scores, _, _, _ = att_scores

        max_batch_size = int(batch_sizes[0])
        if hx is None:
            hx = torch.zeros(max_batch_size, self.hidden_size,
                             dtype=inputs.dtype, device=inputs.device)

        outputs = torch.zeros(inputs.size(0), self.hidden_size,
                              dtype=inputs.dtype, device=inputs.device)

        begin = 0
        for batch in batch_sizes:
            new_hx = self.rnn(
                inputs[begin:begin + batch],
                hx[0:batch],
                att_scores[begin:begin + batch])
            outputs[begin:begin + batch] = new_hx
            hx = new_hx
            begin += batch
        return PackedSequence(outputs, batch_sizes, sorted_indices, unsorted_indices)


class InterestEvolving(nn.Module):
    """DIEN中的兴趣演化模块
    """
    def __init__(self, input_size, gru_type='GRU', use_neg=False, init_std=0.001, 
                 att_hidden_size=(64, 16), att_activation='sigmoid', att_weight_normalization=False):
        super(InterestEvolving, self).__init__()
        assert gru_type in {'GRU', 'AIGRU', 'AGRU', 'AUGRU'}, f"gru_type: {gru_type} is not supported"
        self.gru_type = gru_type

        return_score = True
        if gru_type == 'GRU':
            return_score = False
            self.interest_evolution = nn.GRU(input_size=input_size, hidden_size=input_size, batch_first=True)
        elif gru_type == 'AIGRU':
            self.interest_evolution = nn.GRU(input_size=input_size, hidden_size=input_size, batch_first=True)
        elif gru_type == 'AGRU' or gru_type == 'AUGRU':
            self.interest_evolution = DynamicGRU(input_size=input_size, hidden_size=input_size, gru_type=gru_type)

        self.attention = AttentionSequencePoolingLayer(embedding_dim=input_size, att_hidden_units=att_hidden_size, att_activation=att_activation,
                                                       weight_normalization=att_weight_normalization, return_score=return_score)

        for name, tensor in self.interest_evolution.named_parameters():
            if 'weight' in name:
                nn.init.normal_(tensor, mean=0, std=init_std)

    @staticmethod
    def _get_last_state(states, keys_length):
        # states [B, T, H]
        batch_size, max_seq_length, _ = states.size()

        mask = (torch.arange(max_seq_length, device=keys_length.device).repeat(
            batch_size, 1) == (keys_length.view(-1, 1) - 1))

        return states[mask]

    def forward(self, query, keys, keys_length, mask=None):
        """
        query:       [btz, 1, hdsz]
        keys:        [btz, seq_len ,hdsz]
        keys_length: [btz, 1]
        """
        btz, seq_len, hdsz = keys.shape
        smp_mask = keys_length > 0
        keys_length = keys_length[smp_mask]  # [btz1, 1]

        # keys全部为空
        zero_outputs = torch.zeros(btz, hdsz, device=query.device)
        if keys_length.shape[0] == 0:
            return zero_outputs

        query = torch.masked_select(query, smp_mask.view(-1, 1, 1)).view(-1, 1, hdsz)
        keys = torch.masked_select(keys, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)  # 去除全为0序列的样本

        if self.gru_type == 'GRU':
            packed_keys = pack_padded_sequence(keys, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False)
            packed_interests, _ = self.interest_evolution(packed_keys)
            interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0, total_length=seq_len)
            outputs = self.attention(query, interests, keys_length.unsqueeze(1))  # [btz1, 1, hdsz]
            outputs = outputs.squeeze(1)  # [btz1, hdsz]

        elif self.gru_type == 'AIGRU':
            att_scores = self.attention(query, keys, keys_length.unsqueeze(1))  # [btz1, 1, seq_len]
            interests = keys * att_scores.transpose(1,2)  # [btz1, seq_len, hdsz]
            packed_interests = pack_padded_sequence(interests, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False)
            _, outputs = self.interest_evolution(packed_interests)
            outputs = outputs.squeeze(0)  # [btz1, hdsz]

        elif self.gru_type == 'AGRU' or self.gru_type == 'AUGRU':
            att_scores = self.attention(query, keys, keys_length.unsqueeze(1)).squeeze(1)  # [b, T]
            packed_interests = pack_padded_sequence(keys, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False)
            packed_scores = pack_padded_sequence(att_scores, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False)
            outputs = self.interest_evolution(packed_interests, packed_scores)
            outputs, _ = pad_packed_sequence(outputs, batch_first=True, padding_value=0.0, total_length=seq_len)
            # pick last state
            outputs = InterestEvolving._get_last_state(outputs, keys_length) # [b, H]
            
        # [b, H] -> [B, H]
        zero_outputs[smp_mask.squeeze(1)] = outputs
        return zero_outputs

In [None]:

class InterestEvolving(nn.Module):
    """DIEN中的兴趣演化模块
    """
    def __init__(self, input_size, gru_type='GRU', use_neg=False, init_std=0.001, 
                 att_hidden_size=(64, 16), att_activation='sigmoid', att_weight_normalization=False):
        super(InterestEvolving, self).__init__()
        assert gru_type in {'GRU', 'AIGRU', 'AGRU', 'AUGRU'}, f"gru_type: {gru_type} is not supported"
        self.gru_type = gru_type

        return_score = True
        if gru_type == 'GRU':
            return_score = False
            self.interest_evolution = nn.GRU(input_size=input_size, hidden_size=input_size, batch_first=True)
        elif gru_type == 'AIGRU':
            self.interest_evolution = nn.GRU(input_size=input_size, hidden_size=input_size, batch_first=True)
        elif gru_type == 'AGRU' or gru_type == 'AUGRU':
            self.interest_evolution = DynamicGRU(input_size=input_size, hidden_size=input_size, gru_type=gru_type)

        self.attention = AttentionSequencePoolingLayer(embedding_dim=input_size, att_hidden_units=att_hidden_size, att_activation=att_activation,
                                                       weight_normalization=att_weight_normalization, return_score=return_score)

        for name, tensor in self.interest_evolution.named_parameters():
            if 'weight' in name:
                nn.init.normal_(tensor, mean=0, std=init_std)

    @staticmethod
    def _get_last_state(states, keys_length):
        # states [B, T, H]
        batch_size, max_seq_length, _ = states.size()

        mask = (torch.arange(max_seq_length, device=keys_length.device).repeat(
            batch_size, 1) == (keys_length.view(-1, 1) - 1))

        return states[mask]

    def forward(self, query, keys, keys_length, mask=None):
        """
        query:       [btz, 1, hdsz]
        keys:        [btz, seq_len ,hdsz]
        keys_length: [btz, 1]
        """
        btz, seq_len, hdsz = keys.shape
        smp_mask = keys_length > 0
        keys_length = keys_length[smp_mask]  # [btz1, 1]

        # keys全部为空
        zero_outputs = torch.zeros(btz, hdsz, device=query.device)
        if keys_length.shape[0] == 0:
            return zero_outputs

        query = torch.masked_select(query, smp_mask.view(-1, 1, 1)).view(-1, 1, hdsz)
        keys = torch.masked_select(keys, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)  # 去除全为0序列的样本

        if self.gru_type == 'GRU':
            packed_keys = pack_padded_sequence(keys, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False)
            packed_interests, _ = self.interest_evolution(packed_keys)
            interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0, total_length=seq_len)
            outputs = self.attention(query, interests, keys_length.unsqueeze(1))  # [btz1, 1, hdsz]
            outputs = outputs.squeeze(1)  # [btz1, hdsz]

        elif self.gru_type == 'AIGRU':
            att_scores = self.attention(query, keys, keys_length.unsqueeze(1))  # [btz1, 1, seq_len]
            interests = keys * att_scores.transpose(1,2)  # [btz1, seq_len, hdsz]
            packed_interests = pack_padded_sequence(interests, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False)
            _, outputs = self.interest_evolution(packed_interests)
            outputs = outputs.squeeze(0)  # [btz1, hdsz]

        elif self.gru_type == 'AGRU' or self.gru_type == 'AUGRU':
            att_scores = self.attention(query, keys, keys_length.unsqueeze(1)).squeeze(1)  # [b, T]
            packed_interests = pack_padded_sequence(keys, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False)
            packed_scores = pack_padded_sequence(att_scores, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False)
            outputs = self.interest_evolution(packed_interests, packed_scores)
            outputs, _ = pad_packed_sequence(outputs, batch_first=True, padding_value=0.0, total_length=seq_len)
            # pick last state
            outputs = InterestEvolving._get_last_state(outputs, keys_length) # [b, H]
            
        # [b, H] -> [B, H]
        zero_outputs[smp_mask.squeeze(1)] = outputs
        return zero_outputs
    
class DIN(RecBase):
    """Deep Interest Network实现
    """
    def __init__(self, dnn_feature_columns, item_history_list, dnn_hidden_units=(256, 128),
                 att_hidden_units=(64, 16), att_activation='Dice', att_weight_normalization=False,
                 l2_reg_embedding=1e-5, l2_reg_dnn=0, init_std=1e-4,
                 dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False, out_dim=1, **kwargs):
        super(DIN, self).__init__([], dnn_feature_columns, l2_reg_embedding=l2_reg_embedding, init_std=init_std, out_dim=out_dim, **kwargs)
        del self.linear_model  # 删除不必要的网络结构
        
        self.sparse_feature_columns, self.dense_feature_columns, self.varlen_sparse_feature_columns = split_columns(dnn_feature_columns)
        self.item_history_list = item_history_list

        # 把varlen_sparse_feature_columns分解成hist、neg_hist和varlen特征
        # 其实是DIEN的逻辑（为了避免多次执行），DIN中少了neg模块，DIEN是在deepctr是在forward中会重复执行多次
        self.history_feature_names = list(map(lambda x: "hist_"+x, item_history_list))
        self.neg_history_feature_names = list(map(lambda x: "neg_" + x, self.history_feature_names))
        self.history_feature_columns = []
        self.neg_history_feature_columns = []
        self.sparse_varlen_feature_columns = []
        for fc in self.varlen_sparse_feature_columns:
            feature_name = fc.name
            if feature_name in self.history_feature_names:
                self.history_feature_columns.append(fc)
            elif feature_name in self.neg_history_feature_names:
                self.neg_history_feature_columns.append(fc)
            else:
                self.sparse_varlen_feature_columns.append(fc)

        # Attn模块
        att_emb_dim = self._compute_interest_dim()
        self.attention = AttentionSequencePoolingLayer(att_hidden_units=att_hidden_units, embedding_dim=att_emb_dim, att_activation=att_activation,
                                                       return_score=False, supports_masking=False, weight_normalization=att_weight_normalization)

        # DNN模块
        self.dnn = DNN(self.compute_input_dim(dnn_feature_columns), dnn_hidden_units, activation=dnn_activation, 
                       dropout_rate=dnn_dropout, use_bn=dnn_use_bn, init_std=init_std)
        self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False)
        self.add_regularization_weight(filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
        self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn)


    def forward(self, X):
        # 过embedding
        emb_lists, query_emb, keys_emb, keys_length, deep_input_emb = self._get_emb(X)

        # 获取变长稀疏特征pooling的结果， [[btz, 1, emb_size]
        sequence_embed_dict = embedding_lookup(X, self.embedding_dict, self.feature_index, self.sparse_varlen_feature_columns)
        sequence_embed_list = get_varlen_pooling_list(sequence_embed_dict, X, self.feature_index, self.sparse_varlen_feature_columns)
        
        # Attn部分
        hist = self.attention(query_emb, keys_emb, keys_length)  # [btz, 1, hdsz]

        # dnn部分
        dnn_input_emb_list = emb_lists[2]
        dnn_input_emb_list += sequence_embed_list
        deep_input_emb = torch.cat([deep_input_emb, hist], dim=-1)  # [btz, 1, hdsz]
        dnn_input = combined_dnn_input([deep_input_emb], emb_lists[-1])  # [btz, hdsz]
        dnn_output = self.dnn(dnn_input)
        dnn_logit = self.dnn_linear(dnn_output)

        # 输出
        y_pred = self.out(dnn_logit)

        return y_pred
        
    def _get_emb(self, X):
        # 过embedding，这里改造embedding_lookup使得只经过一次embedding, 加快训练速度
        # query_emb_list     [[btz, 1, emb_size], ...]
        # keys_emb_list      [[btz, seq_len, emb_size], ...]
        # dnn_input_emb_list [[btz, 1, emb_size], ...]
        return_feat_list = [self.item_history_list, self.history_feature_names, [fc.name for fc in self.sparse_feature_columns]]
        emb_lists = embedding_lookup(X, self.embedding_dict, self.feature_index, self.dnn_feature_columns, return_feat_list=return_feat_list)
        query_emb_list, keys_emb_list, dnn_input_emb_list = emb_lists
        dense_value_list = [X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]] for feat in self.dense_feature_columns]
        emb_lists.append(dense_value_list)

        query_emb = torch.cat(query_emb_list, dim=-1)  # [btz, 1, hdsz]
        keys_emb = torch.cat(keys_emb_list, dim=-1)  # [btz, 1, hdsz]
        keys_length = maxlen_lookup(X, self.feature_index, self.history_feature_names)  # [btz, 1]
        deep_input_emb = torch.cat(dnn_input_emb_list, dim=-1)  # [btz, 1, hdsz]
        return emb_lists, query_emb, keys_emb, keys_length, deep_input_emb

    def _compute_interest_dim(self):
        """计算兴趣网络特征维度和
        """
        dim_list = [feat.embedding_dim for feat in self.sparse_feature_columns if feat.name in self.item_history_list]
        return sum(dim_list)



In [None]:

class DIEN(DIN):
    """Deep Interest Evolution Network
    """
    def __init__(self, dnn_feature_columns, item_history_list, gru_type="GRU", use_negsampling=False, alpha=1.0, 
                 dnn_use_bn=False, dnn_hidden_units=(256, 128), dnn_activation='relu', att_hidden_units=(64, 16), att_activation="relu", 
                 att_weight_normalization=True, l2_reg_embedding=1e-6, l2_reg_dnn=0, dnn_dropout=0, init_std=0.0001, out_dim=1, **kwargs):
        super(DIEN, self).__init__(dnn_feature_columns, item_history_list, dnn_hidden_units, att_hidden_units, att_activation, att_weight_normalization, 
                                   l2_reg_embedding, l2_reg_dnn, init_std, dnn_dropout, dnn_activation, dnn_use_bn, out_dim, **kwargs)
        del self.attention
        self.alpha = alpha

        # 兴趣提取层
        input_size = self._compute_interest_dim()
        self.interest_extractor = InterestExtractor(input_size=input_size, use_neg=use_negsampling, init_std=init_std)

        # 兴趣演变层
        self.interest_evolution = InterestEvolving(input_size=input_size, gru_type=gru_type, use_neg=use_negsampling, init_std=init_std,
                                                   att_hidden_size=att_hidden_units, att_activation=att_activation, att_weight_normalization=att_weight_normalization)
        
        # DNN
        dnn_input_size = self.compute_input_dim(dnn_feature_columns, [('sparse', 'dense')]) + input_size
        self.dnn = DNN(dnn_input_size, dnn_hidden_units, activation=dnn_activation, 
                       dropout_rate=dnn_dropout, use_bn=dnn_use_bn, init_std=init_std)

    def forward(self, X):
        # 过embedding
        emb_lists, query_emb, keys_emb, keys_length, deep_input_emb = self._get_emb(X)
        neg_keys_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.dnn_feature_columns, return_feat_list=self.neg_history_feature_names)
        neg_keys_emb = torch.cat(neg_keys_emb_list, dim=-1)  # [btz, 1, hdsz]

        # 过兴趣提取层
        # input shape: [btz, seq_len, hdsz],  [btz, 1], [btz, seq_len, hdsz]
        # masked_interest shape: [btz, seq_len, hdsz]
        masked_interest, aux_loss = self.interest_extractor(keys_emb, keys_length, neg_keys_emb)
        self.add_auxiliary_loss(aux_loss, self.alpha)

        # 过兴趣演变层
        hist = self.interest_evolution(query_emb, masked_interest, keys_length)  # [btz, hdsz]

        # dnn部分
        deep_input_emb = torch.cat([deep_input_emb.squeeze(1), hist], dim=-1)  # [btz, hdsz]
        dnn_input = combined_dnn_input([deep_input_emb], emb_lists[-1])  # [btz, hdsz]
        dnn_output = self.dnn(dnn_input)
        dnn_logit = self.dnn_linear(dnn_output)

        # 输出
        y_pred = self.out(dnn_logit)

        return y_pred

In [None]:

class SequencePoolingLayer(nn.Module):
    """seq输入转Pooling，支持多种pooling方式
    """
    def __init__(self, mode='mean', support_masking=False):
        super(SequencePoolingLayer, self).__init__()
        assert mode in {'sum', 'mean', 'max'}, 'parameter mode should in [sum, mean, max]'
        self.mode = mode
        self.support_masking = support_masking
    
    def forward(self, seq_value_len_list):
        # seq_value_len_list: [btz, seq_len, hdsz], [btz, seq_len]/[btz,1]
        seq_input, seq_len = seq_value_len_list

        if self.support_masking:  # 传入的是mask
            mask = seq_len.float()
            user_behavior_len = torch.sum(mask, dim=-1, keepdim=True)  # [btz, 1]
            mask = mask.unsqueeze(2)  # [btz, seq_len, 1]
        else:  # 传入的是behavior长度
            user_behavior_len = seq_len
            mask = torch.arange(0, seq_input.shape[1]) < user_behavior_len.unsqueeze(-1)
            mask = torch.transpose(mask, 1, 2)  # [btz, seq_len, 1]
        
        mask = torch.repeat_interleave(mask, seq_input.shape[-1], dim=2)  # [btz, seq_len, hdsz]
        mask = (1 - mask).bool()
        
        if self.mode == 'max':
            seq_input = torch.masked_fill(seq_input, mask, 1e-8)
            return torch.max(seq_input, dim=1, keepdim=True)  # [btz, 1, hdsz]
        elif self.mode == 'sum':
            seq_input = torch.masked_fill(seq_input, mask, 0)
            return torch.sum(seq_input, dim=1, keepdim=True)  # [btz, 1, hdsz]
        elif self.mode == 'mean':
            seq_input = torch.masked_fill(seq_input, mask, 0)
            seq_sum = torch.sum(seq_input, dim=1, keepdim=True)
            return seq_sum / (user_behavior_len.unsqueeze(-1) + 1e-8)

class AttentionSequencePoolingLayer(nn.Module):
    """DIN中使用的序列注意力
    """
    def __init__(self, att_hidden_units=(80, 40), att_activation='sigmoid', weight_normalization=False,
                 return_score=False, embedding_dim=4, **kwargs):
        super(AttentionSequencePoolingLayer, self).__init__()
        self.return_score = return_score
        self.weight_normalization = weight_normalization
        # 局部注意力单元
        self.dnn = DNN(input_dim=4 * embedding_dim, hidden_units=att_hidden_units, activation=att_activation, 
                       dice_dim=kwargs.get('dice_dim', 3), use_bn=kwargs.get('dice_dim', False), dropout_rate=kwargs.get('dropout_rate', 0))
        self.dense = nn.Linear(att_hidden_units[-1], 1)

    def forward(self, query, keys, keys_length, mask=None):
        """
        query: 候选item, [btz, 1, emb_size]
        keys:  历史点击序列, [btz, seq_len, emb_size]
        keys_len: keys的长度, [btz, 1]
        mask: [btz, seq_len]
        """
        btz, seq_len, emb_size = keys.shape

        # 计算注意力分数
        queries = query.expand(-1, seq_len, -1)
        attn_input = torch.cat([queries, keys, queries-keys, queries*keys], dim=-1)  # [btz, seq_len, 4*emb_size]
        attn_output = self.dnn(attn_input)  # [btz, seq_len, hidden_units[-1]]
        attn_score = self.dense(attn_output)  # [btz, seq_len, 1]

        # Mask处理
        if mask is not None:
            keys_mask = mask.unsqueeze(1)  # [btz, 1, seq_len]
        else:
            keys_mask = torch.arange(seq_len, device=keys.device).repeat(btz, 1)  # [btz, seq_len]
            keys_mask = keys_mask < keys_length
            keys_mask = keys_mask.unsqueeze(1)  # [btz, 1, seq_len]

        attn_score = attn_score.transpose(1, 2)  # [btz, 1, seq_len]
        if self.weight_normalization:
            # padding置为-inf，这样softmax后就是0
            attn_score = torch.masked_fill(attn_score, keys_mask.bool(), -1e-7)
            attn_score = F.softmax(attn_score, dim=-1)  # [btz, 1, seq_len]
        else:
            # padding置为0
            attn_score = torch.masked_fill(attn_score, keys_mask.bool(), 0)
        
        if not self.return_score:
            return torch.matmul(attn_score, keys)  # [btz, 1, emb_size]
        return attn_score


class InterestExtractor(nn.Module):
    """DIEN中的兴趣提取模块
    """
    def __init__(self, input_size, use_neg=False, dnn_hidden_units=[100, 50, 1], init_std=0.001):
        super(InterestExtractor, self).__init__()
        self.use_neg = use_neg
        self.gru = nn.GRU(input_size=input_size, hidden_size=input_size, batch_first=True)
        if self.use_neg:
            self.auxiliary_net = DNN(input_size * 2, dnn_hidden_units, 'sigmoid', init_std=init_std)
        for name, tensor in self.gru.named_parameters():
            if 'weight' in name:
                nn.init.normal_(tensor, mean=0, std=init_std)

    def forward(self, keys, keys_length, neg_keys=None):
        """
        keys:        [btz, seq_len, hdsz]
        keys_length: [btz, 1]
        neg_keys:    [btz, seq_len, hdsz]   
        """
        btz, seq_len, hdsz = keys.shape
        smp_mask = keys_length > 0
        keys_length = keys_length[smp_mask]  # [btz1, 1]

        # keys全部为空
        if keys_length.shape[0] == 0:
            return torch.zeros(btz, hdsz, device=keys.device)

        # 过RNN
        masked_keys = torch.masked_select(keys, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)  # 去除全为0序列的样本
        packed_keys = pack_padded_sequence(masked_keys, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False)
        packed_interests, _ = self.gru(packed_keys)
        interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0, total_length=seq_len)

        # 计算auxiliary_loss
        if self.use_neg and neg_keys is not None:
            masked_neg_keys = torch.masked_select(neg_keys, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)
            aux_loss = self._cal_auxiliary_loss(interests[:, :-1, :], masked_keys[:, 1:, :], 
                                                masked_neg_keys[:, 1:, :], keys_length - 1)
        return interests, aux_loss

    def _cal_auxiliary_loss(self, states, click_seq, noclick_seq, keys_length):
        """
        states:        [btz, seq_len, hdsz]
        click_seq:     [btz, seq_len, hdsz]   
        noclick_seq:   [btz, seq_len, hdsz]
        keys_length:   [btz, 1]
        """
        smp_mask = keys_length > 0
        keys_length = keys_length[smp_mask]  # [btz1, 1]

        # keys全部为空
        if keys_length.shape[0] == 0:
            return torch.zeros((1,), device=states.device)
        
        # 去除全为0序列的样本
        btz, seq_len, hdsz = states.shape
        states = torch.masked_select(states, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)
        click_seq = torch.masked_select(click_seq, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)
        noclick_seq = torch.masked_select(noclick_seq, smp_mask.view(-1, 1, 1)).view(-1, seq_len, hdsz)

        # 仅对非mask部分计算loss
        mask = torch.arange(seq_len, device=states.device) < keys_length[:, None]
        click_input = torch.cat([states, click_seq], dim=-1)  # [btz, seq_len, hdsz*2]
        noclick_input = torch.cat([states, noclick_seq], dim=-1)  # [btz, seq_len, hdsz*2]
        click_p = self.auxiliary_net(click_input.view(-1, hdsz*2)).view(btz, seq_len)[mask].view(-1, 1)
        noclick_p = self.auxiliary_net(noclick_input.view(-1, hdsz*2)).view(btz, seq_len)[mask].view(-1, 1)
        click_target = torch.ones_like(click_p)
        noclick_target = torch.zeros_like(click_p)

        loss = F.binary_cross_entropy(torch.cat([click_p, noclick_p], dim=0), torch.cat([click_target, noclick_target], dim=0))
        return loss