In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data as Data
import pandas as pd
import numpy as np
from collections import defaultdict

In [2]:
train_df = pd.read_csv("~/Thesis/Data/train.csv")
ratings = pd.read_csv('~/Data/clean_rating4.csv').drop(["id"], axis = 1)

idx_to_animes = list(set(ratings['anime_id'].tolist()))
idx_to_users = list(set(ratings['user_id'].tolist()))
anime_to_idx = {anime: idx for idx, anime in enumerate(idx_to_animes)}
user_to_idx = {user: idx for idx, user in enumerate(idx_to_users)}
num_users, num_animes = len(idx_to_users), len(idx_to_animes)

In [3]:
train_data_raw = train_df.values.tolist()
train_users = list(set(train_df.user))
user_item_dic = defaultdict(list)

train_data = []
for d in train_data_raw:
    train_data.append([user_to_idx[d[0]], anime_to_idx[d[1]]])
    user_item_dic[user_to_idx[d[0]]].append(anime_to_idx[d[1]])

In [3]:
class BPRDataset(Data.Dataset):
    def __init__(self, data, num_item, num_ng, dic):
        super(BPRDataset, self).__init__()
        self.data = data
        self.num_ng = num_ng
        self.dic = dic
        self.num_item = num_item
        
    def select_ng(self):
        self.new_data = []
        for idx, d in enumerate(self.data):
            if not idx % 100000: print(idx)
            ng_num = 0
            while ng_num < self.num_ng:
                item = np.random.randint(self.num_item)   
                if item not in self.dic[d[0]]:
                    self.new_data.append([d[0], d[1], item])
                    ng_num += 1
                    
    def __len__(self):
        return self.num_ng * len(self.data)
    
    def __getitem__(self, idx):
        user = self.new_data[idx][0]
        item_i = self.new_data[idx][1]
        item_j = self.new_data[idx][2]
        
        return user, item_i, item_j

In [4]:
class BPR(nn.Module):
    def __init__(self, num_users, num_animes, num_hidden):
        super(BPR, self).__init__()
        self.user_embed = nn.Embedding(num_users, num_hidden)
        self.anime_embed = nn.Embedding(num_animes, num_hidden)
        
        nn.init.normal_(self.user_embed.weight, std = 0.01)
        nn.init.normal_(self.anime_embed.weight, std = 0.01)
        
    def forward(self, user, anime_i, anime_j):
        point_i = torch.mm(self.user_embed(user), self.anime_embed(anime_i).permute(1, 0)).sum(dim = -1)
        point_j = torch.mm(self.user_embed(user), self.anime_embed(anime_j).permute(1, 0)).sum(dim = -1)
        
        return point_i, point_j

In [5]:
batch_size = 10000
train_dataset = BPRDataset(train_data, num_animes, 5, user_item_dic)
train_dataset.select_ng()
data_iter = Data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
for x, y, z in data_iter:
    print(x.shape, y.shape, z.shape)
    break

0
100000
200000
300000
400000
500000
600000
700000
800000
900000
1000000
1100000
1200000
1300000
1400000
1500000
1600000
torch.Size([10000]) torch.Size([10000]) torch.Size([10000])


In [9]:
def train(net, lr, num_epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("train on ", device)
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr = lr)
    for epoch in range(num_epochs):
        l_sum, n = 0, 0
        for user, item_i, item_j in data_iter:
            user = user.to(device)
            item_i = item_i.to(device)
            item_j = item_j.to(device)
            
            point_i, point_j = net(user, item_i, item_j)
#             print(point_i.shape, point_j.shape)
            loss = - (point_i - point_j).sigmoid().log().sum()
#             print(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            l_sum += loss.item()
            n += 1
            
        print(epoch + 1, l_sum / n)

In [10]:
bpr_net = BPR(num_users, num_animes, 100)
train(bpr_net, 0.001, 30)

train on  cuda
1 33.288747597236984
2 0.0015058380634647877
3 0.0011498136252565117
4 0.0012791689498957261
5 0.0012105081835363666
6 0.0009266369508736299
7 0.0008918527281764662
8 0.0005438231723212497
9 0.00043889454432896206
10 0.0002939776040986635
11 0.00019839053043370136
12 0.00013277045609167936
13 0.00013184452813769142
14 5.71275541561136e-05
15 4.364641085358039e-05
16 3.4593921313937913e-05
17 1.7141461627096192e-05
18 1.7754453824553298e-05
19 8.77377278056628e-06
20 6.5206203232172496e-06
21 3.6710467285070663e-06
22 2.128887132392407e-06
23 9.958860288136463e-07
24 1.6032993879910645e-06
25 8.65328022175365e-07
26 2.3681753083622318e-07
27 1.3551144126941414e-07
28 1.6156573932258565e-07
29 1.2910704387713482e-07
30 3.9445356009880554e-08


In [11]:
torch.save(bpr_net.state_dict(), "BPR2.pt")

In [5]:
test_df = pd.read_csv("~/Thesis/Data/test.csv")
users = list(set(test_df.user))
bpr_net = BPR(4701, 9775, 100).cuda()
bpr_net.load_state_dict(torch.load("BPR.pt"))

<All keys matched successfully>

In [7]:
def metric(net, bound):
    precision = []
    for user in users:
        animes = test_df[test_df.user == user].anime.tolist()
        user_input = torch.LongTensor([user_to_idx[user] for _ in range(len(animes))]).cuda()
        anime_input = torch.LongTensor([anime_to_idx[i] for i in animes]).cuda()
        
#         print(user_input, anime_input)
        point_i, point_j = net(user_input, anime_input, anime_input)
#         print(point_i.shape)
        _, idx = torch.topk(point_i, 5)
        
        target = test_df[(test_df.user == user) & (test_df.rating > bound)].anime.tolist()
        idx = [animes[i] for i in idx]
        
        precision.append(len(set(idx) & set(target)) / len(idx))
        
    return np.mean(precision)

In [8]:
result = []
for k in range(10):
    result.append(metric(bpr_net, k))
    print(result[-1])

1.0
0.9069559668155712
0.7984258668368431
0.6593065305254201
0.44981918740693466
0.2057009146990002
0.05207402680280791
0.011231652839821315
0.005360561582641991
0.0043395022335673255


In [9]:
print(result)

[1.0, 0.9069559668155712, 0.7984258668368431, 0.6593065305254201, 0.44981918740693466, 0.2057009146990002, 0.05207402680280791, 0.011231652839821315, 0.005360561582641991, 0.0043395022335673255]
