In [1]:
import numpy as np
import pandas as pd
import random

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import os

In [2]:
#全局参数，随机种子，图像尺寸
seed = 114514
np.random.seed(seed)
random.seed(seed)
BATCH_SIZE = 512

hidden_dim = 16
epochs = 10
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cpu


## 数据准备

In [3]:
df = pd.read_csv('./train_dataset.csv')
print('共{}个用户，{}本图书，{}条记录'.format(max(df['user_id'])+1, max(df['item_id'])+1, len(df)))

共53424个用户，10000本图书，5869631条记录


In [4]:
df=df[:10000]

In [5]:
import tqdm
class Goodbooks(Dataset):
    def __init__(self, df, mode='training', negs = 99):
        super().__init__()

        self.df = df
        self.mode = mode

        self.book_nums = max(df['item_id'])+1
        self.user_nums = max(df['user_id'])+1

        self._init_dataset()
    
    def _init_dataset(self):
        self.Xs = []

        #下面的两个for循环建立了每一位用户与该用户看过的书籍之间的映射关系，即 :{用户1：用户1看过的书籍, 用户2：用户2看过的书籍, ...}
        self.user_book_map = {}
        for i in range(self.user_nums):
            self.user_book_map[i] = []
        for index, row in self.df.iterrows():
            user_id, book_id = row
            self.user_book_map[user_id].append(book_id)
            #self.user_book_map={用户1：用户1看过的书籍, 用户2：用户2看过的书籍, ...}
            
            
        #对于每一个用户的交互数据，训练集使用除了最后一个item之外的所有item(书籍)，而验证集只使用最后一个item(书籍)
        #训练集样本结构：(用户id，书籍id，label) 
        #label表示是否阅读，是：1，否：0
        if self.mode == 'training':
            for user, items in tqdm.tqdm(self.user_book_map.items()):
                for item in items[:-1]:
                    #构建正样本，对应label为1
                    self.Xs.append((user, item, 1))
                    #构建负样本，对应label为0
                    #正负样本比例为1:3，模拟真实情况下，用户已经阅读过的书籍数小于书籍总数
                    for _ in range(3):
                        while True:
                            neg_sample = random.randint(0, self.book_nums-1)
                            if neg_sample not in self.user_book_map[user]:
                                self.Xs.append((user, neg_sample, 0))
                                break
        #验证集样本结构：(用户id，已阅读书籍id，未阅读书籍id)
        elif self.mode == 'validation':
            for user, items in tqdm.tqdm(self.user_book_map.items()):
                if len(items) == 0:
                    continue
                self.Xs.append((user, items[-1]))
    
    def __getitem__(self, index):
        if self.mode == 'training':
            user_id, book_id, label = self.Xs[index]
            return user_id, book_id, label
        elif self.mode == 'validation':
            user_id, book_id = self.Xs[index]
            #在所有的当前用户没有看过的书籍中随机抽取99本，之后大概会对这99本排序，对当前用户进行推荐？
            negs = list(random.sample(
                list(set(range(self.book_nums)) - set(self.user_book_map[user_id])),
                k=99
            ))
            return user_id, book_id, torch.LongTensor(negs)
    
    def __len__(self):
        return len(self.Xs)

In [6]:
#建立训练和验证dataloader
traindataset = Goodbooks(df, 'training')
validdataset = Goodbooks(df, 'validation')

trainloader = DataLoader(traindataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=0)
validloader = DataLoader(validdataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=0)

100%|██████████████████████████████████████████████████████████████████████████████| 190/190 [00:00<00:00, 2182.48it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████| 190/190 [00:00<?, ?it/s]


In [7]:
for i in traindataset:
    print(i)
    break

(0, 257, 1)


In [8]:
for i in trainloader:
    print(len(i))
    print(i[0].shape)#user id [512,1]
    print(i[1].shape)#book id(pos or neg都有) [512,1]
    print(i[2].shape)#label   [512,1]
    break

3
torch.Size([512])
torch.Size([512])
torch.Size([512])


In [9]:
for i in validdataset:
    print(i)
    break

(0, 900, tensor([6741, 4685, 6776, 3157, 3494, 9160, 9501, 4916, 1507, 9623, 1820, 4317,
        6840, 6515, 9953,  549, 3001, 2845, 3998, 8035, 4845, 4913, 4526, 4265,
        8189, 8291, 8728, 5992, 4564, 5559, 4418,  692, 6184, 7579, 9062, 1441,
        3790, 6852,  501, 8727, 8651, 3277, 9140,  831, 8127,  444, 8670, 5291,
         743, 1152, 8327, 9056, 7927, 6598, 4233, 8374,  488, 7187, 7063, 3064,
        3980, 2439, 9851, 3373, 4533, 7862, 3650, 2650, 5464, 7416, 7847, 3919,
        6250, 4627,  969,  466, 5746,  796, 1398, 8729, 3194, 8980, 6416, 7091,
         519,  537, 4869, 1064, 6744, 1384, 5467, 6204,  410, 6041, 3081, 5036,
        6912, 1499, 8367]))


In [10]:
for i in validloader:
    print(len(i))
    print(i[0].shape)#user id,[99,1]
    print(i[1].shape)#book id(pos),[99,1]
    print(i[2].shape)#book id(neg),[99,99]，负样本随机抽取了99条，具体见Goodbooks类
    break

3
torch.Size([99])
torch.Size([99])
torch.Size([99, 99])


## 模型构建

In [11]:
# 构建模型
class NCFModel(torch.nn.Module):
    def __init__(self, hidden_dim, user_num, item_num, mlp_layer_num=4, weight_decay = 1e-5, dropout=0.5):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.user_num = user_num
        self.item_num = item_num
        self.mlp_layer_num = mlp_layer_num
        self.weight_decay = weight_decay
        self.dropout=dropout
        
        #MLP的Embedding层
        self.mlp_user_embedding = torch.nn.Embedding(user_num, hidden_dim * (2 ** (self.mlp_layer_num - 1)))
        self.mlp_item_embedding = torch.nn.Embedding(item_num, hidden_dim * (2 ** (self.mlp_layer_num - 1)))
        
        #GMF的Embedding层
        self.gmf_user_embedding = torch.nn.Embedding(user_num, hidden_dim)    
        self.gmf_item_embedding = torch.nn.Embedding(item_num, hidden_dim)
        

        mlp_Layers = []
        input_size = int(hidden_dim*(2 ** (self.mlp_layer_num)))
        for i in range(self.mlp_layer_num):
            mlp_Layers.append(torch.nn.Linear(int(input_size), int(input_size / 2)))
            mlp_Layers.append(torch.nn.Dropout(self.dropout))
            mlp_Layers.append(torch.nn.ReLU())
            input_size /= 2
        self.mlp_layers = torch.nn.Sequential(*mlp_Layers)
        """
        Sequential(
          (0): Linear(in_features=256, out_features=128, bias=True)
          (1): Dropout(p=0.5, inplace=False)
          (2): ReLU()
          (3): Linear(in_features=128, out_features=64, bias=True)
          (4): Dropout(p=0.5, inplace=False)
          (5): ReLU()
          (6): Linear(in_features=64, out_features=32, bias=True)
          (7): Dropout(p=0.5, inplace=False)
          (8): ReLU()
          (9): Linear(in_features=32, out_features=16, bias=True)
          (10): Dropout(p=0.5, inplace=False)
          (11): ReLU()
        )
        """

        self.output_layer = torch.nn.Linear(2*self.hidden_dim, 1)

    def forward(self, user, item):
        
        user_gmf_embedding = self.gmf_user_embedding(user)
        item_gmf_embedding = self.gmf_item_embedding(item)
        
        user_mlp_embedding = self.mlp_user_embedding(user)
        item_mlp_embedding = self.mlp_item_embedding(item)

        #GMF执行element-wise product操作
        gmf_output = user_gmf_embedding * item_gmf_embedding
        
        #MLP块通过堆叠的全连接层+激活函数
        mlp_input = torch.cat([user_mlp_embedding, item_mlp_embedding], dim=-1)
        mlp_output = self.mlp_layers(mlp_input)
        
        #将GMF和MLP的输出结果concat起来，送入最后的全连接层预测结果，并使用sigmoid函数将输出结果映射到0与1之间
        output = torch.sigmoid(self.output_layer(torch.cat([gmf_output, mlp_output], dim=-1))).squeeze(-1)
        
        return output
    
    def predict(self, user, item):
        self.eval()
        #print(user.shape,item.shape)#torch.Size([512]) torch.Size([512, 100])
        with torch.no_grad():
            user_gmf_embedding = self.gmf_user_embedding(user)
            item_gmf_embedding = self.gmf_item_embedding(item)

            user_mlp_embedding = self.mlp_user_embedding(user)
            item_mlp_embedding = self.mlp_item_embedding(item)

            gmf_output = user_gmf_embedding.unsqueeze(1) * item_gmf_embedding

            user_mlp_embedding = user_mlp_embedding.unsqueeze(1).expand(-1, item_mlp_embedding.shape[1], -1)#[512, 128]->[512,1,128]->[512,100,128]
            mlp_input = torch.cat([user_mlp_embedding, item_mlp_embedding], dim=-1)
            mlp_output = self.mlp_layers(mlp_input)

        output = torch.sigmoid(self.output_layer(torch.cat([gmf_output, mlp_output], dim=-1))).squeeze(-1)
        return output

## 模型训练&模型评估

In [12]:
model = NCFModel(hidden_dim, traindataset.user_nums, traindataset.book_nums).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
crit = torch.nn.BCELoss()

loss_for_plot = []
hits_for_plot = []

for epoch in range(epochs):
    #训练
    losses = []
    for index, data in enumerate(tqdm.tqdm(trainloader)):
        user, item, label = data
        user, item, label = user.to(device), item.to(device), label.to(device).float()
        y_ = model(user, item).squeeze()

        loss = crit(y_, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.detach().cpu().item())
        
    #验证
    hits = []
    for index, data in enumerate(validloader):
        user, pos, neg = data
        #print(pos.shape,neg.shape)#torch.Size([512]) torch.Size([512, 99])
        pos = pos.unsqueeze(1)#[512->[512,1]
        all_data = torch.cat([pos, neg], dim=-1)
        #print(all_data)#torch.Size([512, 100])
        output = model.predict(user.to(device), all_data.to(device)).detach().cpu()##torch.Size([512, 100])
        
        for batch in output:
            #print('batch',batch)
            pred10=(batch).argsort(descending=True)[:10]#从大到小，取前10所在下标
            
            print('pred10',pred10)
            
            #索引0是正样本，如果预测的前10中没有0，那么说明预测错了
            if 0 not in pred10:
                hits.append(0)
            else:
                hits.append(1)
                
    print('Epoch {} finished, average loss {}, hits@20 {}'.format(epoch, sum(losses)/len(losses), sum(hits)/len(hits)))
    loss_for_plot.append(sum(losses)/len(losses))
    hits_for_plot.append(sum(hits)/len(hits))

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 26.16it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[6813, 6870, 8355,  ..., 3921, 4500, 9685],
        [ 170, 6294, 7788,  ..., 9411, 1451, 4902],
        [2840, 8702, 4207,  ..., 1121, 3270, 1666],
        ...,
        [ 154, 1136, 8977,  ..., 5602, 7667, 2621],
        [1872, 9289, 1737,  ..., 2925, 8020, 2848],
        [9511, 4276, 3353,  ..., 2549, 7756, 2935]])
pred10 tensor([74, 26,  2, 34,  5, 62, 28, 75, 98, 96])
pred10 tensor([56, 37, 92, 42, 15, 31, 93, 83, 14, 51])
pred10 tensor([39, 83, 96, 11, 56, 87, 22, 45, 72, 97])
pred10 tensor([34, 67, 71, 38, 26, 28, 39, 45, 41, 74])
pred10 tensor([13, 20, 48, 36, 80, 53, 65, 97, 41, 22])
pred10 tensor([47,  2, 41, 31, 74, 26, 49, 18, 51, 65])
pred10 tensor([62, 57, 72, 83, 34, 85, 27,  7, 74, 76])
pred10 tensor([79, 85, 19, 63, 73, 41,  3, 18, 74, 86])
pred10 tensor([ 7, 31, 75, 63,  6, 32,  1, 52,  5, 89])
pred10 tensor([95, 28, 22, 50, 45, 76, 66, 37,  5, 77])
pred10 tensor([66, 71, 78, 64, 21, 68, 74, 46, 67, 23])
pred10 tensor([86, 35, 90, 71, 34, 99, 26, 23, 38, 36])
pr

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 26.86it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[1075, 1027, 5295,  ..., 1367, 7672, 9779],
        [2746, 5305, 5662,  ..., 5979, 5063, 5695],
        [ 704,  942,   33,  ..., 9091, 5837, 9216],
        ...,
        [5493, 1227, 7454,  ..., 4598, 1565, 2108],
        [6798, 3145, 9804,  ..., 2063, 7461,  397],
        [1450, 8128, 1182,  ..., 9559, 7111, 7104]])
pred10 tensor([90, 45, 44, 55, 19, 37, 38, 65, 39, 83])
pred10 tensor([33,  9, 30, 60, 59, 70, 23, 35, 52, 56])
pred10 tensor([89,  8, 81, 86, 80,  1,  4, 42, 24, 45])
pred10 tensor([51, 16, 53, 77, 13, 85, 30, 76, 81, 50])
pred10 tensor([79, 71,  0, 46, 95, 50, 53, 92,  7, 32])
pred10 tensor([28, 49, 35, 92, 58, 11, 51, 23, 99,  5])
pred10 tensor([38, 21, 32, 12, 47,  4, 59, 79, 81, 65])
pred10 tensor([37, 67, 56, 64, 52, 92, 70, 63, 12, 84])
pred10 tensor([42, 30, 84, 43, 44, 39, 86, 36, 85, 18])
pred10 tensor([99, 19, 82, 88,  0,  8, 49, 22, 28, 76])
pred10 tensor([85,  9, 60, 89, 34, 68, 82,  1, 47,  2])
pred10 tensor([ 9, 82, 74, 79, 50, 29, 14, 86, 22, 32])
pr

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 28.36it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[ 215, 5337, 3562,  ..., 6916, 8780, 3326],
        [1872, 5062, 3010,  ..., 1252, 3094, 1349],
        [5953, 5973, 7829,  ..., 1156, 8473, 5507],
        ...,
        [2176, 1882, 5992,  ..., 6454,  763, 6488],
        [2235, 9533, 1897,  ..., 5198, 3288, 6676],
        [4530,  757, 8802,  ..., 6260, 7062,  491]])
pred10 tensor([14, 37, 11, 84, 66, 63, 55, 32,  3,  8])
pred10 tensor([73, 44, 66, 74, 58, 13, 84, 46, 60, 97])
pred10 tensor([39, 50, 57, 93, 28, 24, 12, 38, 48, 84])
pred10 tensor([ 0, 14, 77, 54, 15, 34, 86, 87,  7, 78])
pred10 tensor([61, 35,  4, 24,  8, 99, 81, 56, 47, 16])
pred10 tensor([ 5, 21, 89, 68, 10, 46, 65, 76, 16, 41])
pred10 tensor([42, 75,  0, 11, 82,  2, 19, 94, 25, 28])
pred10 tensor([32, 74, 88, 14, 35, 20, 36, 41,  0, 55])
pred10 tensor([43, 63, 26, 49, 85, 32, 36, 23, 25,  4])
pred10 tensor([59, 52, 45, 74,  9, 30,  0, 51, 24, 67])
pred10 tensor([ 6, 43,  0, 90, 32, 77, 85, 96, 63, 33])
pred10 tensor([32, 77,  1, 22, 28, 52, 61,  9, 69, 70])
pr

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 26.38it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[  93, 1624, 8021,  ..., 9011, 8409, 3176],
        [ 208,  921, 3084,  ..., 3134, 1655, 7146],
        [ 123, 7034,  673,  ..., 3107, 2725, 3621],
        ...,
        [1136,  813, 4593,  ..., 6167, 7225,  602],
        [5458, 7533, 8229,  ..., 5544, 8284, 9242],
        [7484, 3979, 1469,  ..., 5551, 6887, 8814]])
pred10 tensor([ 0, 31,  2, 33, 39, 72, 19,  9, 55, 11])
pred10 tensor([ 9, 75, 93, 98, 18,  0, 60,  7, 45, 92])
pred10 tensor([19, 11, 16, 25,  0, 69, 70, 62, 36, 43])
pred10 tensor([ 0, 88,  1, 69, 81, 83, 71, 75, 98, 80])
pred10 tensor([ 2, 36, 37, 79, 24, 23, 60, 89, 18, 77])
pred10 tensor([51, 83, 53, 23, 41, 12, 89, 57, 70, 80])
pred10 tensor([54,  0, 33, 95, 31, 16, 21, 46, 34,  8])
pred10 tensor([92,  9, 54, 75, 89, 28, 82, 47, 90,  0])
pred10 tensor([ 0, 65, 33, 24, 14, 95, 63, 66, 82, 26])
pred10 tensor([98,  0, 51, 91, 23, 54, 78,  3, 72, 55])
pred10 tensor([90, 45,  0, 43, 54,  2, 39, 26, 27, 94])
pred10 tensor([68,  4, 61, 50,  8, 15, 45, 90,  2,  3])
pr

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 27.13it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[1159, 2520, 9110,  ..., 9037,  284, 2097],
        [ 489, 7091, 6228,  ...,  339, 2520, 9203],
        [3867, 3228, 4727,  ..., 7803, 4955, 1426],
        ...,
        [ 653, 9257, 4103,  ..., 4200, 7321,  412],
        [2840, 5048, 6670,  ..., 1515, 1083, 1544],
        [ 166, 4915, 6004,  ..., 3990, 3228, 1399]])
pred10 tensor([46, 73, 28, 15, 62, 44, 84, 79,  6, 98])
pred10 tensor([ 4, 13,  0, 99, 31, 40, 34, 59, 10, 75])
pred10 tensor([78, 56, 79, 35, 68, 50, 11, 51, 99,  0])
pred10 tensor([63, 37,  7, 58, 22, 76, 64, 74, 12, 36])
pred10 tensor([ 6, 45, 59, 99, 80, 31, 77, 44, 42, 91])
pred10 tensor([82, 66, 98, 10,  0,  9, 71, 23,  1, 93])
pred10 tensor([ 6,  1, 58, 97,  0, 30, 36, 65,  4, 32])
pred10 tensor([78, 55,  9, 53, 81, 18, 42, 52, 19, 30])
pred10 tensor([51, 32, 25, 38, 39, 88, 34, 96,  8, 42])
pred10 tensor([15, 26, 41, 99, 38, 67,  7, 19, 72, 28])
pred10 tensor([27,  0, 51,  1, 73, 97, 96, 80, 81, 24])
pred10 tensor([43, 41,  7, 81, 19, 68, 58, 53, 28, 99])
pr

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 28.70it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[9511, 9823, 3185,  ..., 5903, 1032,  180],
        [2176, 7621, 7567,  ..., 8961, 9382, 2479],
        [ 154, 3614, 6998,  ..., 3940, 4415, 7413],
        ...,
        [3645, 2523, 7337,  ..., 7723, 7769, 2404],
        [1450,  595, 1632,  ...,  181, 2143, 8189],
        [ 170, 9422,  388,  ..., 9684, 1887, 4311]])
pred10 tensor([ 2, 52, 53, 96, 99, 88, 17, 87, 78, 71])
pred10 tensor([77, 36, 89, 78, 62,  4, 98, 32, 84, 63])
pred10 tensor([95,  0, 79, 59, 17, 84, 81, 61,  6, 62])
pred10 tensor([99, 82, 56, 34, 39, 88, 32, 40, 94, 41])
pred10 tensor([92, 20, 23, 47,  4, 86, 22, 10,  7, 57])
pred10 tensor([56, 59, 79, 37, 67, 95, 42, 68, 54, 64])
pred10 tensor([33, 30,  7, 71, 72, 82,  1, 12, 76,  9])
pred10 tensor([81, 46, 76, 52, 51, 38, 98, 15, 84, 31])
pred10 tensor([75, 19, 77,  8,  6, 27, 57, 37, 15, 12])
pred10 tensor([ 8,  1, 62, 36, 67, 71,  0, 24, 76, 37])
pred10 tensor([83, 59, 49, 81,  0, 27,  9, 54, 18, 94])
pred10 tensor([31, 16, 51, 49, 86, 26, 88,  0, 70, 73])
pr

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 29.65it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[ 224, 5338, 3456,  ..., 1313, 6463, 6420],
        [ 134, 9783, 9976,  ..., 1564, 4278, 4977],
        [ 146, 1564, 3546,  ..., 3555, 7698, 8210],
        ...,
        [3690, 6263, 1707,  ..., 7640, 5137, 9032],
        [  14, 9173, 3692,  ..., 4749,  731, 6995],
        [2534, 7524,  643,  ..., 6337, 5529, 5903]])
pred10 tensor([65, 39,  0, 50, 34, 80, 23, 19,  2, 90])
pred10 tensor([35, 52,  0, 73, 83,  6, 74, 90, 32, 97])
pred10 tensor([43, 25, 61, 83, 45, 65, 18, 49, 77,  1])
pred10 tensor([88, 17, 84, 10,  4, 27, 67, 51, 16, 54])
pred10 tensor([11,  0, 67, 46, 93, 16, 21, 95, 97, 10])
pred10 tensor([85, 33, 93, 50, 36, 14,  0, 22, 40, 27])
pred10 tensor([ 0, 68, 10, 90, 24, 80, 56, 67, 17, 65])
pred10 tensor([31, 56, 50, 45,  7, 55, 29, 37,  6, 64])
pred10 tensor([ 0, 36, 95, 17, 23, 46, 38, 50, 19, 73])
pred10 tensor([31, 26, 12, 52, 13, 14, 32, 18, 53, 87])
pred10 tensor([56, 82, 96, 49,  0, 46, 26, 41,  1, 91])
pred10 tensor([ 4, 79,  0, 19, 34, 61, 88, 23, 39, 58])
pr

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 26.64it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[3645, 5423, 6339,  ..., 5982, 4254, 6940],
        [5458, 3009, 3235,  ..., 4715, 4778, 4879],
        [ 982, 8220, 8735,  ..., 5032, 6200, 3032],
        ...,
        [ 396, 7073, 9736,  ..., 2851, 6254, 5803],
        [9511,  274, 4647,  ..., 8407, 8080, 3293],
        [ 249, 8164,  546,  ..., 6755,  325, 4478]])
pred10 tensor([79,  9, 65, 58, 34, 80, 42, 72,  3,  7])
pred10 tensor([16,  3, 56, 26, 75, 34, 12, 76, 13, 24])
pred10 tensor([72, 91, 75, 78, 92, 45, 21, 56, 47, 93])
pred10 tensor([92, 36, 69,  6, 96, 37, 74, 56, 19, 34])
pred10 tensor([ 0, 54, 52, 84, 53,  1, 92, 78, 62, 28])
pred10 tensor([57, 23, 59,  1, 45, 28, 64, 12, 36, 76])
pred10 tensor([56,  8, 97, 19, 87, 60, 62,  6, 66, 43])
pred10 tensor([71, 51,  0, 88, 73, 64, 43, 97, 16, 63])
pred10 tensor([59, 74, 82, 53, 87, 33,  6, 94, 47, 38])
pred10 tensor([92, 37,  7, 55, 57,  0, 63,  2, 84, 24])
pred10 tensor([56, 95, 57, 26, 54, 19, 96, 71, 82,  0])
pred10 tensor([94, 36, 28, 74, 29, 43, 76, 59,  9, 53])
pr

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 29.12it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[ 113, 1961, 5517,  ..., 9282, 2943, 1664],
        [ 919, 3461, 2369,  ..., 9981, 3632, 5923],
        [  20, 7422, 6854,  ..., 5508, 2586, 6690],
        ...,
        [ 397, 6077, 5916,  ...,  652, 1320, 6387],
        [ 982, 8262, 7777,  ..., 8282, 8114, 6578],
        [  13, 2624, 6107,  ..., 3843, 4709, 1814]])
pred10 tensor([34, 55,  6,  0, 95, 37, 66, 25, 76, 98])
pred10 tensor([69, 20, 15, 34, 32, 81, 31, 93, 44, 80])
pred10 tensor([ 0, 49, 61, 21,  4, 46, 63, 29, 76,  7])
pred10 tensor([19, 63, 52, 64, 61, 58, 97,  0, 15, 59])
pred10 tensor([67, 36,  4, 38, 52, 27, 98, 34, 70,  8])
pred10 tensor([74,  0, 75,  6, 65, 42, 70,  4, 38, 53])
pred10 tensor([ 6, 33,  0, 45, 16, 23, 77, 70, 78, 80])
pred10 tensor([45, 65, 79, 58, 27, 14, 21, 20,  0, 28])
pred10 tensor([75, 85, 65, 43, 70, 79, 41,  2, 89, 56])
pred10 tensor([10, 86, 18, 28, 88,  9, 62,  0, 74, 70])
pred10 tensor([63, 72, 97,  4, 16, 96,  3, 39, 43, 50])
pred10 tensor([75, 37, 79, 59, 12, 34, 45, 68, 93, 56])
pr

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 28.83it/s]


tensor([[ 365, 3506, 2860,  ..., 6413, 9968, 3771],
        [ 851,  758, 2326,  ...,  317,  210, 6984],
        [ 384,  354, 2299,  ..., 5416, 9370, 1986],
        ...,
        [2176, 9866, 3904,  ..., 6307, 3283, 5332],
        [ 284, 8515, 6282,  ..., 8316, 3443, 6406],
        [1075, 7281, 9909,  ...,  148, 2451, 8487]])
pred10 tensor([73, 58, 18, 11, 91, 83, 21, 74,  0, 90])
pred10 tensor([98,  5, 33, 80, 34, 72, 45, 27, 12, 10])
pred10 tensor([42, 76, 19, 89, 16, 21, 67, 41, 13, 69])
pred10 tensor([13, 81, 31,  0,  7,  4, 40, 15, 96, 75])
pred10 tensor([11, 10, 13, 47, 84, 78, 64, 15, 27, 39])
pred10 tensor([22, 97, 31, 66, 87, 68, 55, 90, 45, 33])
pred10 tensor([85, 10,  1, 17, 58, 66, 88,  3, 86, 51])
pred10 tensor([ 4, 91, 49, 23, 57, 53, 42, 77, 51, 27])
pred10 tensor([27, 16, 42, 71, 38, 30, 11, 31, 33, 14])
pred10 tensor([24, 34, 42, 31, 87,  7, 97, 73, 96, 20])
pred10 tensor([18, 14, 93, 77, 76, 91, 61, 23,  8,  0])
pred10 tensor([50, 11, 14, 48, 52,  1,  7, 62, 80,  8])
pr

In [14]:
model = NCFModel(hidden_dim, traindataset.user_nums, traindataset.book_nums).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
crit = torch.nn.BCELoss()

loss_for_plot = []
hits_for_plot = []

for epoch in range(epochs):
    #训练
    losses = []
    for index, data in enumerate(tqdm.tqdm(trainloader)):
        user, item, label = data
        user, item, label = user.to(device), item.to(device), label.to(device).float()
        y_ = model(user, item).squeeze()

        loss = crit(y_, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.detach().cpu().item())
        
    #验证
    hits = []
    for index, data in enumerate(validloader):
        user, pos, neg = data
        #print(pos.shape,neg.shape)#torch.Size([512]) torch.Size([512, 99])
        pos = pos.unsqueeze(1)#[512->[512,1]
        all_data = torch.cat([pos, neg], dim=-1)
        print(all_data)#torch.Size([512, 100])
        output = model.predict(user.to(device), all_data.to(device)).detach().cpu()##torch.Size([512, 100])
        
        #每一个用户预测的结果对应output中的一行(batch),被预测的相应item是all_data中的一行(batch_items)
        for batch ,batch_items in zip(output,all_data):
            
            pos_id=batch_items[0]#取出正样本对应的真实的item id
            
            pred10=(batch).argsort(descending=True)[:10]#预测值从大到小，取前10所在下标
            pred10=batch_items[pred10]#在batch_items中的真实下标，这才是item id
            print(pred10)
            #索引0是正样本，如果预测的前10中没有0，那么说明预测错了
            if pos_id not in pred10:
                hits.append(0)
            else:
                hits.append(1)
                
    print('Epoch {} finished, average loss {}, hits@20 {}'.format(epoch, sum(losses)/len(losses), sum(hits)/len(hits)))
    loss_for_plot.append(sum(losses)/len(losses))
    hits_for_plot.append(sum(hits)/len(hits))

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:03<00:00, 24.70it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[ 154, 7793, 7855,  ..., 5052,  441, 7805],
        [1728, 6290, 2170,  ..., 9121, 2102, 1346],
        [4530, 9667, 4790,  ..., 2918, 3704, 5545],
        ...,
        [7484, 3735, 8418,  ..., 5728, 9618,  907],
        [3435, 1830, 2604,  ..., 2331, 8905, 7993],
        [1075,  334, 8454,  ..., 3879, 6621, 8711]])
tensor([5677, 4348, 2675, 6947, 4948, 8039, 2001,  208, 6318, 9099])
tensor([ 353, 1878,  649, 8140, 2729, 6613, 4014, 1544, 2934,  523])
tensor([4790, 6921, 1173, 4973, 9620, 8479, 5804, 3704,  713, 8636])
tensor([4173, 8114, 3897, 4782, 6409, 3607, 3382, 5489, 6833, 8312])
tensor([5677, 8113, 7870, 7722, 5212, 3121, 1765, 4472, 8023,  153])
tensor([4833, 8267, 6679, 5005,  781, 4788, 5440, 2842, 3422, 9678])
tensor([5344, 1533, 8994, 2059, 3106, 2460, 8535, 1337, 2167,  960])
tensor([3203, 3411, 6204, 1210, 9172, 4823, 3966, 5366, 7802, 4574])
tensor([1491, 4739, 9315, 1723, 5056, 6627, 9451, 1950, 1170, 6453])
tensor([4617, 9320, 5266, 7584, 8036, 9847, 6385, 778

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 27.49it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[ 967, 7831, 7030,  ..., 4745, 3048, 9366],
        [2246, 3153, 4626,  ..., 9514, 7469, 1913],
        [ 166, 3003, 1714,  ..., 7968, 5834, 2264],
        ...,
        [ 115, 1850, 5299,  ..., 1237,  287, 1493],
        [3435, 7733,  883,  ..., 7343, 5614, 6687],
        [ 701,  990, 1047,  ..., 2278, 2231,  131]])
tensor([5449, 4274,  597, 3061, 1305, 3579, 7536, 6379, 4185, 9315])
tensor([7149,  940, 8841, 8558, 4198, 3070, 5207, 9716, 6979, 8724])
tensor([5439, 4835, 1714, 5921, 6835, 2662,  575, 7952, 2717, 4228])
tensor([3794, 4520, 2774, 4465, 5882, 7527, 4190, 2060, 2517, 3708])
tensor([9223, 1979, 3119,  941,  800, 8843,   49, 2116, 7149, 9570])
tensor([5213, 4393, 6190,  652, 2450, 7955, 7926, 7637,  329, 2423])
tensor([5509, 4150,   20, 1574, 3874, 6871, 8374, 8584, 1981, 5774])
tensor([2928, 6990, 6089, 1987, 6841,  696, 4072, 9077, 4310, 4128])
tensor([  61, 8603, 5555, 8457,  714, 1591, 7253, 5054,   49, 4525])
tensor([5082,  652,   26,  742,   65,  556, 2354,  27

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 29.78it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[ 154, 1338, 1639,  ..., 6822, 5938, 6906],
        [1211, 4199, 3233,  ..., 5876, 8651, 9772],
        [ 900, 7607, 9511,  ..., 3865, 7006, 1146],
        ...,
        [2246, 6067, 2621,  ..., 6476, 9724, 2512],
        [  29, 5603, 2494,  ..., 3293, 5570, 3819],
        [6180, 5087, 9902,  ..., 3308, 3660, 8477]])
tensor([ 723, 2532, 8820, 8948, 8638, 5255, 6254, 1017, 6684,  154])
tensor([7684,  611,  493, 6214, 2399, 7199,  254, 9742, 3007, 6663])
tensor([5102, 3910, 6190, 4331, 7006, 5243, 9759,  657, 2008,  183])
tensor([3066,  742, 5515,  228, 5661, 9605, 4973, 1463, 1074, 5243])
tensor([2399, 4305, 1388, 9228, 1210,  637, 9525, 2587, 9713, 5129])
tensor([1194, 3070, 4305, 6682, 1708, 1747, 1261, 1401, 8575, 2024])
tensor([1286, 8948, 1271, 2523,  706, 5570, 3806, 8843, 3121, 1341])
tensor([5555, 1054,   55, 1223,  969, 1923, 8200, 9451,  636,  521])
tensor([3155,  484, 6190, 7684, 2818, 4823, 4312, 6150,  227, 7052])
tensor([ 218,   32,  917, 1537, 4475,  322, 6129, 123

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 28.11it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[ 972, 1531,  965,  ..., 5282, 2160,  798],
        [ 218,  798, 5173,  ..., 8650,  896, 8233],
        [  31, 3755, 9187,  ..., 6602, 5795, 6070],
        ...,
        [ 170, 8461, 3615,  ..., 5708, 5895,  117],
        [3868, 3101, 8532,  ..., 1886, 4442, 1531],
        [ 672, 2580, 7729,  ..., 4424, 8077, 5930]])
tensor([ 126,  207, 1204, 1388, 6819, 5282,   75, 1852,  415, 9234])
tensor([2023, 6258,   35,  928,  218,  896, 8564, 5406, 1073, 3047])
tensor([ 653,   31, 1010, 9989, 9187,  366,  878, 2073, 8693, 1187])
tensor([4067, 1008,  331, 8169,  278,  111, 8280, 4058, 6559, 9718])
tensor([  38, 9178, 3048,  270, 2012, 3585, 1865,  268, 3597, 1655])
tensor([ 537, 2026, 6194, 4014, 5345, 1805, 4619, 1384, 3714,  611])
tensor([1858,   83,   43, 1089,   75, 2731, 4400, 2896,  113, 2992])
tensor([  37,  166,  866, 1388, 7839, 4725, 1264, 2271, 1427, 1248])
tensor([ 179, 4520, 3378,  688, 6871, 2034, 3714,  351, 1261, 1010])
tensor([2545,  929, 6194, 2023,  351,   36, 6149, 202

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 29.21it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[ 898, 3228, 4481,  ..., 7992, 8165, 7076],
        [  12, 6805, 6145,  ..., 1754, 9686, 8758],
        [1254, 9717, 1406,  ..., 3036, 4791, 3577],
        ...,
        [9511, 5469, 1715,  ..., 1135, 2861, 2720],
        [ 489, 8870, 7318,  ..., 3781, 1182, 1544],
        [6798, 3281, 8053,  ..., 7397, 1806, 4096]])
tensor([4283,  898, 4018, 4014, 2901, 7051, 2987, 1922, 1135,  848])
tensor([1644,   12,  460,  455, 4099, 9223, 2088, 2135, 9985, 7757])
tensor([1490, 2905,  455, 1198, 1254, 1223, 9486, 4484, 1645, 2674])
tensor([  38,  113,  687,    2, 2342, 4886, 4923, 9030, 6650,  314])
tensor([ 416,   19, 1134, 1380,   52,  878,   24, 1348, 5154, 4770])
tensor([ 300, 4972, 2273,  974, 4075, 6493, 6782, 1074,  221, 9742])
tensor([  86, 1380, 2003, 2629, 3073, 5154, 4107,  704, 8151, 1642])
tensor([ 224, 1249,  159,  149, 1205, 2583, 4136,  569, 3362, 8390])
tensor([ 315, 3515, 5085, 2551, 5240,  177, 7899, 5273, 4859, 2356])
tensor([ 575,   89,  281, 6069, 1249, 3292, 1089, 975

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 28.61it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[6813, 4855,  630,  ..., 2677, 1317, 2611],
        [3868, 3930, 3002,  ...,   75, 9913, 6378],
        [ 972, 5211, 5907,  ..., 2933,  484, 9268],
        ...,
        [  31, 2692, 6896,  ..., 6170, 6485, 1068],
        [  14, 8834, 2842,  ..., 1506, 6437, 1107],
        [ 141, 3466, 9795,  ..., 4691, 1958, 2281]])
tensor([7006,  630, 1596,  925, 1078, 4770,  620,  260, 5697,  338])
tensor([  19,   75, 1281, 8175, 7232, 4128, 6190, 1581, 1135, 1204])
tensor([2930,  337,  349, 7757,  386,  484,  972, 1134, 2738, 5907])
tensor([ 605,   91, 6258,   56,  248,  974, 2769, 5944, 2780, 2746])
tensor([  99, 1702,  688,  653, 1238,  821,  239,  147, 2440, 4890])
tensor([ 661,  191,  110,  732,  288,  331, 7027, 3662, 2551, 5637])
tensor([  10, 4393, 7199,  714, 9573,  221, 6089, 7095, 7720, 9873])
tensor([8558,  897,  814, 1536,  180,   42, 2905, 2028,  915,  290])
tensor([4362, 4393, 1101, 6326,   29, 1644, 2746, 9223, 7115, 1029])
tensor([3952,  466, 1503,  254, 1384, 5628, 2096,   3

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 28.50it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[ 919,  869, 5373,  ..., 9472, 2664, 7332],
        [ 161, 4935, 7877,  ..., 7913, 5821,  665],
        [ 378,  144, 1726,  ...,  170, 1936, 5444],
        ...,
        [  93, 2932, 8444,  ..., 8681, 9642, 8215],
        [ 885, 9873, 1204,  ..., 4760, 6181, 6053],
        [ 397, 8919, 7628,  ..., 1270, 9440, 6910]])
tensor([ 800, 5373, 5330,  869, 1020, 8841, 1714,  545, 3492, 9468])
tensor([ 161,  894,  856, 1614, 3578,  415, 1465, 2034, 9486,  246])
tensor([ 170, 4249, 1726,  859, 2930, 1249,  559, 1884, 6650, 6696])
tensor([9525, 8488, 1183, 6143, 9792, 2742, 1911, 1955, 7005, 8538])
tensor([ 126,  533,   11,  454, 2303, 1435, 4434,  143, 6938, 6399])
tensor([ 260,  112, 9761,  379, 4240,  438, 9497, 3543, 1427, 5111])
tensor([   0, 8138,  407, 4078, 8033, 9178, 4084,  190,  473, 4369])
tensor([ 309,   30, 5509, 5729, 3781, 9203,  338,   82, 6706,  313])
tensor([5044, 2073, 6860,  479, 8369, 7625, 1866, 1955, 3964, 1564])
tensor([6210,  497,  282, 5477, 9772, 5524, 8123, 794

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 30.04it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:02, 25.71it/s]

tensor([[2840, 4998, 6427,  ...,   23, 3799, 9463],
        [ 141,  158,  456,  ..., 4388, 3832, 4098],
        [ 984, 9956, 6073,  ..., 7797, 4358, 7310],
        ...,
        [ 123, 7033, 5860,  ..., 3568, 3786, 3192],
        [2246,  990, 6133,  ..., 9197, 8268, 8062],
        [ 145,  167, 3754,  ..., 8913, 1938, 7512]])
tensor([  23,   83, 4116,  837,  446,  594, 6943, 1459, 6560,  491])
tensor([ 546,  118, 5729,  141, 5829,  772,  158, 1260, 1083, 1441])
tensor([ 940,  157, 2452, 8430,  954,  665,  220, 2631, 4066, 5325])
tensor([ 915,  958, 2026, 4019, 9203, 9525, 1857, 1543,  394, 5733])
tensor([  93, 1205,  762, 9486, 2048, 1662,  229, 5207,  849, 5742])
tensor([1328,  491,  269, 7007, 2925, 6643, 5664, 2890, 1245, 3645])
tensor([ 940, 2534,   12, 3645,  174,  208, 3356,   85,  704, 7359])
tensor([1298,  653, 7872, 2461, 3585,  704, 8123, 5384, 2081, 1760])
tensor([ 123,   47, 2297, 4605, 4941, 9245, 4745, 7252,  658, 5118])
tensor([ 179,   19,  128, 7002, 2097, 3510, 2238, 185

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 29.36it/s]
  0%|                                                                                           | 0/78 [00:00<?, ?it/s]

tensor([[6813, 4262,  458,  ..., 8237, 6255, 6765],
        [1724, 6330,  719,  ...,  308,  876, 4783],
        [ 145, 3341, 9280,  ..., 3484, 6730, 6198],
        ...,
        [ 898, 4918, 3809,  ..., 7461, 8352, 8316],
        [ 123, 2508, 4598,  ..., 8667,  899,  584],
        [2610, 9385, 7522,  ..., 3512, 2825, 5061]])
tensor([ 255, 2102,  207,  713, 1013, 4099, 1282, 3432, 6628,  150])
tensor([1017, 4019, 6819,  345, 5445, 7765, 8603, 6684,  701,  366])
tensor([ 537,   94,  145,  303, 1010, 1219, 4613, 8877, 2492, 5708])
tensor([ 489,  214, 1579,  781, 6104, 5345,  957,  131,   14, 3450])
tensor([ 521, 1264,  140,   21,   31, 1955, 7952,  221, 5422, 2008])
tensor([ 605, 9077, 2023, 7285, 6480,  680, 9857, 1047, 6560, 1576])
tensor([8881, 3378, 2251, 7150, 1023, 7918, 7092,  742,  968,  727])
tensor([ 555,   35,  942,  489,   13, 1968,  750, 1348, 2545, 2088])
tensor([ 981,  147,  714, 2317, 1251,  188, 2673, 5599, 1508,  814])
tensor([  22,  159,  172,  180, 1537,  871, 4725,  68

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:02<00:00, 28.83it/s]


tensor([[1075, 8410, 7752,  ..., 5009, 3269, 6746],
        [ 378,  771, 3739,  ..., 4092, 9982, 1286],
        [ 384, 3259, 5066,  ..., 7196, 5458, 1329],
        ...,
        [1728, 1401, 6244,  ..., 2167, 7356, 2169],
        [2484, 8529, 6265,  ..., 4611,  630, 5468],
        [ 704, 6725, 7076,  ..., 5371, 6614, 4493]])
tensor([2317, 5490,  187, 1576,    6, 8653, 8020, 2549, 1025, 2673])
tensor([ 371, 6363,   95, 1040,  153, 7521, 8948, 2672,  340, 9587])
tensor([ 397, 7196, 2614,   42,  260,  948, 4744, 7918, 7480, 4827])
tensor([5524,  309, 2448, 1130,  242,   43, 8408, 3605,  974, 7839])
tensor([2346, 8835,  261,  224, 9301,  948, 3509, 6011, 6585, 2715])
tensor([  80, 8454,  969, 7873,  918, 9772,   33, 4254,  277,  436])
tensor([7261,  407,  102, 2662, 6593,  107, 2902, 4028, 1194,  922])
tensor([1153,  341,  952, 8108, 1514,  968, 8383, 5571, 6633, 9742])
tensor([ 765, 1074,  116,   91, 4393,  139, 2189,  360, 1499, 8776])
tensor([  17, 1004, 2797, 2780,   99, 1517, 1135,   9

In [None]:
# 模型保存
torch.save(model.state_dict(), './model.h5')

## 模型预测

In [16]:
df = pd.read_csv('./test_dataset.csv')
user_for_test = df['user_id'].tolist()

In [19]:
user_for_test

[0,
 1,
 3,
 5,
 7,
 8,
 10,
 14,
 17,
 21,
 23,
 24,
 27,
 31,
 28,
 33,
 39,
 30,
 42,
 48,
 45,
 49,
 52,
 54,
 57,
 60,
 61,
 64,
 65,
 53,
 55,
 68,
 69,
 62,
 71,
 72,
 58,
 73,
 74,
 75,
 79,
 37,
 82,
 84,
 77,
 44,
 88,
 92,
 93,
 99,
 25,
 102,
 104,
 107,
 111,
 112,
 113,
 114,
 116,
 85,
 89,
 115,
 122,
 123,
 124,
 125,
 127,
 128,
 134,
 121,
 136,
 135,
 140,
 141,
 142,
 148,
 150,
 155,
 129,
 157,
 163,
 35,
 166,
 168,
 172,
 173,
 174,
 176,
 175,
 179,
 177,
 182,
 170,
 183,
 184,
 186,
 188,
 167,
 189,
 105]

In [26]:
predict_item_id = []

def chunks(l, n):#n: batch size
    for i in range(0, len(l), n):
        yield l[i:i+n]

f = open('./submission.csv', 'w', encoding='utf-8')

#预测每一个用户user可能会点击的图书item
for user in tqdm.tqdm(user_for_test):
    #将用户已经交互过的物品排除
    user_visited_items = traindataset.user_book_map[user]
    items_for_predict = list(set(range(traindataset.book_nums)) - set(user_visited_items))

    results = []
    user = torch.Tensor([user]).to(device).long()

    for item_batch in chunks(items_for_predict, 512):
        item_batch = torch.Tensor(item_batch).unsqueeze(0).to(device).long()
        
        #print(user.shape)#torch.Size([1])
        #print(item_batch.shape)#torch.Size([1, 512])

        result = model.predict(user, item_batch).view(-1).detach().cpu()
        #print(result.shape)#torch.Size([512])
        results.append(result)
    #print(len(results),len(results[0]))#(20,512),注意results[-1]不一定是512，因为可能不足一个batch (batch size =512)
    results = torch.cat(results, dim=-1)#所有items_for_predict关于用户user的预测值

    #取得分前10的item在results(也在items_for_predict)中的下标
    predict_item_id = results.argsort(descending=True)[:10]#从大到小排序，取前10
    print('ind:',predict_item_id)
    
    #映射到真实的item id
    res=[]
    for i in predict_item_id:
        res.append(items_for_predict[i])
    print('res:',res)
    list(map(lambda x: f.write('{},{}\n'.format(user.cpu().item(), x)), predict_item_id))

f.flush()
f.close()

  3%|██▍                                                                               | 3/100 [00:00<00:04, 19.42it/s]

ind: tensor([  0,  75, 507,  82,  18, 301,  76,  12,  19,   6])
res: [0, 111, 586, 120, 24, 371, 112, 17, 25, 7]
ind: tensor([   0,  550,  336, 1111,   42,   93, 4308,  534,    6,   76])
res: [0, 586, 371, 1153, 57, 111, 4360, 569, 10, 93]
ind: tensor([  0, 486, 172, 586, 235, 289,  63,  71,  58, 231])
res: [0, 586, 237, 690, 309, 371, 111, 123, 105, 305]
ind: 

  7%|█████▋                                                                            | 7/100 [00:00<00:05, 17.03it/s]

tensor([   0,  114,    7,  154, 1211,   18,  667,   56,   17,   77])
res: [0, 121, 7, 161, 1255, 18, 700, 57, 17, 80]
ind: tensor([3345, 1323,  139, 5990,    0,  544,   47,  212,   15,  186])
res: [3424, 1384, 159, 6077, 0, 586, 57, 237, 20, 210]
ind: tensor([   3,   62,   27,   71,   28,   12,    0,  113, 2884,  191])
res: [7, 99, 56, 111, 57, 34, 3, 158, 3001, 248]
ind: tensor([   1,    0, 1305,  128,   41,  945,   12,  126,   86,  448])
res: [1, 0, 1384, 161, 57, 1020, 17, 159, 112, 508]


  9%|███████▍                                                                          | 9/100 [00:00<00:05, 15.97it/s]

ind: tensor([ 89,   0,  35, 633, 532,  65,  91, 325,  72,  81])
res: [121, 0, 57, 690, 586, 93, 123, 371, 102, 111]
ind: tensor([ 650,  548,    0,   51,  138,   52,    1,  615, 3980,   35])
res: [690, 586, 0, 69, 161, 70, 1, 653, 4067, 46]
ind: tensor([  7,  22,  50, 557, 517, 114,  16,  14,  46,  19])
res: [10, 25, 57, 586, 546, 123, 19, 17, 53, 22]
ind: tensor([  0,  18, 340,   9,  90,  69, 590, 283,   7, 516])


 13%|██████████▌                                                                      | 13/100 [00:00<00:05, 15.05it/s]

res: [0, 39, 404, 19, 124, 102, 663, 343, 16, 586]
ind: tensor([  0,   8, 569, 552,  21, 804, 870, 358,  43, 148])
res: [0, 10, 586, 569, 25, 827, 894, 371, 48, 160]
ind: tensor([ 542, 4677,   38,   29, 1298,    7,  820,   93, 1175,  121])
res: [605, 4787, 57, 44, 1384, 13, 894, 131, 1255, 159]


 17%|█████████████▊                                                                   | 17/100 [00:01<00:05, 16.15it/s]

ind: tensor([  68,    8,  350,    3, 1164,   12,   44, 1347,  511,  113])
res: [102, 17, 416, 10, 1255, 23, 69, 1440, 586, 161]
ind: tensor([   5,  365,   97,    0,  196,  134,  537,  637, 7325,  201])
res: [7, 404, 120, 0, 227, 161, 586, 690, 7415, 232]
ind: tensor([   0,  545, 1200,  311, 1595,   23,   92,    9,   32,    3])
res: [0, 586, 1255, 334, 1659, 25, 101, 10, 34, 3]
ind: tensor([254,  37, 190, 317, 115, 349,  11,  21,  55, 494])
res: [305, 69, 237, 371, 157, 404, 34, 48, 89, 555]


 21%|█████████████████                                                                | 21/100 [00:01<00:04, 17.05it/s]

ind: tensor([ 17,  37,   1,  90,  97, 359, 195,   8, 517,  40])
res: [25, 53, 3, 112, 120, 404, 227, 10, 573, 57]
ind: tensor([ 379,  550,  219, 3797, 6468, 1194,  210,   12, 1247,   30])
res: [404, 586, 237, 3877, 6560, 1255, 227, 18, 1309, 39]
ind: tensor([ 42, 329,   6, 215,  62, 294, 536, 890, 483,  10])
res: [57, 371, 10, 250, 80, 334, 586, 952, 533, 16]
ind: tensor([  20,    9,   46,   13, 1483,   97,  264, 2494,   23,  486])
res: [25, 10, 53, 17, 1536, 112, 290, 2560, 28, 519]


 23%|██████████████████▋                                                              | 23/100 [00:01<00:04, 17.44it/s]

ind: tensor([   0,   47,  142, 1090,   52, 1078,    6,    2,   81,  949])
res: [0, 57, 160, 1153, 63, 1141, 7, 3, 93, 1010]
ind: tensor([560,   0,   9,  51,  84,   8, 114, 380,  90, 912])
res: [586, 0, 10, 57, 94, 9, 126, 404, 100, 942]
ind: tensor([ 260,   90,   11,   37,  351, 5731, 1452,  131,   26,   97])
res: [309, 121, 25, 57, 404, 5850, 1538, 167, 44, 129]
ind: tensor([  98,  205,   76,  230,  544,  284,  333,  562, 1647,    5])
res: [126, 237, 102, 263, 586, 322, 371, 605, 1702, 15]


 28%|██████████████████████▋                                                          | 28/100 [00:01<00:04, 17.59it/s]

ind: tensor([ 27, 326, 233, 493,  79,  70, 105, 106, 465,  24])
res: [57, 404, 305, 586, 124, 111, 159, 160, 555, 53]
ind: tensor([   0,  537,  100,  363,  636,  201,   12, 5758,   86, 1628])
res: [0, 586, 126, 404, 690, 237, 19, 5850, 111, 1702]
ind: tensor([   0,  620, 1302,  519, 1174,   94,   92,   45,  468, 1617])
res: [0, 690, 1384, 586, 1255, 126, 123, 69, 533, 1702]
ind: tensor([ 564,    6,  148,    0,    7,  110,  351,  102,   20, 1349])
res: [586, 7, 160, 0, 10, 120, 371, 112, 25, 1384]


 33%|██████████████████████████▋                                                      | 33/100 [00:01<00:03, 17.58it/s]

ind: tensor([  4,   1, 310,  79, 516,  33, 485,   0,  20, 583])
res: [7, 3, 371, 112, 586, 53, 555, 1, 34, 653]
ind: tensor([   5, 3943,  377,  901,  250,  111, 5930,  654,  309, 1251])
res: [7, 4019, 404, 942, 272, 120, 6011, 690, 334, 1300]
ind: tensor([8244,  545,    0,  564,  102,    6,   19, 2045,  108,  642])
res: [8328, 586, 0, 605, 120, 10, 23, 2106, 126, 690]
ind: tensor([   0, 1213,    8,   14,    6,  560,  350,   20,   19,  507])
res: [0, 1255, 10, 17, 7, 586, 371, 23, 22, 533]


 35%|████████████████████████████▎                                                    | 35/100 [00:02<00:04, 15.70it/s]

ind: tensor([556,  23,   9,   3, 346,  61, 241,   7, 540, 778])
res: [586, 25, 10, 3, 371, 69, 263, 7, 569, 819]
ind: tensor([513,  87,   4,   9, 614, 927, 117,  70,  78,  12])
res: [586, 123, 10, 16, 690, 1010, 161, 102, 111, 19]
ind: tensor([   0,    4,  326,  630,  731,  188, 1312,  827,  127,  198])
res: [0, 10, 371, 690, 794, 227, 1384, 894, 159, 237]


 39%|███████████████████████████████▌                                                 | 39/100 [00:02<00:03, 15.89it/s]

ind: tensor([560,   0,  12, 143, 512,   7,   3,  87,  47,   1])
res: [586, 0, 17, 159, 537, 10, 3, 102, 57, 1]
ind: tensor([ 553,    6,  341,   10,   95,   86, 5977,  183,   77,    0])
res: [586, 10, 371, 25, 120, 111, 6032, 210, 101, 3]
ind: tensor([ 27, 494, 104, 833,  30, 131,  65, 894,  40,  70])
res: [53, 586, 159, 942, 57, 192, 105, 1004, 70, 112]
ind: 

 41%|█████████████████████████████████▏                                               | 41/100 [00:02<00:03, 15.34it/s]

tensor([ 311,  520,  589,   86,    0,   29,  503,   83, 1067,  537])
res: [371, 586, 661, 126, 3, 57, 569, 123, 1153, 605]
ind: tensor([ 10,   1, 268,  56, 562,   0, 383,  96,  11, 329])
res: [17, 1, 290, 69, 586, 0, 407, 112, 18, 352]
ind: tensor([  0, 920, 564, 110, 253, 286, 218,  38, 533,  57])
res: [0, 948, 586, 124, 272, 305, 237, 44, 555, 64]
ind: tensor([916, 944, 261, 324, 125,  36,   9,  93,  97,   6])
res: [982, 1010, 305, 371, 161, 53, 17, 123, 129, 10]


 46%|█████████████████████████████████████▎                                           | 46/100 [00:02<00:03, 16.21it/s]

ind: tensor([   0,   82,  418,  612,  183,   77,   61, 1290, 4028,    4])
res: [0, 121, 488, 690, 237, 111, 93, 1384, 4136, 10]
ind: tensor([   0,    6,   20,   98,  348,    9,   39,   18,   50, 1209])
res: [0, 7, 25, 111, 371, 10, 44, 23, 57, 1255]
ind: tensor([564,  15, 146, 222, 120,  87,  90, 111,   7, 474])
res: [586, 25, 161, 237, 135, 102, 105, 126, 9, 494]
ind: tensor([  26,    3,   30,  512,  483,   68,  877, 1099,  311, 1293])
res: [53, 10, 57, 586, 555, 101, 958, 1185, 371, 1384]


 50%|████████████████████████████████████████▌                                        | 50/100 [00:03<00:03, 15.20it/s]

ind: tensor([ 252,  108,    5,   43, 5484,  890,  454,  144, 1305, 4238])
res: [290, 129, 7, 57, 5585, 952, 500, 172, 1384, 4335]
ind: tensor([  16,    0,  559,    8, 4312,  347,   40, 1219,   11, 8162])
res: [19, 0, 586, 10, 4360, 371, 51, 1255, 13, 8215]
ind: tensor([ 559,    0,    7,    3, 7130,  111,  241,  259, 1665,   41])
res: [586, 0, 7, 3, 7199, 126, 263, 282, 1702, 48]


 52%|██████████████████████████████████████████                                       | 52/100 [00:03<00:03, 14.18it/s]

ind: tensor([   0, 1335,   22,    9, 1210,  114,  147,   51,   20,   13])
res: [0, 1384, 25, 10, 1255, 126, 161, 57, 23, 16]
ind: tensor([371,  14,  20, 277, 545, 305, 213, 291, 296, 104])
res: [404, 18, 25, 305, 586, 334, 237, 319, 325, 120]
ind: tensor([  12,  202,    6,   13, 1453,   15,   43,  830,  104,  438])
res: [17, 237, 10, 20, 1536, 23, 57, 894, 123, 492]


 56%|█████████████████████████████████████████████▎                                   | 56/100 [00:03<00:03, 13.53it/s]

ind: tensor([ 529,   11,  196,  601,    3,    5, 4160,  629,  262,   89])
res: [586, 18, 237, 661, 5, 7, 4250, 690, 305, 120]
ind: tensor([  44,    0,  548,  140, 8467, 9312,  190, 7013,  897,   18])
res: [53, 0, 586, 159, 8558, 9408, 210, 7102, 942, 25]
ind: tensor([2398, 5927, 5620,    0,  276,  387,  113, 1116,  918, 4605])
res: [2440, 6011, 5703, 0, 290, 404, 120, 1141, 942, 4682]


 58%|██████████████████████████████████████████████▉                                  | 58/100 [00:03<00:03, 12.88it/s]

ind: tensor([ 472, 1491, 1657,  347,    0,   17,  112,  110,  380,  874])
res: [500, 1536, 1702, 371, 0, 23, 126, 123, 404, 912]
ind: tensor([  18,  165,    0,  262,  868, 4476,  106, 4387,  370,   20])
res: [23, 191, 0, 290, 915, 4544, 128, 4454, 404, 25]
ind: tensor([  0, 567, 356,   2, 100, 915, 537, 222, 111,   7])
res: [0, 586, 371, 3, 111, 942, 555, 237, 123, 9]


 62%|██████████████████████████████████████████████████▏                              | 62/100 [00:03<00:02, 14.53it/s]

ind: tensor([  17,    0,  558,   12,   91,  347,  100, 1219,    2,  144])
res: [25, 0, 586, 17, 102, 371, 112, 1255, 3, 159]
ind: tensor([   0,  285,  100,  103,   15,  638,   55, 1976,  484, 1469])
res: [0, 322, 120, 123, 17, 690, 63, 2054, 533, 1536]
ind: tensor([   0,   46,   22,  214,    6,  344, 4173,  642, 4485,  113])
res: [0, 53, 25, 237, 7, 377, 4283, 690, 4601, 124]
ind: tensor([  66,    0,   35,  530, 2250,  150, 1646,  191,  122,  640])
res: [93, 0, 53, 586, 2338, 192, 1728, 237, 161, 700]


 66%|█████████████████████████████████████████████████████▍                           | 66/100 [00:04<00:02, 15.84it/s]

ind: tensor([  0,   6,  88, 524,  91, 260, 256,  56, 191, 878])
res: [0, 15, 121, 586, 124, 309, 305, 84, 237, 948]
ind: tensor([ 515,  310,   91,  174,    0, 3967,    9,  498,  184,   14])
res: [586, 371, 135, 227, 0, 4067, 13, 569, 237, 19]
ind: tensor([ 22,   0,  81, 514,  89, 105, 543, 156, 277,  16])
res: [25, 0, 93, 555, 102, 121, 586, 177, 309, 19]


 68%|███████████████████████████████████████████████████████                          | 68/100 [00:04<00:02, 14.82it/s]

ind: tensor([  49,   21,  518,   45,  276,    2, 1327,   97,  212,  532])
res: [57, 25, 555, 53, 305, 3, 1384, 111, 237, 569]
ind: tensor([ 25,  23,  53,   9, 102,   7,  50, 110, 514, 959])
res: [27, 25, 57, 10, 112, 7, 53, 120, 546, 1004]
ind: tensor([ 534,  132, 1308,  254,  452,   40,  459,  330,   10,    6])
res: [586, 159, 1384, 290, 500, 57, 508, 371, 17, 10]


 72%|██████████████████████████████████████████████████████████▎                      | 72/100 [00:04<00:01, 14.57it/s]

ind: tensor([ 47,   0, 852, 138, 523,  43, 303,  68,   3,  12])
res: [57, 0, 894, 159, 555, 53, 334, 80, 3, 17]
ind: tensor([ 474,   36, 2433, 1303,  910,    0, 4763,    4,   92, 2531])
res: [533, 57, 2526, 1384, 982, 0, 4861, 10, 123, 2626]
ind: tensor([1499,    0,   48, 5942,  911,  940,  235,    7, 4393,   11])
res: [1538, 0, 53, 6011, 942, 972, 250, 10, 4454, 15]


 74%|███████████████████████████████████████████████████████████▉                     | 74/100 [00:04<00:01, 14.65it/s]

ind: tensor([  44,   40,  103, 6733,   76,   84,   95,  844,  105,  954])
res: [57, 53, 121, 6819, 93, 101, 112, 894, 123, 1004]
ind: tensor([ 638,  369,   20,    0,    1,  231,  203,   18, 3553, 3974])
res: [690, 404, 25, 0, 1, 257, 227, 23, 3642, 4067]
ind: tensor([  15,  304,  507,   17,  269,   81,  114,  242, 1189,   22])
res: [39, 371, 586, 41, 334, 124, 159, 305, 1298, 53]


 78%|███████████████████████████████████████████████████████████████▏                 | 78/100 [00:05<00:01, 14.75it/s]

ind: tensor([   9, 1319,   73,  447,  611,   13,    0,  129,  532,  127])
res: [17, 1384, 100, 500, 667, 22, 0, 161, 586, 159]
ind: tensor([531,  40,   6,  71, 863,  10,   0, 322, 217,   9])
res: [586, 57, 10, 93, 923, 17, 1, 371, 257, 15]
ind: tensor([  0, 332,  94,  40,  15, 536, 138,  84, 773,  75])
res: [0, 371, 112, 53, 25, 586, 161, 102, 827, 93]
ind: 

 80%|████████████████████████████████████████████████████████████████▊                | 80/100 [00:05<00:01, 14.10it/s]

tensor([5902, 1070,   80,   25,   74,   28,  205,  249,  518,   56])
res: [6011, 1153, 120, 53, 112, 57, 257, 305, 586, 93]
ind: tensor([   6,  325,    0, 1283,  136,  526,   68,  448, 1156,  738])
res: [10, 371, 0, 1384, 161, 586, 86, 500, 1255, 819]
ind: tensor([493,  42,  38,  15, 351,  90,  94, 253, 381, 820])
res: [555, 57, 53, 25, 404, 120, 124, 305, 436, 894]


 84%|████████████████████████████████████████████████████████████████████             | 84/100 [00:05<00:01, 13.61it/s]

ind: tensor([  35,   78,   87,   90,  855,  613,  497,   73, 1059,  338])
res: [57, 111, 123, 126, 942, 690, 569, 105, 1153, 399]
ind: tensor([  0, 223,   9,  95, 268, 115,  15,  53, 249,  35])
res: [0, 237, 10, 105, 282, 126, 17, 57, 263, 38]
ind: tensor([  0, 522, 314, 596,  75,  12,  11,  22,  73, 666])
res: [0, 586, 371, 663, 102, 20, 19, 34, 100, 735]


 86%|█████████████████████████████████████████████████████████████████████▋           | 86/100 [00:05<00:01, 13.11it/s]

ind: tensor([   2,  465,  565, 1111,   32,   65, 2290, 5529,    4,    7])
res: [7, 586, 690, 1255, 65, 120, 2448, 5698, 10, 18]
ind: tensor([   0, 2908,   51, 4147,   14, 1485,  211,   45, 2492,   19])
res: [0, 3001, 65, 4250, 19, 1565, 250, 57, 2583, 25]
ind: tensor([552,   0, 108,  20,  15,   8, 142, 190,  90, 106])
res: [586, 0, 123, 22, 17, 10, 159, 210, 102, 121]


 90%|████████████████████████████████████████████████████████████████████████▉        | 90/100 [00:05<00:00, 13.13it/s]

ind: tensor([ 514,   32,   61,  162,  249, 5940,  595,   67,   16, 8849])
res: [586, 57, 93, 210, 309, 6032, 667, 102, 34, 8945]
ind: tensor([  0, 550, 336, 369,  77,  37, 205, 178,   7,  10])
res: [0, 586, 371, 404, 102, 57, 237, 210, 15, 19]
ind: tensor([ 500,   30,   75,  530, 3706,  193,   83,  244,   40,   34])
res: [555, 53, 102, 586, 3791, 231, 112, 290, 65, 57]


 92%|██████████████████████████████████████████████████████████████████████████▌      | 92/100 [00:06<00:00, 14.03it/s]

ind: tensor([   4,  251,   17,  101,   24,  137,   66, 3923,  205,  195])
res: [7, 290, 25, 126, 34, 167, 84, 4067, 237, 227]
ind: tensor([  20,   22,   18,  639, 1950,  892, 1315, 1585,   14,   46])
res: [25, 27, 23, 690, 2027, 952, 1384, 1659, 17, 53]
ind: tensor([   0,   25,  349, 7335,  117,  880,  120,   57,   23,  529])
res: [0, 25, 371, 7406, 123, 912, 126, 57, 23, 555]


 96%|█████████████████████████████████████████████████████████████████████████████▊   | 96/100 [00:06<00:00, 14.01it/s]

ind: tensor([   0,   35, 1501,   43,  153, 1321,   14,  497,   12,   15])
res: [0, 48, 1565, 57, 180, 1384, 19, 545, 16, 20]
ind: tensor([  43,  276,  370,   86,   15,  136,  647, 1199,  902,  624])
res: [53, 305, 404, 101, 23, 159, 690, 1255, 952, 667]
ind: tensor([   0,  521, 1180, 9510,  117, 1715,    4, 6965, 1078,  218])
res: [0, 586, 1255, 9605, 160, 1795, 9, 7058, 1153, 268]


 98%|███████████████████████████████████████████████████████████████████████████████▍ | 98/100 [00:06<00:00, 13.64it/s]

ind: tensor([   0,    8,   43, 1406,    6, 2450,  541, 9351,   17, 1314])
res: [0, 10, 53, 1478, 7, 2526, 586, 9437, 20, 1384]
ind: tensor([ 648,    6,    0, 1486, 2121,  297,  907, 6513,  281, 5774])
res: [690, 7, 0, 1536, 2176, 322, 952, 6592, 305, 5850]
ind: tensor([   0,   57,    7,  126,   17, 1701,  263,   48,  179,  100])
res: [0, 57, 7, 126, 17, 1702, 263, 48, 179, 100]


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.94it/s]

ind: tensor([  0,  25,   7, 371, 120,  57,  10,  23,  17, 161])
res: [0, 25, 7, 371, 120, 57, 10, 23, 17, 161]





In [24]:
predict_item_id

tensor([  0,  25,   7, 371, 120,  57,  10,  23,  17, 161])

In [25]:
res

[0, 25, 7, 371, 120, 57, 10, 23, 17, 161]