In [None]:
# 多兴趣建模：大部分模型将⽤户的兴趣表达为⼀个 user embedding，然而单个embedding来表达⽤户的多种兴趣是很困难的。尤其是⻓期⾏为序列推荐的场景。
# 可用于多兴趣召回。
# MIND、ComirecSA、SINE、

In [1]:
import copy
# 加载数据集2成序列数据集，评分[0,1,2]为负反馈，评分[3,4,5]为正反馈，只保留正样本，构造简单序列推荐数据集
# 数据集：ml-100k

import os, random
import numpy as np
import pandas as pd
random.seed(100)

# 加载数据: >=3分为正，用户评分次数不低于50，只保留最后50个，拆分为40: 5 + 15负例 (随机采样): 5 + 15负例 (随机采样)
ratings = np.array([[int(x) for x in line.strip().split('\t')[:4]] for line in open('./data/ml-100k/ua.base','r').read().strip().split('\n')], dtype=np.int32)
ratings_pd = pd.DataFrame({feature_name: list(feature_data) for feature_name, feature_data in zip(['user_id','item_id','rating','timestamp'], ratings.T)})
pos_ratings_pd = ratings_pd[ratings_pd['rating']>2.9][['user_id','item_id','timestamp']].dropna().sort_values('timestamp') # 已经排序了
pos_ratings_pd = pos_ratings_pd.groupby('user_id').filter(lambda x: x['user_id'].count()>=50)
userid2id = {user_id: i for i, user_id in enumerate(sorted(list(set(pos_ratings_pd['user_id'].tolist()))))}
itemid2id = {item_id: i for i, item_id in enumerate(sorted(list(set(pos_ratings_pd['item_id'].tolist()))))}
print(len(userid2id), len(itemid2id))
del ratings, ratings_pd

# new id
user_train_validate_test = {}
for user,item,t in pos_ratings_pd.values:
    u, i = userid2id[user], itemid2id[item]
    if u not in user_train_validate_test:
        user_train_validate_test[u] = [i]
    else:
        user_train_validate_test[u].append(i)
    user_train_validate_test[u] = user_train_validate_test[u][-50:]
train_seq_len = 40
pos_num = 5
neg_sample_num = 15
def sample(low, high, notinset, num):
    nums = set([])
    n = num
    while n>0:
        id = random.randint(low, high)
        if id not in notinset and id not in nums:
            nums.add(id)
            n -= 1
    return list(nums)
data = np.zeros((len(user_train_validate_test), 81), dtype=np.int32)
i = 0
for user, train_validate_test in user_train_validate_test.items():
    train, validate, test = train_validate_test[:train_seq_len], train_validate_test[-pos_num*2:-pos_num], train_validate_test[-pos_num:]
    data[i, 0] = user
    data[i,1:train_seq_len+1] = np.array(train)
    samples = sample(0, len(itemid2id)-1, set(train_validate_test), neg_sample_num * 2)
    data[i,1+train_seq_len : 1+train_seq_len+pos_num+neg_sample_num] = np.array(validate + samples[:neg_sample_num])
    data[i,1+train_seq_len+pos_num+neg_sample_num : ] = np.array(test + samples[neg_sample_num:])
    i += 1
del user_train_validate_test
print(data.shape)
print(data[:2,:])

# 继续加载info特征信息，内容特征
occupation_dict = {'administrator':0, 'artist':1, 'doctor':2, 'educator':3, 'engineer':4, 'entertainment':5, 'executive':6, 'healthcare':7, 'homemaker':8, 'lawyer':9, 'librarian':10, 'marketing':11, 'none':12, 'other':13, 'programmer':14, 'retired':15, 'salesman':16, 'scientist':17, 'student':18, 'technician':19, 'writer':20}
gender_dict={'M':1,'F':0}
user_info = {}
for line in open('./data/ml-100k/u.user','r', encoding='utf-8').read().strip().split('\n'):
    phs = line.strip().split('|')
    if int(phs[0]) not in userid2id:
        continue
    uid = userid2id[int(phs[0])]
    user_info[uid] = [gender_dict[phs[2]], occupation_dict[phs[3]]] # int(phs[1]) 为了方便，不要连续型特征
user_num_features = 2
item_info = {}
for line in open('./data/ml-100k/u.item','r', encoding='ISO-8859-1').read().strip().split('\n'):
    phs = line.strip().split('|')
    if int(phs[0]) not in itemid2id:
        continue
    iid = itemid2id[int(phs[0])]
    item_info[iid] = phs[5:]
item_num_features = 19
num_users = len(user_info)
num_items = len(item_info)
num_features = 21



446 1548
(446, 81)
[[ 397  303  260  306  312  744  257  285  338  270  682  862  328 1543
   344  881  326  298  867  265  673 1491  301  337  353  261  300 1260
  1238  302  325  334  331  351  347 1090  683  901  897  272  345   49
   585  309  521  126  386 1282 1038  539  288  417  418  931 1444  804
   164  933  941 1326  686 1001  316  324 1128  900 1091 1349  710  716
   470 1499  225   98 1508  357  365  757  887  248  633]
 [ 344  773  365  397  445   48  374   62  431  981  231  780  384  109
    39  775 1021  929 1022  393   93  399   89  386  569  715   66 1059
   748  414  459  418  139  832  495  831 1413  413  398  396   77  784
   717  786   50  142  768  259 1159  394   11  917  793 1306 1307  810
   302  944 1328  564   53  293  141  212  778  645  825  839  328  332
  1234 1363  342 1247  485  107 1134 1137  369 1277  254]]


In [19]:
# MIND： Multi-interest network with dynamic routing for recommendation at Tmall
# 采⽤了hinton提出的胶囊⽹络(动态路由)作为多兴趣提取层（指定胶囊个数，可以类似于K-Means聚类来理解）
# 我这里用了user_profile和item_profile，基于item_profile的行为序列，采用胶囊网络来建模行为序列的多兴趣偏好。
# 胶囊网络迭代过程中增大top-1兴趣节点的激活，降低其他兴趣节点的激活，本质类似聚类switch expert。
# 这种是hard选择路由。
# 多头注意力则是soft路由。
# 数据集：ml-100k

import torch
from torch import nn
from torch.nn import Module, CrossEntropyLoss, Sequential, Linear, Sigmoid
from torch.utils.data import Dataset, DataLoader, TensorDataset 
from sklearn.model_selection import train_test_split
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
device = torch.device("cuda:0" if torch.cuda.is_available() else ('mps:0' if torch.backends.mps.is_available() else "cpu"))
batch_size = 100
num_epochs = 10
dim=50

user_feature_vals = {}
for i in range(user_num_features):
    user_feature_vals[i] = sorted(list(set([val[i] for val in user_info.values()])))
    for user, info in user_info.items():
        user_info[user][i] = user_feature_vals[i].index(info[i])
item_feature_vals = {}
for i in range(item_num_features):
    item_feature_vals[i] = sorted(list(set([val[i] for val in item_info.values()])))
    for item, info in item_info.items():
        item_info[item][i] = item_feature_vals[i].index(info[i])

user_profile_data = np.array([user_info[u] for u in data[:,0]]) # [data_len, ufeature]
item_seq_profile_data = np.array([[item_info[item] for item in item_seq] for item_seq in data[:,1:]]) # [data_len, seq_len, ufeature]

train_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(user_profile_data).long(), 
                                                torch.from_numpy(item_seq_profile_data[:,:train_seq_len,:]).long(),
                                                torch.from_numpy(item_seq_profile_data[:,train_seq_len:(train_seq_len + pos_num + neg_sample_num),:]).long()
                                                ), batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(user_profile_data).long(), 
                                                torch.from_numpy(item_seq_profile_data[:,:train_seq_len,:]).long(), # 这里图简便，懒得改seq_len了
                                                torch.from_numpy(item_seq_profile_data[:,-(pos_num + neg_sample_num):,:]).long()
                                               ), batch_size=batch_size, shuffle=False, pin_memory=True)

# 胶囊网络结构
class CapsuleNet(nn.Module):
    def __init__(self, hidden_dim: int, seq_len: int, num_interests: int = 4, routing_times: int = 3):
        super(CapsuleNet, self).__init__()
        self.hidden_dim, self.seq_len = hidden_dim, seq_len
        self.num_interests, self.routing_times = num_interests, routing_times
        self.to_interest_linear = nn.Linear(hidden_dim, hidden_dim * num_interests, bias=False)
        self.relu_linear = nn.Sequential(nn.Linear(hidden_dim, hidden_dim, bias=False), nn.ReLU())
    # [batch_len, seq_len, profile_embedding], [batch_len, seq_len]
    def forward(self, history_item_embeddings, mask):
        batch_len = history_item_embeddings.shape[0]
        # 计算u, [batch_len, interest_num, seq_len, hidden_dim]
        interest_item_embeddings = self.to_interest_linear(history_item_embeddings)
        interest_item_embeddings = interest_item_embeddings.reshape((batch_len, self.seq_len, self.num_interests, self.hidden_dim))
        interest_item_embeddings = interest_item_embeddings.permute((0,2,1,3))
        # 随机初始化胶囊权重b
        capsule_weight = torch.randn((batch_len, self.num_interests, self.seq_len), device=device, requires_grad=False)
        # 动态路由传播3次
        for i in range(self.routing_times):
            # mask，最后shape=[b, in, 1, s]
            atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.num_interests, 1)
            paddings = torch.zeros_like(atten_mask, dtype=torch.float)
            # 计算c
            capsule_softmax_weight = torch.softmax(capsule_weight, dim=-1)
            capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)  # mask位置填充0
            capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
            if i <= 1: # 前两次
                # 计算s
                # [batch_len, num_interests, 1, seq_len]  [batch_len, num_interests, seq_len, hidden_dim] -> [batch_len, num_interests, 1, hidden_dim]
                interest_capsule = torch.matmul(capsule_softmax_weight, interest_item_embeddings)
                # 计算v = squash(s)
                cap_norm = torch.sum(interest_capsule.square(), -1, keepdim=True)  # [batch_len, num_interests, 1, 1]
                scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
                interest_capsule = scalar_factor * interest_capsule  # [batch_len, num_interests, 1, hidden_dim]
                # 计算b = b + u * v
                # [batch_len, interest_num, seq_len, hidden_dim], [batch_len, num_interests, hidden_dim, 1] -> [batch_len, interest_num, seq_len, 1]
                delta_weight = torch.matmul(interest_item_embeddings, interest_capsule.transpose(2, 3).contiguous())
                delta_weight = delta_weight.squeeze()
                # 更新时候，正益正，负益负（向量点积作用），迭代多次。
                capsule_weight = capsule_weight + delta_weight
            else:
                interest_capsule = torch.matmul(capsule_softmax_weight, interest_item_embeddings)
                cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
                scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
                interest_capsule = scalar_factor * interest_capsule
        interest_capsule = self.relu_linear(interest_capsule.squeeze()) # [batch_len, num_interests, hidden_dim]
        return interest_capsule
    def parameters(self, recurse: bool = True):
        return [para for para in self.to_interest_linear.parameters()] + [para for para in self.relu_linear.parameters()]
class MIND(nn.Module):
    def __init__(self, hidden_dim: int, seq_len: int, num_interests: int, routing_times: int, user_profile_feature: [tuple], item_profile_feature: [tuple], profile_feature_embedding_dim: int, dnn_layer_dims: list[int]):
        super(MIND, self).__init__()
        self.dnn_layer_dims, self.hidden_dim = dnn_layer_dims, hidden_dim
        self.num_interests = num_interests
        # 内容特征
        self.user_profile_feature, self.item_profile_feature, self.profile_feature_embedding_dim = user_profile_feature, item_profile_feature, profile_feature_embedding_dim
        self.user_profile_embed = nn.ModuleDict({'user_embed_' + str(i): nn.Embedding(num_embeddings=valcount, embedding_dim=profile_feature_embedding_dim) for i, valcount in user_profile_feature})
        self.item_profile_embed = nn.ModuleDict({'item_embed_' + str(i): nn.Embedding(num_embeddings=valcount, embedding_dim=profile_feature_embedding_dim) for i, valcount in item_profile_feature})
        self.user_profile_all_embed_dim = profile_feature_embedding_dim * len(user_profile_feature)
        self.item_profile_all_embed_dim = profile_feature_embedding_dim * len(item_profile_feature)
        # item pooling layer
        self.pooling_layer = nn.Sequential(nn.Linear(self.item_profile_all_embed_dim, hidden_dim), nn.ReLU())
        # capsule_net
        self.capsule_net = CapsuleNet(hidden_dim, seq_len, num_interests, routing_times)
        # final dnn
        self.all_embedding_dim = self.user_profile_all_embed_dim + hidden_dim
        self.final_dnn_network = nn.Sequential(nn.Linear(self.all_embedding_dim, dnn_layer_dims[0]), nn.ReLU())
        if len(dnn_layer_dims) > 1:
            for i, layer_dim in enumerate(dnn_layer_dims[1:]):
                self.final_dnn_network.append(nn.Linear(dnn_layer_dims[i], layer_dim))
                self.final_dnn_network.append(nn.ReLU())
        self.final_dnn_network.append(nn.Linear(dnn_layer_dims[-1], hidden_dim))
        self.final_dnn_network.append(nn.ReLU())
    def forward(self, user_profiles, item_history_list_profile, item_future_list_profile):
        batch_len = user_profiles.shape[0]
        # user profile: [batch, feature * embed_dim]
        user_profile_embeddings = torch.cat([self.user_profile_embed['user_embed_' + str(i)](user_profiles[:,i].long()) for i in range(user_profiles.shape[-1])], axis=-1)
        user_profile_embeddings = user_profile_embeddings.reshape((batch_len, len(self.user_profile_feature) * self.profile_feature_embedding_dim)) # [batch, feature, embed_dim]
        # item_history_list_profile: torch.Tensor([batch, seq_len, feature * embed_dim])
        seq_len = item_history_list_profile.shape[1]
        item_history_list_profile_embeddings = torch.cat([self.item_profile_embed['item_embed_' + str(i)](item_history_list_profile[:,:,i].long()) for i in range(item_history_list_profile.shape[-1])], axis=-1)
        item_history_list_profile_embeddings = item_history_list_profile_embeddings.reshape((batch_len, seq_len, len(self.item_profile_feature) * self.profile_feature_embedding_dim)) # [batch, seq_len, feature * embed_dim]
        # 以上处理user profile和行为历史，下面进行与candidate组合预测， item_future_list 和 item_future_list_profile
        seq_len_ = item_future_list_profile.shape[1]
        item_future_list_profile_embeddings = torch.cat([self.item_profile_embed['item_embed_' + str(i)](item_future_list_profile[:,:,i].long()) for i in range(item_future_list_profile.shape[-1])], axis=-1)
        item_future_list_profile_embeddings = item_future_list_profile_embeddings.reshape((batch_len, seq_len_, len(self.item_profile_feature) * self.profile_feature_embedding_dim)) # [batch, seq_len_, feature * embed_dim]
        # [batch_len, test_len, hidden_dim]
        # [batch, seq_len, embed_dim]
        item_history_pool = self.pooling_layer(item_history_list_profile_embeddings)
        mask = torch.ones((batch_len, seq_len)) # 目前我的数据集整齐没有mask
        multi_interest_capsule = self.capsule_net(item_history_pool, mask) # [batch_len, num_interests, hidden_dim]
        user_multi_interest_cat = torch.cat([user_profile_embeddings.unsqueeze(1).repeat((1,self.num_interests,1)), multi_interest_capsule], dim=-1)
        user_history_multi_interest_embed = self.final_dnn_network(user_multi_interest_cat) # [batch_len, num_interests, hidden_dim]
        # future prediction
        item_future_pool = self.pooling_layer(item_future_list_profile_embeddings)# [batch, seq_len_, hidden_dim]
        item_future_multi_interest_scores = torch.sigmoid(torch.bmm(item_future_pool, user_history_multi_interest_embed.permute((0,2,1)))) # [batch_len, seq_len_, num_interests]
        # find the best capsule
        best_interest_index = torch.argmax(item_future_multi_interest_scores, dim=-1) # [batch_len, seq_len_]
        # print(best_interest_index)
        best_item_future_multi_interest_score  = item_future_multi_interest_scores.take(best_interest_index).squeeze()
        return best_item_future_multi_interest_score
    def parameters(self, recurse: bool = True):
        return [para for para in self.user_profile_embed.parameters()] + [para for para in self.item_profile_embed.parameters()] + [para for para in self.pooling_layer.parameters()] + [para for para in self.capsule_net.parameters()] + [para for para in self.final_dnn_network.parameters()] 
model = MIND(hidden_dim=dim, seq_len=train_seq_len, num_interests=4, routing_times=3, user_profile_feature = [(i,len(list_)) for i, list_ in user_feature_vals.items()], item_profile_feature= [(i,len(list_)) for i, list_ in item_feature_vals.items()], profile_feature_embedding_dim=dim, dnn_layer_dims=[16])

optimizer = torch.optim.Adam(model.parameters(), lr=0.1, weight_decay=0.0003)
criterion = CrossEntropyLoss(reduction='sum').to(device)
label = torch.FloatTensor([1 for i in range(pos_num)] + [0 for i in range(neg_sample_num)]).to(device)

def DCG(batch_labels):
    dcgsum = np.zeros((batch_labels.shape[0]))
    for i in range(batch_labels.shape[-1]):
        dcg = (2 ** batch_labels[:,i] - 1) / np.math.log(i + 2, 2)
        dcgsum += dcg
    return dcgsum
def NDCG(output, labels):
    # ideal_dcg
    ideal_dcg = DCG(labels)
    # this
    dcg = DCG((np.argsort( - output, axis=-1)<pos_num).astype(np.float32))
    return np.sum(dcg/ideal_dcg)

for epoch in range(num_epochs):
    # train:
    epoch_train_losses = []
    model.train()
    for i, inputs in enumerate(train_loader):
        optimizer.zero_grad()
        user_profiles, item_history_list_profile, item_future_list_profile = inputs
        batch_len = user_profiles.shape[0]
        # print(item_history_list_profile.shape, item_future_list_profile.shape)
        user_profiles = user_profiles.to(device)
        item_history_list_profile = item_history_list_profile.to(device)
        item_future_list_profile = item_future_list_profile.to(device)
        output = model(user_profiles, item_history_list_profile, item_future_list_profile)
        labels = label.unsqueeze(0).repeat([batch_len,1])
        loss = criterion(output, labels)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=1, norm_type=2)
        optimizer.step()
        epoch_train_losses.append([batch_len, loss.item(), NDCG(output.cpu().detach().numpy(), labels.cpu().detach().numpy())])
    # validate:
    model.eval()
    epoch_test_losses = []
    for i, inputs in enumerate(test_loader):
        user_profiles, item_history_list_profile, item_future_list_profile = inputs
        batch_len = user_profiles.shape[0]
        user_profiles = user_profiles.to(device)
        item_history_list_profile = item_history_list_profile.to(device)
        item_future_list_profile = item_future_list_profile.to(device)
        output = model(user_profiles, item_history_list_profile, item_future_list_profile)
        labels = label.unsqueeze(0).repeat([batch_len,1])
        loss = criterion(output, labels)
        epoch_test_losses.append([batch_len, loss.item(), NDCG(output.cpu().detach().numpy(), labels.cpu().detach().numpy())])
    train_loss = sum([x[1] for x in epoch_train_losses])/sum([x[0] * (pos_num + neg_sample_num) for x in epoch_train_losses])
    test_loss  = sum([x[1] for x in epoch_test_losses])/sum([x[0] * (pos_num + neg_sample_num) for x in epoch_test_losses])
    train_ndcg = sum([x[2] for x in epoch_train_losses])/sum([x[0] for x in epoch_train_losses])
    test_ndcg  = sum([x[2] for x in epoch_test_losses])/sum([x[0] for x in epoch_test_losses])
    # print
    print('['+datetime.now().strftime("%Y-%m-%d %H:%M:%S")+']', 'epoch=[{}/{}], train_ce_loss: {:.4f}, train_ndcg: {:.4f}, validate_ce_loss: {:.4f}, validate_ndcg: {:.4f}'.format(epoch+1, num_epochs,  train_loss, train_ndcg, test_loss, test_ndcg))

# hard路由方式其实效果不佳，目前采用较为简单的网络结构。用的是item profile 没有去学item id的embedding。实际上，item id的embedding很有用。尊重原文。

[2023-09-05 11:39:33] epoch=[1/10], train_ce_loss: 0.7489, train_ndcg: 0.6662, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 11:39:38] epoch=[2/10], train_ce_loss: 0.7489, train_ndcg: 0.6681, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 11:39:43] epoch=[3/10], train_ce_loss: 0.7489, train_ndcg: 0.6681, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 11:39:48] epoch=[4/10], train_ce_loss: 0.7489, train_ndcg: 0.6640, validate_ce_loss: 0.7489, validate_ndcg: 0.6677
[2023-09-05 11:39:53] epoch=[5/10], train_ce_loss: 0.7489, train_ndcg: 0.6688, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 11:39:58] epoch=[6/10], train_ce_loss: 0.7489, train_ndcg: 0.6682, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 11:40:04] epoch=[7/10], train_ce_loss: 0.7489, train_ndcg: 0.6681, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 11:40:09] epoch=[8/10], train_ce_loss: 0.7489, train_ndcg: 0.6681, validate_ce_loss: 0.7489, valida

In [27]:
# ComirecSA：
# Comirec：Controllable Multi-Interest Framework for Recommendation， KDD 2020
# 改进了MIND中的动态路由算法，采用注意力机制，并选择交互得分最大的兴趣点。
# hard的路由模式。
# 数据集：ml-100k

import torch
from torch import nn
from torch.nn import Module, CrossEntropyLoss, Sequential, Linear, Sigmoid
from torch.utils.data import Dataset, DataLoader, TensorDataset 
from sklearn.model_selection import train_test_split
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
device = torch.device("cuda:0" if torch.cuda.is_available() else ('mps:0' if torch.backends.mps.is_available() else "cpu"))
batch_size = 100
num_epochs = 10
dim=50

user_feature_vals = {}
for i in range(user_num_features):
    user_feature_vals[i] = sorted(list(set([val[i] for val in user_info.values()])))
    for user, info in user_info.items():
        user_info[user][i] = user_feature_vals[i].index(info[i])
item_feature_vals = {}
for i in range(item_num_features):
    item_feature_vals[i] = sorted(list(set([val[i] for val in item_info.values()])))
    for item, info in item_info.items():
        item_info[item][i] = item_feature_vals[i].index(info[i])

user_profile_data = np.array([user_info[u] for u in data[:,0]]) # [data_len, ufeature]
item_seq_profile_data = np.array([[item_info[item] for item in item_seq] for item_seq in data[:,1:]]) # [data_len, seq_len, ufeature]

train_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(user_profile_data).long(), 
                                                torch.from_numpy(item_seq_profile_data[:,:train_seq_len,:]).long(),
                                                torch.from_numpy(item_seq_profile_data[:,train_seq_len:(train_seq_len + pos_num + neg_sample_num),:]).long()
                                                ), batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(dataset=TensorDataset(torch.from_numpy(user_profile_data).long(), 
                                                torch.from_numpy(item_seq_profile_data[:,:train_seq_len,:]).long(), # 这里图简便，懒得改seq_len了
                                                torch.from_numpy(item_seq_profile_data[:,-(pos_num + neg_sample_num):,:]).long()
                                               ), batch_size=batch_size, shuffle=False, pin_memory=True)

# 输入序列嵌入，得到多兴趣嵌入
class MultiInterestSelfAttention(nn.Module):
    def __init__(self, hidden_dim: int, num_interests: int):
        super(MultiInterestSelfAttention, self).__init__()
        self.hidden_dim, self.num_interests = hidden_dim, num_interests
        # Create trainable parameters
        self.W1 = nn.Parameter(torch.rand(hidden_dim, hidden_dim * 4), requires_grad=True)
        self.W2 = nn.Parameter(torch.rand(hidden_dim * 4, num_interests), requires_grad=True)
    def forward(self, item_seq_embeds: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        H = torch.einsum('bse, ed -> bsd', item_seq_embeds, self.W1).tanh() # [batch_len, seq_len, hidden_dim * 4]
        attention = torch.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.unsqueeze(-1).float()), dim=1) # [batch_len, seq_len, num_interests]
        multi_interest_emb = torch.matmul(attention.permute(0, 2, 1), item_seq_embeds) # [batch_len, num_interests, hidden_dim]
        return multi_interest_emb
class ComirecSA(nn.Module):
    def __init__(self, hidden_dim, seq_len, num_interests, user_profile_feature: [tuple], item_profile_feature: [tuple], profile_feature_embedding_dim: int, dnn_layer_dims: list[int]):
        super(ComirecSA, self).__init__()
        self.hidden_dim, self.seq_len, self.num_interests = hidden_dim, seq_len, num_interests
        # 内容特征
        self.user_profile_feature, self.item_profile_feature, self.profile_feature_embedding_dim = user_profile_feature, item_profile_feature, profile_feature_embedding_dim
        self.user_profile_embed = nn.ModuleDict({'user_embed_' + str(i): nn.Embedding(num_embeddings=valcount, embedding_dim=profile_feature_embedding_dim) for i, valcount in user_profile_feature})
        self.item_profile_embed = nn.ModuleDict({'item_embed_' + str(i): nn.Embedding(num_embeddings=valcount, embedding_dim=profile_feature_embedding_dim) for i, valcount in item_profile_feature})
        self.user_profile_all_embed_dim = profile_feature_embedding_dim * len(user_profile_feature)
        self.item_profile_all_embed_dim = profile_feature_embedding_dim * len(item_profile_feature)
        # user/item pooling layer
        self.user_pooling_layer = nn.Sequential(nn.Linear(self.user_profile_all_embed_dim, hidden_dim), nn.ReLU())
        self.item_pooling_layer = nn.Sequential(nn.Linear(self.item_profile_all_embed_dim, hidden_dim), nn.ReLU())
        # MultiInterestSelfAttention
        self.multi_interest_sa = MultiInterestSelfAttention(hidden_dim=hidden_dim, num_interests=num_interests)
    def forward(self, user_profiles, item_history_list_profile, item_future_list_profile):
        batch_len = user_profiles.shape[0]
        # user profile: [batch, feature * embed_dim]
        user_profile_embeddings = torch.cat([self.user_profile_embed['user_embed_' + str(i)](user_profiles[:,i].long()) for i in range(user_profiles.shape[-1])], axis=-1)
        user_profile_embeddings = self.user_pooling_layer(user_profile_embeddings.reshape((batch_len, len(self.user_profile_feature) * self.profile_feature_embedding_dim))) # [batch, embed_dim]
        seq_len = item_history_list_profile.shape[1]
        item_history_list_profile_embeddings = torch.cat([self.item_profile_embed['item_embed_' + str(i)](item_history_list_profile[:,:,i].long()) for i in range(item_history_list_profile.shape[-1])], axis=-1)
        item_history_list_profile_embeddings = self.item_pooling_layer(item_history_list_profile_embeddings.reshape((batch_len, seq_len, len(self.item_profile_feature) * self.profile_feature_embedding_dim))) # [batch, seq_len, embed_dim]
        # 以上处理user profile和行为历史，下面进行与candidate组合预测， item_future_list 和 item_future_list_profile
        seq_len_ = item_future_list_profile.shape[1]
        item_future_list_profile_embeddings = torch.cat([self.item_profile_embed['item_embed_' + str(i)](item_future_list_profile[:,:,i].long()) for i in range(item_future_list_profile.shape[-1])], axis=-1)
        item_future_list_profile_embeddings = self.item_pooling_layer(item_future_list_profile_embeddings.reshape((batch_len, seq_len_, len(self.item_profile_feature) * self.profile_feature_embedding_dim))) # [batch, seq_len_, embed_dim]
        mask = torch.ones((batch_len, seq_len)) # 目前我的数据集整齐没有mask
        
        # [batch_len, seq_len, embed_dim] + [batch, seq_len, embed_dim] ->  [batch_len, num_interests, hidden_dim]
        multi_interest_embeds = self.multi_interest_sa(user_profile_embeddings.unsqueeze(1) * item_history_list_profile_embeddings, mask)
        # [batch, seq_len_, hidden_dim], [batch_len, num_interests, hidden_dim] -> [batch, seq_len_, num_interests]
        future_interest_scores = torch.sigmoid(torch.bmm(item_future_list_profile_embeddings, multi_interest_embeds.permute((0,2,1))))
        best_future_index = torch.argmax(future_interest_scores, dim=-1)
        best_item_future_multi_interest_score  = future_interest_scores.take(best_future_index).squeeze()
        return best_item_future_multi_interest_score
model = ComirecSA(hidden_dim=dim, seq_len=train_seq_len, num_interests=4, user_profile_feature = [(i,len(list_)) for i, list_ in user_feature_vals.items()], item_profile_feature= [(i,len(list_)) for i, list_ in item_feature_vals.items()], profile_feature_embedding_dim=dim, dnn_layer_dims=[16])

optimizer = torch.optim.Adam(model.parameters(), lr=0.1, weight_decay=0.0003)
criterion = CrossEntropyLoss(reduction='sum').to(device)
label = torch.FloatTensor([1 for i in range(pos_num)] + [0 for i in range(neg_sample_num)]).to(device)

def DCG(batch_labels):
    dcgsum = np.zeros((batch_labels.shape[0]))
    for i in range(batch_labels.shape[-1]):
        dcg = (2 ** batch_labels[:,i] - 1) / np.math.log(i + 2, 2)
        dcgsum += dcg
    return dcgsum
def NDCG(output, labels):
    # ideal_dcg
    ideal_dcg = DCG(labels)
    # this
    dcg = DCG((np.argsort( - output, axis=-1)<pos_num).astype(np.float32))
    return np.sum(dcg/ideal_dcg)

for epoch in range(num_epochs):
    # train:
    epoch_train_losses = []
    model.train()
    for i, inputs in enumerate(train_loader):
        optimizer.zero_grad()
        user_profiles, item_history_list_profile, item_future_list_profile = inputs
        batch_len = user_profiles.shape[0]
        # print(item_history_list_profile.shape, item_future_list_profile.shape)
        user_profiles = user_profiles.to(device)
        item_history_list_profile = item_history_list_profile.to(device)
        item_future_list_profile = item_future_list_profile.to(device)
        output = model(user_profiles, item_history_list_profile, item_future_list_profile)
        labels = label.unsqueeze(0).repeat([batch_len,1])
        loss = criterion(output, labels)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=1, norm_type=2)
        optimizer.step()
        epoch_train_losses.append([batch_len, loss.item(), NDCG(output.cpu().detach().numpy(), labels.cpu().detach().numpy())])
    # validate:
    model.eval()
    epoch_test_losses = []
    for i, inputs in enumerate(test_loader):
        user_profiles, item_history_list_profile, item_future_list_profile = inputs
        batch_len = user_profiles.shape[0]
        user_profiles = user_profiles.to(device)
        item_history_list_profile = item_history_list_profile.to(device)
        item_future_list_profile = item_future_list_profile.to(device)
        output = model(user_profiles, item_history_list_profile, item_future_list_profile)
        labels = label.unsqueeze(0).repeat([batch_len,1])
        loss = criterion(output, labels)
        epoch_test_losses.append([batch_len, loss.item(), NDCG(output.cpu().detach().numpy(), labels.cpu().detach().numpy())])
    train_loss = sum([x[1] for x in epoch_train_losses])/sum([x[0] * (pos_num + neg_sample_num) for x in epoch_train_losses])
    test_loss  = sum([x[1] for x in epoch_test_losses])/sum([x[0] * (pos_num + neg_sample_num) for x in epoch_test_losses])
    train_ndcg = sum([x[2] for x in epoch_train_losses])/sum([x[0] for x in epoch_train_losses])
    test_ndcg  = sum([x[2] for x in epoch_test_losses])/sum([x[0] for x in epoch_test_losses])
    # print
    print('['+datetime.now().strftime("%Y-%m-%d %H:%M:%S")+']', 'epoch=[{}/{}], train_ce_loss: {:.4f}, train_ndcg: {:.4f}, validate_ce_loss: {:.4f}, validate_ndcg: {:.4f}'.format(epoch+1, num_epochs,  train_loss, train_ndcg, test_loss, test_ndcg))


# hard路由方式其实效果不佳，目前采用较为简单的网络结构。用的是item profile 没有去学item id的embedding。实际上，item id的embedding很有用。

[2023-09-05 14:25:55] epoch=[1/10], train_ce_loss: 0.7489, train_ndcg: 0.6628, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 14:26:00] epoch=[2/10], train_ce_loss: 0.7489, train_ndcg: 0.6681, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 14:26:05] epoch=[3/10], train_ce_loss: 0.7489, train_ndcg: 0.6681, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 14:26:09] epoch=[4/10], train_ce_loss: 0.7489, train_ndcg: 0.6681, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 14:26:14] epoch=[5/10], train_ce_loss: 0.7489, train_ndcg: 0.6681, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 14:26:18] epoch=[6/10], train_ce_loss: 0.7489, train_ndcg: 0.6681, validate_ce_loss: 0.7489, validate_ndcg: 0.6676
[2023-09-05 14:26:23] epoch=[7/10], train_ce_loss: 0.7489, train_ndcg: 0.6560, validate_ce_loss: 0.7489, validate_ndcg: 0.6681
[2023-09-05 14:26:28] epoch=[8/10], train_ce_loss: 0.7489, train_ndcg: 0.6681, validate_ce_loss: 0.7489, valida