# Import

In [1]:
import sys

root_dir = '../../'
if root_dir not in sys.path:
    sys.path.append(root_dir)

import torch
from torch import nn, optim
import numpy as np
from sklearn.decomposition import TruncatedSVD
from scipy.sparse import csr_matrix

from modules import losses, models, samplers, searches, regularizers, evaluators, trainers, datasets, distributions

# DataSet

In [2]:
dataset = datasets.ML100k()
n_user = dataset.n_user
n_item = dataset.n_item
train_set, test_set = dataset.get_train_and_test_set()

# device setting
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_set = torch.LongTensor(train_set).to(device)
test_set = torch.LongTensor(test_set).to(device)

# Evaluator

In [3]:
ks = [5, 10, 50]
recall = evaluators.RecallEvaluator(test_set, ks)

# Sampler

In [4]:
sampler = samplers.BaseSampler(train_set, n_user, n_item, device=device, strict_negative=False)

# Model

In [5]:
# Hyperparameters
lr = 1e-3
n_dim = 10
n_batch = 256
n_epoch = 50
no_progressbar = True

model = models.CollaborativeMetricLearning(n_user, n_item, n_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = losses.SumTripletLoss(margin=1).to(device)
trainer = trainers.BaseTrainer(model, optimizer, criterion, sampler, no_progressbar)

# Search

In [6]:
# Hyperparameters
n_item_sample = 30
n_user_sample = 30

gaussian = distributions.Gaussian()
gamma = distributions.Gamma()

knn = searches.NearestNeighborhood(model)
mp_gaussian = searches.MutualProximity(model, gamma, n_item_sample, n_user_sample)
mp_gamma = searches.MutualProximity(model, gamma, n_item_sample, n_user_sample)

# Training

In [7]:
# only traing no validation
trainer.fit(n_batch, n_epoch)

# Result

## Base CML

In [8]:
trainer.valid(knn, recall)
display(trainer.valid_scores)

Unnamed: 0,Recall@5,Recall@10,Recall@50
0,0.402732,0.570996,0.890921


## Gaussian

In [9]:
trainer.valid(mp_gaussian, recall)
display(trainer.valid_scores)

Unnamed: 0,Recall@5,Recall@10,Recall@50
0,0.317188,0.48684,0.852539


## Gamma

In [10]:
trainer.valid(mp_gamma, recall)
display(trainer.valid_scores)

Unnamed: 0,Recall@5,Recall@10,Recall@50
0,0.319808,0.489005,0.851772
