# Recommendations in pytorch using triplet loss
Along the lines of BPR [1]. 

[1] Rendle, Steffen, et al. "BPR: Bayesian personalized ranking from implicit feedback." Proceedings of the Twenty-Fifth Conference on Uncertainty in Artificial Intelligence. AUAI Press, 2009.

This is implemented (more efficiently) in LightFM (https://github.com/lyst/lightfm). See the MovieLens example (https://github.com/lyst/lightfm/blob/master/examples/movielens/example.ipynb) for results comparable to this notebook.

## Set up the architecture
A simple dense layer for both users and items: this is exactly equivalent to latent factor matrix when multiplied by binary user and item indices. There are three inputs: users, positive items, and negative items. In the triplet objective we try to make the positive item rank higher than the negative item for that user.

Because we want just one single embedding for the items, we use shared weights for the positive and negative item inputs (a siamese architecture).

This is all very simple but could be made arbitrarily complex, with more layers, conv layers and so on. I expect we'll be seeing a lot of papers doing just that.


In [1]:
from __future__ import print_function

import numpy as np
import itertools
import data
import metrics

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [2]:
n_latent = 100
batch_size = 64

# Read data
train, test = data.get_movielens_data()
n_users, n_items = train.shape

uid, pid, nid = data.get_triplets(train)  # user, positive_item, negative_item
test_uid, test_pid, test_nid = data.get_triplets(test)
print(n_users, n_items)
print(len(uid), len(test_uid))

944 1683
49906 5469


In [3]:
class FactNet(nn.Module):
    def __init__(self,
                 n_users, n_items,
                 n_latent,
                ):
        super(FactNet, self).__init__()
        self.user_embedding_layer = nn.Embedding(n_users, n_latent)
        self.item_embedding_layer = nn.Embedding(n_items, n_latent)  # both pos and neg items share these params
        init.uniform(net.user_embedding_layer.weight, -0.5, 0.5)
        init.uniform(net.item_embedding_layer.weight, -0.5, 0.5)  # default was normal
        
    def predict_score(self, uid, iid):
        # TODO: check to see if this handles multiple users and multiple items correctly
        user_embedding = self.user_embedding_layer(uid)
        item_embedding = self.item_embedding_layer(iid)
        score = (user_embedding * item_embedding).sum(dim=1)
        return score
    
        
    def forward(self, uid, pid, nid):
        # lulzy forward for loss computation
        user_embedding = self.user_embedding_layer(uid)
        pos_item_embedding = self.item_embedding_layer(pid)
        neg_item_embedding = self.item_embedding_layer(nid)
        
        # torch.dot doesnt take in axis :sadface:
        pos_pred = (user_embedding * pos_item_embedding).sum(dim=1)
        neg_pred = (user_embedding * neg_item_embedding).sum(dim=1)
        return pos_pred, neg_pred

    
class TripletBPRLoss(nn.Module):
    def __init__(self):
        super(TripletBPRLoss, self).__init__()

    def forward(self, pos_pred, neg_pred):
        loss = 1.0 - torch.sigmoid(pos_pred - neg_pred)
        return loss.mean()

net = FactNet(n_users, n_items, n_latent)
criterion = TripletBPRLoss()
optimizer = optim.Adam(net.parameters())
# optimizer = optim.Adadelta(net.parameters())
optimizer = optim.Adagrad(net.parameters())

In [6]:
from sklearn.metrics import roc_auc_score
def full_auc(net, ground_truth):
    ground_truth = ground_truth.tocsr()

    no_users, no_items = ground_truth.shape

    pid_arr = np.arange(no_items, dtype=np.int64)

    scores = []

    for user_id, row in enumerate(ground_truth):
        user_arr = np.array([user_id]*len(pid_arr), dtype=np.int64)
        user_input = Variable(torch.LongTensor(user_arr))
        items_input = Variable(torch.LongTensor(pid_arr))
        
        predictions = net.predict_score(user_input, items_input)
        preds_arr = np.squeeze(predictions.data.numpy())

        true_pids = row.indices[row.data == 1]
        if len(true_pids):
            scores.append(roc_auc_score(row.toarray()[0].astype(bool),
                                        preds_arr))

    return np.mean(scores)

In [7]:
def batcher(iterable, batch_size):
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, batch_size))
        if not chunk:
            return
        yield np.array(chunk, dtype=np.int64)

In [8]:
%%time
num_epochs = 10

for epoch in range(num_epochs):

    print('Epoch %s' % epoch, end='\t')

    # Sample triplets from the training data
    uid, pid, nid = data.get_triplets(train)
    shuffle_ind = np.arange(len(uid))
#     np.random.shuffle(shuffle_ind)
    uid_gen = batcher(uid[shuffle_ind], batch_size)
    pid_gen = batcher(pid[shuffle_ind], batch_size)
    nid_gen = batcher(nid[shuffle_ind], batch_size)
    for uid_batch, pid_batch, nid_batch in zip(uid_gen, pid_gen, nid_gen):
    
        user_input = Variable(torch.LongTensor(uid_batch))
        p_item_input = Variable(torch.LongTensor(pid_batch))
        n_item_input = Variable(torch.LongTensor(nid_batch))

        out = net(user_input, p_item_input, n_item_input)
        loss = criterion(*out)
        loss.backward()
        optimizer.step()

    print('AUC: train:{}\t test:{}'.format(full_auc(net, train), full_auc(net, test)))

Epoch 0	AUC: train:0.851437926336	 test:0.708176292295
Epoch 1	AUC: train:0.905246953711	 test:0.786664742913
Epoch 2	AUC: train:0.917485025308	 test:0.813266294433
Epoch 3	AUC: train:0.921757169089	 test:0.824214443
Epoch 4	AUC: train:0.924242200149	 test:0.830692094282
Epoch 5	AUC: train:0.925970550503	 test:0.834588743087
Epoch 6	AUC: train:0.92756677899	 test:0.837254668446
Epoch 7	AUC: train:0.929186087453	 test:0.839445715136
Epoch 8	AUC: train:0.930479078246	 test:0.841002026467
Epoch 9	AUC: train:0.931517076017	 test:0.842274550028
CPU times: user 51.9 s, sys: 6.24 s, total: 58.1 s
Wall time: 41.7 s
