In [1]:
import sys

root_dir = '../../'
if root_dir not in sys.path:
    sys.path.append(root_dir)

import torch
from torch import nn, optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import TruncatedSVD
from scipy.sparse import csr_matrix

from modules import losses, models, samplers, searches, regularizers, evaluators, trainers, datasets, distributions

In [2]:
dataset = datasets.ML100k()
n_user = dataset.n_user
n_item = dataset.n_item
train_set, test_set = dataset.get_train_and_test_set()

# device setting
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_set = torch.LongTensor(train_set).to(device)
test_set = torch.FloatTensor(test_set).to(device)

In [3]:
# k
ks = [5, 10, 50]

score_function_dict = {
    "Recall"      : evaluators.recall,
    "Unpopularity": evaluators.unpopularity,
    "Unpopularity2": evaluators.unpopularity2,
    "Unpopularity3": evaluators.unpopularity3,
    "F1-score"    : evaluators.f1_score,
    "F1-score2"    : evaluators.f1_score2,
    "F1-score3"    : evaluators.f1_score3
}
userwise = evaluators.UserwiseEvaluator(test_set, score_function_dict, ks)
# coverage = evaluators.CoverageEvaluator(test_set, ks)
# hubness = evaluators.HubnessEvaluator(test_set, ks)

In [4]:
sampler = samplers.BaseSampler(train_set, n_user, n_item, device=device, strict_negative=False)

In [5]:
# Hyperparameters
lr = 1e-3
n_dim = 10
n_batch = 256
n_epoch = 30
valid_per_epoch = 10
n_item_sample = 30
n_user_sample = 30
no_progressbar = False

search_range = 30

# models
model = models.CollaborativeMetricLearning(n_user, n_item, n_dim).to(device)

# distributiuons
gaussian = distributions.Gaussian()
gamma = distributions.Gamma()

# search
knn = searches.NearestNeighborhood(model)
mp = searches.MutualProximity(model, gamma)
mymp = searches.Mymp(model, search_range)

# learning late optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

# loss function
criterion = losses.SumTripletLoss(margin=1).to(device)

# trainer
trainer = trainers.BaseTrainer(model, optimizer, criterion, sampler, no_progressbar)

In [6]:
trainer.fit(n_batch, n_epoch, knn, userwise, valid_per_epoch)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 940/940 [00:30<00:00, 30.64it/s]
epoch1 avg_loss:0.936: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:02<00:00, 92.80it/s]
epoch2 avg_loss:0.778: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:02<00:00, 87.63it/s]
epoch3 avg_loss:0.685: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:02<00:00, 92.85it/s]
epoch4 avg_loss:0.625: 100%|████████████████████████████████████████████████████████████

In [8]:
display(trainer.valid_scores)

Unnamed: 0,Recall@5,Unpopularity@5,Unpopularity2@5,Unpopularity3@5,F1-score@5,F1-score2@5,F1-score3@5,Recall@10,Unpopularity@10,Unpopularity2@10,...,F1-score3@10,Recall@50,Unpopularity@50,Unpopularity2@50,Unpopularity3@50,F1-score@50,F1-score2@50,F1-score3@50,epoch,losses
0,0.060601,0.969027,6.604254,0.062687,0.095637,0.115694,0.030846,0.124173,0.969471,6.611089,...,0.093141,0.522777,0.917356,6.282453,0.525055,0.576298,0.920117,0.514344,0,
0,0.199701,0.904484,3.800076,0.015353,0.282366,0.357725,0.017315,0.376784,0.898409,3.74017,...,0.054212,0.812421,0.883639,4.674655,0.415116,0.822434,1.372523,0.472772,10,0.470439
0,0.319594,0.872723,3.240209,0.011262,0.416526,0.552662,0.015858,0.492825,0.883725,3.457373,...,0.058786,0.856281,0.884782,4.618482,0.425086,0.851559,1.435081,0.492658,20,0.301154
0,0.376933,0.872549,3.250793,0.013281,0.47149,0.645983,0.022692,0.539528,0.887942,3.565128,...,0.073861,0.877089,0.887067,4.703218,0.434348,0.865,1.469327,0.507297,30,0.254217


In [None]:
trainer.valid(mymp, userwise)

In [None]:
trainer.valid_scores

In [None]:
mp2 = searches.MutualProximity(model, gaussian)

In [None]:
trainer.valid(mp, recall)

In [None]:
trainer.valid_scores