In [6]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data as Data
import pandas as pd
import numpy as np
from collections import defaultdict

In [8]:
train_df = pd.read_csv("~/Thesis/Data/train.csv")
ratings = pd.read_csv('~/Data/clean_rating4.csv').drop(["id"], axis = 1)

idx_to_animes = list(set(ratings['anime_id'].tolist()))
idx_to_users = list(set(ratings['user_id'].tolist()))
anime_to_idx = {anime: idx for idx, anime in enumerate(idx_to_animes)}
user_to_idx = {user: idx for idx, user in enumerate(idx_to_users)}
num_users, num_animes = len(idx_to_users), len(idx_to_animes)

train_data_raw = train_df.values.tolist()
train_users = list(set(train_df.user))
user_item_dic = defaultdict(list)

train_data = []
for d in train_data_raw:
    train_data.append([user_to_idx[d[0]], anime_to_idx[d[1]]])
    user_item_dic[user_to_idx[d[0]]].append(anime_to_idx[d[1]])

In [11]:
class BPRDataset(Data.Dataset):
    def __init__(self, data, num_item, num_ng, dic):
        super(BPRDataset, self).__init__()
        self.data = data
        self.num_ng = num_ng
        self.dic = dic
        self.num_item = num_item
        
    def select_ng(self):
        self.new_data = []
        for d in self.data:
            ng_num = 0
            while ng_num < self.num_ng:
                item = np.random.randint(self.num_item)   
                if item not in self.dic[d[0]]:
                    self.new_data.append([d[0], d[1], item])
                    ng_num += 1
                    
    def __len__(self):
        return self.num_ng * len(self.data)
    
    def __getitem__(self, idx):
        user = self.data[idx][0]
        item_i = self.data[idx][1]
        item_j = self.data[idx][2]
        
        return user, item_i, item_j

In [12]:
class BPR(nn.Module):
    def __init__(self, num_users, num_animes, num_hidden):
        self.user_embed = nn.Embedding(num_users, num_hidden)
        self.anime_embed = nn.Embedding(num_anime, num_hidden)
        
    def forward(self, user, anime_i, anime_j):
        point_i = torch.mm(self.user_embed(user), self.anime_embed(anime_i))
        point_j = torch.mm(self.user_embed(user), self.anime_embed(anime_j))
        
        return point_i, point_j

In [None]:
batchsize = 5120
train_dataset = Data.TensorDataset