In [1]:
"""
Algorithm flow

while not converged:
    take s1[idx of positive example], s2[C1*len(s2) negative example], target[1...1,0...0]
    get the user's item_vec from model and extract s1, s2 part as model's output
    calculate loss (BCEWithLogitsLoss)
    update model
    

ISSUE: 
1. Batch?
2. Validation set?

"""

"\nAlgorithm flow\n\nwhile not converged:\n    take s1[idx of positive example], s2[C1*len(s2) negative example], target[1...1,0...0]\n    get the user's item_vec from model and extract s1, s2 part as model's output\n    calculate loss (BCEWithLogitsLoss)\n    update model\n    \n\nISSUE: \n1. Batch?\n2. Validation set?\n\n"

In [2]:
"""
Params
"""

train_file = "../dataset/train.csv"


In [3]:
"""Use gpu"""
import torch
print(torch.cuda.is_available())
torch.cuda.set_device(0)
torch.set_default_tensor_type('torch.cuda.FloatTensor')

True


In [3]:
"""
PIR_Dataset
"""
import torch
from torch.utils.data import Dataset
import pandas as pd
from random import sample, shuffle

C1 = 1

class PIR_Dataset(Dataset):

    def __init__(self, csv_file, valid_split=0.1):
        user_dataframe = pd.read_csv(csv_file)
        self.train_lists, self.valid_lists = self._get_pos_lists(user_dataframe, valid_split)
        self.num_items = max([max(l) for l in self.train_lists])+1
        self.num_users = len(self.train_lists)
        self.neg_lists = self._get_neg_lists()
        self.train_targets, self.valid_targets = self._get_targets()
        self.is_train = True
        
    def set_is_train(self, is_train):
        self.is_train = is_train
    
    def __len__(self):
        return len(self.train_lists)

    def __getitem__(self, idx):
        if self.is_train == True:
            s1 = self.train_lists[idx]
            s2 = sample(self.neg_lists[idx], C1*len(self.train_lists[idx]))
            return s1, s2, self.train_targets[idx]
        else:
            # might sample neg_data from validation set but ignore for now
            s1 = self.valid_lists[idx]
            s2 = sample(self.neg_lists[idx], C1*len(self.valid_lists[idx]))
            return s1, s2, self.valid_targets[idx]
    
    def _get_pos_lists(self, df, valid_split):
        tot_lists = [list(map(int, item_str.split())) for item_str in df["ItemId"]]
        train_lists, valid_lists = [], []
        for l in tot_lists:
            shuffle(l)
            idx = int(len(l)*valid_split) + 1
            valid_lists.append(sorted(l[:idx]))
            train_lists.append(sorted(l[idx:]))
        return train_lists, valid_lists
    
    def _get_neg_lists(self):
        neg_lists = [[] for _ in range(self.num_users)]
        for user_id in range(self.num_users):
            pos_iter = 0
            for item_id in range(self.num_items):
                if item_id == self.train_lists[user_id][pos_iter]:
                    pos_iter += (pos_iter != len(self.train_lists[user_id])-1)
                else:
                    neg_lists[user_id].append(item_id)
        return neg_lists
    
    def _get_targets(self):
        return [torch.cat((torch.ones(len(l)), torch.zeros(C1*len(l)))) for l in self.train_lists], \
                [torch.cat((torch.ones(len(l)), torch.zeros(C1*len(l)))) for l in self.valid_lists]
    

In [4]:
"""
PIR_Model
"""
import torch

NUM_TOPICS = 256

class PIR_Model(torch.nn.Module):
    
    def __init__(self, num_users, num_items, topics=NUM_TOPICS):
        super().__init__()
        self.user_matrix = torch.randn((num_users, topics), requires_grad=True)
        self.item_matrix = torch.randn((topics, num_items), requires_grad=True)
        self.item_bias = torch.randn((num_items,), requires_grad=True)
        
    def forward(self, user_id, pos_idx, neg_idx):
        pos_tensor = torch.matmul(self.user_matrix[user_id], self.item_matrix[:,pos_idx]) + self.item_bias[pos_idx]
        neg_tensor = torch.matmul(self.user_matrix[user_id], self.item_matrix[:,neg_idx]) + self.item_bias[neg_idx]
        if pos_tensor.dim() == 1:
            return torch.cat((pos_tensor, neg_tensor))
        else:
            return torch.cat((pos_tensor, neg_tensor), axis=1)
    

In [5]:
"""
PIR_Loss
"""



class BPR_Loss(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.m = torch.nn.LogSigmoid()
    
    def forward(self, x, y):
        pos = x[:len(x)//2]
        neg = x[len(x)//2:]
        return -torch.mean(self.m((pos-neg)))
    

In [47]:
"""
Calculate F1/MAP score
"""

import numpy as np
import random

def calc_score(dataset, model):
    with torch.no_grad():
        matrix = torch.matmul(model.user_matrix, model.item_matrix).cpu().detach().numpy()
        sum_AP = 0
        for i, row in enumerate(matrix):
            if dataset.is_train:
                golden_set = set(dataset.train_lists[i])
            else:
                golden_set = set(dataset.valid_lists[i])
                row[dataset.train_lists[i]] = np.min(row)
            rank = np.flip(np.argsort(row))
            n_recom = 50
            AP = 0
            true_pos = 0
            # use all pos (including test data) to calculate MAP (?)
            n_pos = int((len(dataset.train_lists[i]) + len(dataset.valid_lists[i]))/0.89)
            for i, idx in enumerate(rank[:n_recom]):
                if idx in golden_set:
                    true_pos += 1
                    AP += (true_pos/(i+1))/n_pos
            sum_AP += AP
    return sum_AP/len(dataset)
"""
torch.manual_seed(66)
random.seed(66)
dataset = PIR_Dataset(train_file, 0.1)
dataset.is_train = True
user_file = "../model/user0.196863.pt"
item_file = "../model/item0.196863.pt"
model = PIR_Model(dataset.num_users, dataset.num_items)
model.user_matrix = torch.load(user_file)
model.item_matrix = torch.load(item_file)
print(calc_score(dataset, model))
""""""

0.3304835278205502


In [24]:
"""
PIR_Trainer
"""

import torch.optim as optim
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm
import random
import os

ITER = 1000
torch.manual_seed(66)
random.seed(66)
checkpoint_dir = "../model/"
dataset = PIR_Dataset(train_file, 0.1)
model = PIR_Model(dataset.num_users, dataset.num_items, 64)
loss_func = BPR_Loss()
# loss_func = BCEWithLogitsLoss()
optimizer = optim.Adam([model.user_matrix, model.item_matrix], lr=1e-3, weight_decay=1e-6)

prev_valid_loss = 100000
valid_tolerance = 3
idxs = [i for i in range(len(dataset))]


for i in range(ITER):
    if valid_tolerance == 0:
        break
    train_loss = 0
    valid_loss = 0
    random.shuffle(idxs)
    dataset.set_is_train(True)
    print("training...")
    for uid in tqdm(idxs):
        optimizer.zero_grad()
        s1, s2, target = dataset[uid]
        output = model(uid, s1, s2)
        loss = loss_func(output, target)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    train_MAP = calc_score(dataset, model)
    print("validating...")
    dataset.set_is_train(False)
    for uid in tqdm(idxs):
        with torch.no_grad():
            s1, s2, target = dataset[uid]
            output = model(uid, s1, s2)
            loss = loss_func(output, target)
            valid_loss += loss.item()
    valid_MAP = calc_score(dataset, model)
    print("iter {}: training loss={:6.4f}, MAP={:6.4f}; validation loss={:6.4f}, MAP={:6.4f}\n" \
          .format(i, train_loss/len(idxs), train_MAP, valid_loss/len(idxs), valid_MAP))
    if valid_loss < prev_valid_loss:
        if prev_valid_loss != 100000:
            os.remove("{}user{:8.6f}.pt".format(checkpoint_dir, prev_valid_loss/len(idxs)))
            os.remove("{}item{:8.6f}.pt".format(checkpoint_dir, prev_valid_loss/len(idxs)))
        torch.save(model.user_matrix, "{}user{:8.6f}.pt".format(checkpoint_dir, valid_loss/len(idxs)))
        torch.save(model.item_matrix, "{}item{:8.6f}.pt".format(checkpoint_dir, valid_loss/len(idxs)))
        prev_valid_loss = valid_loss
        valid_tolerance = 3
    else:
        valid_tolerance -= 1


training...


100%|█████████████████████████████████████████| 4454/4454 [00:32<00:00, 136.41it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:07<00:00, 581.04it/s]


iter 0: training loss=1.6217548194812668, validation loss=1.5775860477620205

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 119.24it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 540.35it/s]


iter 1: training loss=1.4668522855266093, validation loss=1.5279098685778965

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 119.38it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:10<00:00, 405.49it/s]


iter 2: training loss=1.3219534629729226, validation loss=1.4608534246893758

training...


100%|█████████████████████████████████████████| 4454/4454 [00:42<00:00, 104.49it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 525.90it/s]


iter 3: training loss=1.1863312087618534, validation loss=1.4048646048168767

training...


100%|█████████████████████████████████████████| 4454/4454 [00:34<00:00, 130.72it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:07<00:00, 573.42it/s]


iter 4: training loss=1.0633735688220078, validation loss=1.3298278915629012

training...


100%|█████████████████████████████████████████| 4454/4454 [00:34<00:00, 129.47it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:07<00:00, 557.62it/s]


iter 5: training loss=0.9192286016688012, validation loss=1.18978044054471

training...


100%|█████████████████████████████████████████| 4454/4454 [00:34<00:00, 130.05it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 504.37it/s]


iter 6: training loss=0.7736241694892743, validation loss=1.0388876781655993

training...


100%|█████████████████████████████████████████| 4454/4454 [00:36<00:00, 123.45it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:09<00:00, 477.28it/s]


iter 7: training loss=0.6216919997587406, validation loss=0.8983222126094429

training...


100%|██████████████████████████████████████████| 4454/4454 [00:45<00:00, 96.92it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 517.72it/s]


iter 8: training loss=0.5006579540905781, validation loss=0.7895211712492823

training...


100%|██████████████████████████████████████████| 4454/4454 [00:47<00:00, 92.81it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:17<00:00, 260.67it/s]


iter 9: training loss=0.4035949996035031, validation loss=0.7040074361741516

training...


100%|█████████████████████████████████████████| 4454/4454 [00:40<00:00, 110.95it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:09<00:00, 476.86it/s]


iter 10: training loss=0.3390468520599116, validation loss=0.622198081447732

training...


100%|█████████████████████████████████████████| 4454/4454 [00:40<00:00, 111.29it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 506.62it/s]


iter 11: training loss=0.2890980520362405, validation loss=0.5627782155855403

training...


100%|█████████████████████████████████████████| 4454/4454 [00:42<00:00, 105.71it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 534.32it/s]


iter 12: training loss=0.25389209473472163, validation loss=0.5150779192079588

training...


100%|█████████████████████████████████████████| 4454/4454 [00:36<00:00, 122.95it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:14<00:00, 303.69it/s]


iter 13: training loss=0.22407812250382467, validation loss=0.4843094636171601

training...


100%|█████████████████████████████████████████| 4454/4454 [00:39<00:00, 113.95it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 521.11it/s]


iter 14: training loss=0.20334563778068696, validation loss=0.45050248399965437

training...


100%|█████████████████████████████████████████| 4454/4454 [00:35<00:00, 125.82it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 507.68it/s]


iter 15: training loss=0.18725851027007065, validation loss=0.42506419148213564

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 117.25it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 499.67it/s]


iter 16: training loss=0.17518341281893737, validation loss=0.39939708850512035

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 117.89it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:09<00:00, 489.93it/s]


iter 17: training loss=0.16208721048614327, validation loss=0.3786516837120835

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 118.17it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 506.01it/s]


iter 18: training loss=0.15253366790021594, validation loss=0.3449664248587153

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 118.47it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:09<00:00, 460.52it/s]


iter 19: training loss=0.14564369775922667, validation loss=0.33639412855294226

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 119.59it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:10<00:00, 427.44it/s]


iter 20: training loss=0.13754286962271917, validation loss=0.31891907643563455

training...


100%|█████████████████████████████████████████| 4454/4454 [00:42<00:00, 105.77it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 497.28it/s]


iter 21: training loss=0.13100512031149822, validation loss=0.3017358367947138

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 119.73it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 521.02it/s]


iter 22: training loss=0.12614036695544492, validation loss=0.30111449436178106

training...


100%|█████████████████████████████████████████| 4454/4454 [00:35<00:00, 124.39it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:09<00:00, 476.80it/s]


iter 23: training loss=0.1224002966679602, validation loss=0.2831491658408245

training...


100%|█████████████████████████████████████████| 4454/4454 [00:36<00:00, 122.82it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 507.35it/s]


iter 24: training loss=0.11761753275689595, validation loss=0.2721188021249071

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 119.29it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:10<00:00, 441.47it/s]


iter 25: training loss=0.11478977146942995, validation loss=0.26083441847837774

training...


100%|█████████████████████████████████████████| 4454/4454 [00:40<00:00, 110.52it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:09<00:00, 476.26it/s]


iter 26: training loss=0.11273673193409645, validation loss=0.25255259348456827

training...


100%|█████████████████████████████████████████| 4454/4454 [00:38<00:00, 116.22it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:09<00:00, 465.17it/s]


iter 27: training loss=0.11059025516295466, validation loss=0.24664432732290886

training...


100%|█████████████████████████████████████████| 4454/4454 [00:36<00:00, 122.73it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:10<00:00, 423.08it/s]


iter 28: training loss=0.10769534905699969, validation loss=0.2415355773794373

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 118.31it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 531.17it/s]


iter 29: training loss=0.1051659430246499, validation loss=0.24275661426686437

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 119.61it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 535.49it/s]


iter 30: training loss=0.10474874216071646, validation loss=0.23329253488661314

training...


100%|█████████████████████████████████████████| 4454/4454 [00:34<00:00, 130.64it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:09<00:00, 474.78it/s]


iter 31: training loss=0.10278402832458, validation loss=0.23106560617694258

training...


100%|█████████████████████████████████████████| 4454/4454 [00:41<00:00, 107.13it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 525.23it/s]


iter 32: training loss=0.10146667064177178, validation loss=0.2299151933998553

training...


100%|█████████████████████████████████████████| 4454/4454 [00:35<00:00, 125.76it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 513.67it/s]


iter 33: training loss=0.10009242648168475, validation loss=0.22393614631051706

training...


100%|█████████████████████████████████████████| 4454/4454 [00:42<00:00, 104.68it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 495.28it/s]


iter 34: training loss=0.09935526418999184, validation loss=0.21616008139706488

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 117.68it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 536.63it/s]


iter 35: training loss=0.09922050018739424, validation loss=0.21271363759797474

training...


100%|█████████████████████████████████████████| 4454/4454 [00:35<00:00, 125.48it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:09<00:00, 481.67it/s]


iter 36: training loss=0.0970301750520864, validation loss=0.21189431844516404

training...


100%|█████████████████████████████████████████| 4454/4454 [00:34<00:00, 128.02it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 516.02it/s]


iter 37: training loss=0.09698386388106525, validation loss=0.20724874349713457

training...


100%|█████████████████████████████████████████| 4454/4454 [00:34<00:00, 130.35it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 503.33it/s]


iter 38: training loss=0.0958092597005433, validation loss=0.20777309742345126

training...


100%|█████████████████████████████████████████| 4454/4454 [00:36<00:00, 120.43it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 541.45it/s]


iter 39: training loss=0.09558831681568489, validation loss=0.20272552314476616

training...


100%|█████████████████████████████████████████| 4454/4454 [00:37<00:00, 118.44it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 500.00it/s]


iter 40: training loss=0.0950570256003067, validation loss=0.21080071772485506

training...


100%|█████████████████████████████████████████| 4454/4454 [00:35<00:00, 124.79it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 519.70it/s]


iter 41: training loss=0.09504774448553416, validation loss=0.2093686984077013

training...


100%|█████████████████████████████████████████| 4454/4454 [00:36<00:00, 122.84it/s]


validating...


100%|█████████████████████████████████████████| 4454/4454 [00:08<00:00, 532.90it/s]


iter 42: training loss=0.09445019559857369, validation loss=0.20847767490614091



In [48]:
"""
Evaluate and Test
"""

import numpy as np

output_file = "../outputs/out.csv"
user_file = ["../model/user0.196863.pt", "../model/user0.199941.pt"]
item_file = ["../model/item0.196863.pt", "../model/item0.199941.pt"]
model = PIR_Model(dataset.num_users, dataset.num_items)
matrix = np.zeros((dataset.num_users, dataset.num_items))
for f1, f2 in zip(user_file, item_file):
    matrix += torch.matmul(torch.load(f1), torch.load(f2)).cpu().detach().numpy()
sum_true_pos = 0
sum_poss_pos = 0

with open(output_file, "w") as f:
    f.write("UserId,ItemId\n")
#     matrix = torch.matmul(model.user_matrix, model.item_matrix).cpu().detach().numpy()
    for i, row in enumerate(matrix):
        f.write("{},".format(i))
        rank = np.flip(np.argsort(row))
        pos_set = set(dataset.train_lists[i]+dataset.valid_lists[i])
        n_recom = 50
        sum_poss_pos += n_recom
        for idx in rank:
            if n_recom == 0:
                break
            if idx not in pos_set:
                n_recom -= 1
                f.write("{} ".format(idx))
            else:
                sum_true_pos += 1
                
        f.write("\n")            
print(sum_true_pos, sum_poss_pos)

180570 222700


In [None]:
"""
Test optimizing ability after indexing the tensor
"""

import torch
import torch.nn as nn
import torch.optim as optim

loss = nn.BCEWithLogitsLoss()
input_prev = torch.tensor([[1.0,1.0,0.0,0.0,1.0,1.0], [1.0,1.0,0.0,0.0,1.0,1.0]], requires_grad=True)
optimizer = optim.Adam([input_prev], lr=0.0001)

input = input_prev[[[0,0,0], [0,3,4]]]
target = torch.empty(3).random_(2)

print(input_prev)
print(input)
print(target)

output = loss(input, target)
output.backward()
optimizer.step()

print(input_prev)
print(input)
print(target)
print(output)