# Collaborative Memory Network for Recommendation Systems
**_Ebesu, Shen, Fang - The 41st International ACM SIGIR Conference on Research & Development in Information Retrieval - SIGIR '18_**

[This](https://github.com/IamAdiSri/cmn4recosys) notebook by [**Aditya Srivastava**](https://github.com/IamAdiSri/) is a PyTorch port to the original TensorFlow [project](https://github.com/tebesu/CollaborativeMemoryNetwork).

### Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import pickle
import random
import numpy as np
from tqdm import tqdm_notebook as tqdm

from collections import defaultdict

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### Configuration

In [3]:
class config:
    ssdir = 'snapshots/'
    logdir = 'logs/'
    dataset = 'data/citeulike-a.npz'
    version = 'model_v2'
    pretrain = 'pretrain/citeulike-a_e50_v2.npz' # output/input location for pretrained embeddings
    embed_size = 50
    pretraining_epochs = 15
    pretraining_l2_lambda = 0.001 # l2 regularization for pretraining
    epochs = 30 # training epochs (originally 30)
    batch_size = 128
    hops = 2 # number of hops/layers
    training_l2_lambda = 0.1 # l2 regularization for training
    neg_count = 4 # negative samples count
    learning_rate = 0.001
    decay_rate = 0.9
    momentum = 0.9
    grad_clip = 5.0

### Utility Functions

In [4]:
def truncated_normal_(tensor, mean=0, std=1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)

### Data Loader

In [5]:
class Dataset(object):

    def __init__(self, filename):
        """
        Wraps dataset and produces batches for the model to consume

        :param filename: path to training data for npz file
        """
        self._data = np.load(filename, allow_pickle=True)
        self.train_data = self._data['train_data'][:, :2]
        self.test_data = self._data['test_data'].tolist()
        self._train_index = np.arange(len(self.train_data), dtype=np.uint)
        self._n_users, self._n_items = self.train_data.max(axis=0) + 1

        # Neighborhoods
        self.user_items = defaultdict(set)
        self.item_users = defaultdict(set)
        for u, i in self.train_data:
            self.user_items[u].add(i)
            self.item_users[i].add(u)
        # Get a list version so we do not need to perform type casting
        self.item_users_list = {k: list(v) for k, v in self.item_users.items()}
        self._max_user_neighbors = max([len(x) for x in self.item_users.values()])
        self.user_items = dict(self.user_items)
        self.item_users = dict(self.item_users)

    @property
    def train_size(self):
        """
        :return: number of examples in training set
        :rtype: int
        """
        return len(self.train_data)

    @property
    def user_count(self):
        """
        Number of users in dataset
        """
        return self._n_users

    @property
    def item_count(self):
        """
        Number of items in dataset
        """
        return self._n_items

    def _sample_item(self):
        """
        Draw an item uniformly
        """
        return np.random.randint(0, self.item_count)

    def _sample_negative_item(self, user_id):
        """
        Uniformly sample a negative item
        """
        if user_id > self.user_count:
            raise ValueError("Trying to sample user id: {} > user count: {}".format(
                user_id, self.user_count))

        n = self._sample_item()
        positive_items = self.user_items[user_id]

        if len(positive_items) >= self.item_count:
            raise ValueError("The User has rated more items than possible %s / %s" % (
                len(positive_items), self.item_count))
        while n in positive_items or n not in self.item_users:
            n = self._sample_item()
        return n

    def _generate_data(self, neg_count):
        idx = 0
        self._examples = np.zeros((self.train_size*neg_count, 3),
                                  dtype=np.uint32)
        self._examples[:, :] = 0
        for user_idx, item_idx in self.train_data:
            for _ in range(neg_count):
                neg_item_idx = self._sample_negative_item(user_idx)
                self._examples[idx, :] = [user_idx, item_idx, neg_item_idx]
                idx += 1

    def get_data(self, batch_size: int, neighborhood: bool, neg_count: int):
        """
        Batch data together as (user, item, negative item), pos_neighborhood,
        length of neighborhood, negative_neighborhood, length of negative neighborhood

        if neighborhood is False returns only user, item, negative_item so we
        can reuse this for non-neighborhood-based methods.

        :param batch_size: size of the batch
        :param neighborhood: return the neighborhood information or not
        :param neg_count: number of negative samples to uniformly draw per a pos
                          example
        :return: generator
        """
        # Allocate inputs
        batch = np.zeros((batch_size, 3), dtype=np.uint32)
        pos_neighbor = np.zeros((batch_size, self._max_user_neighbors), dtype=np.int32)
        pos_length = np.zeros(batch_size, dtype=np.int32)
        neg_neighbor = np.zeros((batch_size, self._max_user_neighbors), dtype=np.int32)
        neg_length = np.zeros(batch_size, dtype=np.int32)

        # Shuffle index
        np.random.shuffle(self._train_index)

        idx = 0
        for user_idx, item_idx in self.train_data[self._train_index]:
            # TODO: set positive values outside of for loop
            for _ in range(neg_count):
                neg_item_idx = self._sample_negative_item(user_idx)
                batch[idx, :] = [user_idx, item_idx, neg_item_idx]

                # Get neighborhood information
                if neighborhood:
                    if len(self.item_users.get(item_idx, [])) > 0:
                        pos_length[idx] = len(self.item_users[item_idx])
                        pos_neighbor[idx, :pos_length[idx]] = self.item_users_list[item_idx]
                    else:
                        # Length defaults to 1
                        pos_length[idx] = 1
                        pos_neighbor[idx, 0] = item_idx

                    if len(self.item_users.get(neg_item_idx, [])) > 0:
                        neg_length[idx] = len(self.item_users[neg_item_idx])
                        neg_neighbor[idx, :neg_length[idx]] = self.item_users_list[neg_item_idx]
                    else:
                        # Length defaults to 1
                        neg_length[idx] = 1
                        neg_neighbor[idx, 0] = neg_item_idx

                idx += 1
                # Yield batch if we filled queue
                if idx == batch_size:
                    if neighborhood:
                        max_length = max(neg_length.max(), pos_length.max())
                        yield batch, pos_neighbor[:, :max_length], pos_length, \
                              neg_neighbor[:, :max_length], neg_length
                        pos_length[:] = 1
                        neg_length[:] = 1
                    else:
                        yield batch
                    # Reset
                    idx = 0

        # Provide remainder
        if idx > 0:
            if neighborhood:
                max_length = max(neg_length[:idx].max(), pos_length[:idx].max())
                yield batch[:idx], pos_neighbor[:idx, :max_length], pos_length[:idx], \
                      neg_neighbor[:idx, :max_length], neg_length[:idx]
            else:
                yield batch[:idx]

In [6]:
dataset = Dataset(config.dataset)

config.item_count = dataset.item_count
config.user_count = dataset.user_count
config.max_neighbors = dataset._max_user_neighbors

print(dataset.item_count, dataset.user_count, dataset._max_user_neighbors)

16980 5551 311


## Loss

The loss function is common between both, pairwise GMF pretraining and CMN training.

In [7]:
class LossLayer(nn.Module):
    def __init__(self):
        super(LossLayer, self).__init__()

    def forward(self, X, y):
        """
        :param X: predicted value
        :param y: ground truth
        :returns: Loss
        """        
        bprl = torch.squeeze(self.bpr_loss(X, y))        
        return bprl
    
    def bpr_loss(self, positive, negative):
        r"""
        Pairwise Loss from Bayesian Personalized Ranking.

        \log \sigma(pos - neg)

        where \sigma is the sigmoid function, we try to set the ranking

        if pos > neg = + number
        if neg < pos = - number

        Then applying the sigmoid to obtain a monotonically increasing function. Any
        monotonically increasing function could be used, eg piecewise or probit.

        :param positive: Score of prefered example
        :param negative: Score of negative example
        :param name: str, name scope
        :returns: mean loss
        """
        difference = positive - negative
        # Numerical stability
        eps = 1e-12
        loss = -1*torch.log(torch.sigmoid(difference) + eps)
        return torch.mean(loss)

## Part 1: Pretraining User and Item Embeddings

### Model

In [8]:
class PairwiseGMF(nn.Module):
    
    def __init__(self):
        """
        Constructs the user/item memories and user/item external memory/outputs

        Also add the embedding lookups
        """
        super(PairwiseGMF, self).__init__()
        
        # MemoryEmbed
        self.user_memory = nn.Embedding(config.user_count, config.embed_size)
        truncated_normal_(self.user_memory.weight, std=0.01)
        self.user_memory.weight.requires_grad = True

        # ItemMemory
        self.item_memory = nn.Embedding(config.item_count, config.embed_size)        
        truncated_normal_(self.item_memory.weight, std=0.01)
        self.item_memory.weight.requires_grad = True
        
        self.v = nn.Linear(config.embed_size, 1, bias=False)
        nn.init.xavier_uniform_(self.v.weight)
        self.v.weight.requires_grad = True

    def forward(self, input_users, input_items, input_items_negative):
        """
        Construct the model; main part of it goes here
        """
        # [batch, embedding size]
        cur_user = self.user_memory(input_users)

        # Item memories a query
        cur_item = self.item_memory(input_items)
        cur_item_negative = self.item_memory(input_items_negative)

        score = F.relu(self.v(cur_user * cur_item))
        negative_output = F.relu(self.v(cur_user * cur_item_negative))
        
        return score, negative_output

In [9]:
model = PairwiseGMF().to(device)

### Training Loop

In [10]:
for name, param in model.named_parameters():
    print(name)

user_memory.weight
item_memory.weight
v.weight


In [11]:
criterion = LossLayer()
optimizer = torch.optim.Adam(model.parameters(), 
                             lr=config.learning_rate)

for i in range(config.pretraining_epochs):
    model.train()
    model.zero_grad()
    
    progress = tqdm(enumerate(dataset.get_data(config.batch_size, False, config.neg_count)),
                    dynamic_ncols=True, total=(dataset.train_size * config.neg_count) // config.batch_size)
    loss = []
    for k, batch in progress:
        
        optimizer.zero_grad()
        
        input_users = torch.LongTensor(np.array(batch[:, 0], dtype=np.int32)).to(device)
        input_items = torch.LongTensor(np.array(batch[:, 1], dtype=np.int32)).to(device)
        input_items_negative = torch.LongTensor(np.array(batch[:, 2], dtype=np.int32)).to(device)
        
        score, negative_output = model(input_users, input_items, input_items_negative)
        batch_loss = criterion(score, negative_output)
        
        # adding l2 regularisation
        for name, param in model.named_parameters():
            if name in ['v.weight']:
                l2 = torch.sqrt(param.pow(2).sum())
                batch_loss += (config.pretraining_l2_lambda * l2)
        
        batch_loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        
        optimizer.step()
        
        loss.append(batch_loss.item())
        progress.set_description(u"[{}] Loss: {:,.4f} » » » » ".format(i, batch_loss.item()))

    print("Epoch {}: Avg Loss/Batch {:<20,.6f}".format(i, np.mean(loss)))


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 0: Avg Loss/Batch 0.693170            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 1: Avg Loss/Batch 0.693148            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 2: Avg Loss/Batch 0.693148            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 3: Avg Loss/Batch 0.693138            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 4: Avg Loss/Batch 0.682225            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 5: Avg Loss/Batch 0.517125            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 6: Avg Loss/Batch 0.244349            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 7: Avg Loss/Batch 0.137857            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 8: Avg Loss/Batch 0.099885            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 9: Avg Loss/Batch 0.080794            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 10: Avg Loss/Batch 0.068308            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 11: Avg Loss/Batch 0.059544            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 12: Avg Loss/Batch 0.053044            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 13: Avg Loss/Batch 0.047668            


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 14: Avg Loss/Batch 0.043235            


In [12]:
print('Saving embeddings to: %s' % config.pretrain)
user_embed, item_embed, v = (model.user_memory.weight.detach().cpu(), model.item_memory.weight.detach().cpu(), model.v.weight.detach().cpu())
np.savez(config.pretrain, user=user_embed, item=item_embed, v=v)

Saving embeddings to: pretrain/citeulike-a_e50_v2.npz


## Part 2: Training CMN

### Model

In [13]:
class VariableLengthMemoryLayer(nn.Module):
    def __init__(self, hops, embed_size):
        super(VariableLengthMemoryLayer, self).__init__()
        
        self.hops = hops
        self.embed_size = embed_size
        
        self.hop_mapping = {}
        for h in range(hops-1):
            self.hop_mapping[str(h+1)] = nn.Linear(self.embed_size, self.embed_size, bias=True).to(device)
            self.hop_mapping[str(h+1)].weight.requires_grad = True
            self.hop_mapping[str(h+1)].bias.requires_grad = True
            nn.init.kaiming_normal_(self.hop_mapping[str(h+1)].weight)
            self.hop_mapping[str(h+1)].bias.data.fill_(1.0)    
        self.hop_mapping = nn.ModuleDict(self.hop_mapping)
    
    def mask_mod(self, inputs, mask_length, maxlen=None):
        """
        Apply a memory mask such that the values we mask result in being the
        minimum possible value we can represent with a float32.

        :param inputs: [batch size, length], dtype=tf.float32
        :param memory_mask: [batch_size] shape Tensor of ints indicating the
            length of inputs
        :param maxlen: Sets the maximum length of the sequence; if None infered
            from inputs
        :returns: [batch size, length] dim Tensor with the mask applied
        """
        # [batch_size, length] => Sequence Mask
        memory_mask = torch.arange(maxlen).to(device).expand(len(mask_length), maxlen) < mask_length.unsqueeze(1)
        memory_mask = memory_mask.float()

        num_remaining_memory_slots = torch.sum(memory_mask, 1)

        # Get the numerical limits of a float
        finfo = np.finfo(np.float32)
        # print(finfo)

        # If True = 1 = Keep that memory slot
        kept_indices = memory_mask

        # Inverse
        ignored_indices = memory_mask < 1
        ignored_indices = ignored_indices.float()

        # If we keep the indices its the max float value else its the
        # minimum float value. Then we can take the minimum
        lower_bound = finfo.max * kept_indices + finfo.min * ignored_indices
        slice_length = torch.max(mask_length)
        
        # Return the elementwise
        return torch.min(inputs[:, :slice_length], lower_bound[:, :slice_length])
        
    def apply_attention_memory(self, memory, output_memory, query, memory_mask=None, maxlen=None):
        """
            :param memory: [batch size, max length, embedding size],
                typically Matrix M
            :param output_memory: [batch size, max length, embedding size],
                typically Matrix C
            :param query: [batch size, embed size], typically u
            :param memory_mask: [batch size] dim Tensor, the length of each
                sequence if variable length
            :param maxlen: int/Tensor, the maximum sequence padding length; if None it
                infers based on the max of memory_mask
            :returns: AttentionOutput
                 output: [batch size, embedding size]
                 weight: [batch size, max length], the attention weights applied to
                         the output representation.
        """
        # query = [batch size, embeddings] => expand => [batch size, embeddings, 1]
        # transpose => [batch size, 1, embeddings]
        query_expanded = query.unsqueeze(-1).transpose(2, 1)

        # Apply batched dot product
        # memory = [batch size, <Max Length>, Embeddings]
        # Broadcast the same memory across each dimension of max length
        # We obtain an attention value for each memory,
        # ie a_0 p_0, a_1 p_1, .. a_n p_n, which equates to the max length
        #    because our query is only 1 dim, we only get attention over memory
        #    for that query. If our query was 2-d then we would obtain a matrix.
        # Return: [batch size, max length]
        batched_dot_prod = query_expanded * memory
        scores = batched_dot_prod.sum(2)

        if memory_mask is not None:
            scores = self.mask_mod(scores, memory_mask, maxlen)

        # Attention over memories: [Batch Size, <Max Length>]
        # equation 2
        attention = F.softmax(scores, dim=-1)

        # [Batch Size, <Max Length>] => [Batch Size, 1, <Max Length>]
        probs_temp = attention.unsqueeze(1)

        # Output_Memories = [batch size, <Max Length>, Embeddings]
        # Transpose = [Batch Size, Embedding Size, <Max Length>]
        c_temp = output_memory.transpose(2, 1)

        # Apply a weighted scalar or attention to the external memory 
        # to get weighted neighborhood
        # [batch size, 1, <max length>] * [batch size, embedding size, <max length>]
        neighborhood = c_temp * probs_temp

        # Sum the weighted memories together
        # Input:  [batch Size, embedding size, <max length>]
        # Output: [Batch Size, Embedding Size]
        # Weighted output vector
        # equation 3
        weighted_output = neighborhood.sum(2)

        return {'weight':attention, 'output':weighted_output}
    
    def forward(self, query, memory, output_memory, seq_length, maxlen=32):
        # find maximum length of sequences in this batch
        cur_max = torch.max(seq_length).item()
        # slice to max length
        memory = memory[:, :cur_max]
        output_memory = output_memory[:, :cur_max]
        
        user_query, item_query = query
        hop_outputs = []
        
        # hop 0
        # z = m_u + e_i
        z = user_query + item_query
        
        for hop_k in range(self.hops):
            # hop 1, ... , hop self.hops-1
            if hop_k != 0:                
                # f(Wz + o + b)
                # equation 6
                z = F.relu(self.hop_mapping[str(hop_k)](z) + memory_hop['output'])
            
            # apply attention
            memory_hop = self.apply_attention_memory(memory, 
                                               output_memory,
                                               z, 
                                               seq_length, 
                                               maxlen)
            hop_outputs.append(memory_hop)
        
        return hop_outputs

In [14]:
class OutputModule(nn.Module):
    
    def __init__(self, embed_size):
        super(OutputModule, self).__init__()
        
        self.embed_size = embed_size
        
        self.dense = nn.Linear(self.embed_size*2, self.embed_size, bias=True)
        self.dense.weight.requires_grad = True
        self.dense.bias.requires_grad = True
        nn.init.kaiming_normal_(self.dense.weight)
        self.dense.bias.data.fill_(1.0)
        
        self.out = nn.Linear(self.embed_size, 1, bias = False)
        self.out.weight.requires_grad = True
        nn.init.xavier_uniform_(self.out.weight)
        
    def forward(self, inputs):
        output = F.relu(self.dense(inputs))
        output = self.out(output)
        return output.squeeze()

In [15]:
class CollaborativeMemoryNetwork(nn.Module):
    
    def __init__(self, user_embeddings, item_embeddings):
        super(CollaborativeMemoryNetwork, self).__init__()

        # MemoryEmbed
        self.user_memory = nn.Embedding(user_embeddings.shape[0], user_embeddings.shape[1])
        self.user_memory.weight = nn.Parameter(torch.from_numpy(user_embeddings))
        self.user_memory.weight.requires_grad = True
        
        # ItemMemory
        self.item_memory = nn.Embedding(item_embeddings.shape[0], item_embeddings.shape[1])
        self.item_memory.weight = nn.Parameter(torch.from_numpy(item_embeddings))
        self.item_memory.weight.requires_grad = True

        # MemoryOutput
        self.user_output = nn.Embedding(user_embeddings.shape[0], user_embeddings.shape[1])
        truncated_normal_(self.user_output.weight, std=0.01)
        self.user_output.weight.requires_grad = True

        self.mem_layer = VariableLengthMemoryLayer(2, config.embed_size)

        self.output_module = OutputModule(config.embed_size)

    
    def forward(self, input_users, input_items, input_items_negative, 
                input_neighborhoods, input_neighborhood_lengths, 
                input_neighborhoods_negative, input_neighborhood_lengths_negative, evaluation=False):
        
        # get embeddings from user memory
        cur_user = self.user_memory(input_users)
        cur_user_output = self.user_output(input_users)

        # get embeddings from item memory
        cur_item = self.item_memory(input_items)
        
        # queries
        query = (cur_user, cur_item)
        
        # positive
        neighbor = self.mem_layer(query, 
                                  self.user_memory(input_neighborhoods), 
                                  self.user_output(input_neighborhoods), 
                                  input_neighborhood_lengths, 
                                  config.max_neighbors)[-1]['output']
        
        score = self.output_module(torch.cat((cur_user * cur_item, neighbor), 1))
        
        
        if evaluation:
            return score
        
        cur_item_negative = self.item_memory(input_items_negative)
        neg_query = (cur_user, cur_item_negative)
            
        # negative
        neighbor_negative = self.mem_layer(neg_query, 
                                           self.user_memory(input_neighborhoods_negative), 
                                           self.user_output(input_neighborhoods_negative), 
                                           input_neighborhood_lengths_negative, 
                                           config.max_neighbors)[-1]['output']
        
        negative_output = self.output_module(torch.cat((cur_user * cur_item_negative, 
                                                        neighbor_negative), 1))
        
        return score, negative_output
    

In [16]:
# loading pretrained embeddings
embeddings = np.load(config.pretrain, allow_pickle=True)

# initialize model
model = CollaborativeMemoryNetwork(embeddings['user']*0.5, embeddings['item']*0.5).to(device)

### Evaluation Functions

In [17]:
def get_model_scores(test_data, neighborhood, max_neighbors, return_scores=False):
    """
    test_data = dict([positive, np.array[negatives]])
    """
    out = ''
    scores = []
    progress = tqdm(test_data.items(), total=len(test_data),
                    leave=False, desc=u'Evaluate || ')
    for user, (pos, neg) in progress:
        item_indices = list(neg) + [pos]

        input_users = torch.LongTensor([user] * (len(neg) + 1)).to(device)
        input_items = torch.LongTensor(item_indices).to(device)

        if neighborhood is not None:
            neighborhoods, neighborhood_length = (np.zeros((len(neg) + 1, max_neighbors), dtype=np.int32), 
                                                  np.ones(len(neg) + 1, dtype=np.int32))

            for _idx, item in enumerate(item_indices):
                _len = min(len(neighborhood.get(item, [])), max_neighbors)
                if _len > 0:
                    neighborhoods[_idx, :_len] = neighborhood[item][:_len]
                    neighborhood_length[_idx] = _len
                else:
                    neighborhoods[_idx, :1] = user
                    
            input_neighborhoods = torch.LongTensor(neighborhoods).to(device)
            input_neighborhood_lengths = torch.LongTensor(neighborhood_length).to(device)

        score = model(input_users, input_items, None, input_neighborhoods, 
                      input_neighborhood_lengths, None, None, evaluation=True).cpu()
        
        scores.append(score.detach().numpy().ravel())
        if return_scores:
            s = ' '.join(["{}:{}".format(n, s) for s, n in zip(score.detach().numpy().ravel().tolist(), item_indices)])
            out += "{}\t{}\n".format(user, s)
    if return_scores:
        return scores, out
    return scores


def evaluate_model(test_data, neighborhood, max_neighbors, EVAL_AT=[1, 5, 10]):
    scores = get_model_scores(test_data, neighborhood, max_neighbors)
    hrs = []
    ndcgs = []
    s = '\n'
    for k in EVAL_AT:
        hr, ndcg = get_eval(scores, len(scores[0]) - 1, k)
        s += "{:<14} {:<14.6f}{:<14} {:.6f}\n".format('HR@%s' % k, hr, 'NDCG@%s' % k, ndcg)
        hrs.append(hr)
        ndcgs.append(ndcg)
    print(s + '\n')

    return hrs, ndcgs


def get_eval(scores, index, top_n=10):
    """
    if the last element is the correct one, then
    index = len(scores[0])-1
    """
    ndcg = 0.0
    hr = 0.0
    assert len(scores[0]) > index and index >= 0

    for score in scores:
        # Get the top n indices
        arg_index = np.argsort(-score)[:top_n]
        if index in arg_index:
            # Get the position
            ndcg += np.log(2.0) / np.log(arg_index.tolist().index(index) + 2.0)
            # Increment
            hr += 1.0

    return hr / len(scores), ndcg / len(scores)

### Training Loop

In [18]:
for name, param in model.named_parameters():
    print(name)

user_memory.weight
item_memory.weight
user_output.weight
mem_layer.hop_mapping.1.weight
mem_layer.hop_mapping.1.bias
output_module.dense.weight
output_module.dense.bias
output_module.out.weight


In [19]:
# %%capture training_loop_output

optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate, 
                                momentum=config.momentum)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config.decay_rate)

criterion = LossLayer()

loss = []
for i in range(config.epochs):    
    model.train()
    model.zero_grad()
    
    # Decay Learning Rate
    scheduler.step()
    # Print Learning Rate
    print("[Epoch: {}] [LR: {}]".format(i, scheduler.get_lr()[0]))
    
    progress = tqdm(enumerate(dataset.get_data(config.batch_size, True, config.neg_count)), 
                    dynamic_ncols=True, total=(dataset.train_size * config.neg_count) // config.batch_size)
    
    for k, batch in progress:
        
        optimizer.zero_grad()
        
        ratings, pos_neighborhoods, pos_neighborhood_length, neg_neighborhoods, neg_neighborhood_length = batch
        
        input_users = torch.LongTensor(np.array(ratings[:, 0], dtype=np.int32)).to(device)
        input_items = torch.LongTensor(np.array(ratings[:, 1], dtype=np.int32)).to(device)
        input_items_negative = torch.LongTensor(np.array(ratings[:, 2], dtype=np.int32)).to(device)
        input_neighborhoods = torch.LongTensor(np.array(pos_neighborhoods, dtype=np.int32)).to(device)
        input_neighborhood_lengths = torch.LongTensor(np.array(pos_neighborhood_length, dtype=np.int32)).to(device)
        input_neighborhoods_negative = torch.LongTensor(np.array(neg_neighborhoods, dtype=np.int32)).to(device)
        input_neighborhood_lengths_negative = torch.LongTensor(np.array(neg_neighborhood_length, dtype=np.int32)).to(device)
        
        score_pos, score_neg = model(input_users, input_items, input_items_negative, 
                                     input_neighborhoods, input_neighborhood_lengths, 
                                     input_neighborhoods_negative, input_neighborhood_lengths_negative)
        
        batch_loss = criterion(score_pos, score_neg)
        
        # adding l2 regularisation
        for name, param in model.named_parameters():
            if name in ['mem_layer.hop_mapping.1.weight', 
                        'output_module.dense.weight', 
                        'output_module.out.weight']:
                l2 = torch.sqrt(param.pow(2).sum())
                batch_loss += (config.training_l2_lambda * l2)

        batch_loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        
        optimizer.step()
        
        loss.append(batch_loss.item())
        progress.set_description(u"[{}] Loss: {:,.4f} » » » » ".format(i, batch_loss.item()))
    
    print("Epoch {}: Avg Loss/Batch {:<20,.6f}".format(i, np.mean(loss)))
    model.eval()
    evaluate_model(dataset.test_data, dataset.item_users_list, config.max_neighbors)

[Epoch: 0] [LR: 0.0008100000000000001]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 0: Avg Loss/Batch 0.536484            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.304450      NDCG@1         0.304450
HR@5           0.638984      NDCG@5         0.480291
HR@10          0.776257      NDCG@10        0.524932


[Epoch: 1] [LR: 0.000729]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 1: Avg Loss/Batch 0.422990            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.336876      NDCG@1         0.336876
HR@5           0.697712      NDCG@5         0.527333
HR@10          0.824536      NDCG@10        0.568500


[Epoch: 2] [LR: 0.0006561000000000001]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 2: Avg Loss/Batch 0.363152            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.348946      NDCG@1         0.348946
HR@5           0.709602      NDCG@5         0.539700
HR@10          0.838768      NDCG@10        0.581711


[Epoch: 3] [LR: 0.00059049]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 3: Avg Loss/Batch 0.324254            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.354170      NDCG@1         0.354170
HR@5           0.718249      NDCG@5         0.547642
HR@10          0.839849      NDCG@10        0.587239


[Epoch: 4] [LR: 0.000531441]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 4: Avg Loss/Batch 0.296596            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.361917      NDCG@1         0.361917
HR@5           0.713385      NDCG@5         0.548750
HR@10          0.843812      NDCG@10        0.591349


[Epoch: 5] [LR: 0.0004782969]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 5: Avg Loss/Batch 0.275701            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.367862      NDCG@1         0.367862
HR@5           0.742569      NDCG@5         0.567055
HR@10          0.862007      NDCG@10        0.605902


[Epoch: 6] [LR: 0.00043046721]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 6: Avg Loss/Batch 0.258946            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.368402      NDCG@1         0.368402
HR@5           0.740407      NDCG@5         0.565543
HR@10          0.857323      NDCG@10        0.603527


[Epoch: 7] [LR: 0.000387420489]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 7: Avg Loss/Batch 0.245210            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.377409      NDCG@1         0.377409
HR@5           0.750315      NDCG@5         0.575086
HR@10          0.862007      NDCG@10        0.611604


[Epoch: 8] [LR: 0.0003486784401]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 8: Avg Loss/Batch 0.233615            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.381553      NDCG@1         0.381553
HR@5           0.749054      NDCG@5         0.576298
HR@10          0.863808      NDCG@10        0.613759


[Epoch: 9] [LR: 0.00031381059609000004]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 9: Avg Loss/Batch 0.223697            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.365340      NDCG@1         0.365340
HR@5           0.728878      NDCG@5         0.559591
HR@10          0.853540      NDCG@10        0.600536


[Epoch: 10] [LR: 0.00028242953648100003]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 10: Avg Loss/Batch 0.215119            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.386237      NDCG@1         0.386237
HR@5           0.752297      NDCG@5         0.581444
HR@10          0.870654      NDCG@10        0.620073


[Epoch: 11] [LR: 0.00025418658283290005]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 11: Avg Loss/Batch 0.207548            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.389479      NDCG@1         0.389479
HR@5           0.754819      NDCG@5         0.583843
HR@10          0.874077      NDCG@10        0.622627


[Epoch: 12] [LR: 0.00022876792454961005]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 12: Avg Loss/Batch 0.200885            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.387678      NDCG@1         0.387678
HR@5           0.759142      NDCG@5         0.585451
HR@10          0.875518      NDCG@10        0.623158


[Epoch: 13] [LR: 0.00020589113209464906]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 13: Avg Loss/Batch 0.194852            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.394524      NDCG@1         0.394524
HR@5           0.762926      NDCG@5         0.590503
HR@10          0.876419      NDCG@10        0.627431


[Epoch: 14] [LR: 0.00018530201888518417]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 14: Avg Loss/Batch 0.189326            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.393443      NDCG@1         0.393443
HR@5           0.762565      NDCG@5         0.590215
HR@10          0.874617      NDCG@10        0.626840


[Epoch: 15] [LR: 0.00016677181699666576]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 15: Avg Loss/Batch 0.184343            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.389479      NDCG@1         0.389479
HR@5           0.767069      NDCG@5         0.590476
HR@10          0.878040      NDCG@10        0.626466


[Epoch: 16] [LR: 0.0001500946352969992]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 16: Avg Loss/Batch 0.179814            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.384976      NDCG@1         0.384976
HR@5           0.758962      NDCG@5         0.584064
HR@10          0.873176      NDCG@10        0.621357


[Epoch: 17] [LR: 0.0001350851717672993]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 17: Avg Loss/Batch 0.175646            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.382634      NDCG@1         0.382634
HR@5           0.757881      NDCG@5         0.583172
HR@10          0.873176      NDCG@10        0.620791


[Epoch: 18] [LR: 0.00012157665459056936]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 18: Avg Loss/Batch 0.171798            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.397946      NDCG@1         0.397946
HR@5           0.768330      NDCG@5         0.595030
HR@10          0.879661      NDCG@10        0.631262


[Epoch: 19] [LR: 0.00010941898913151243]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 19: Avg Loss/Batch 0.168240            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.390921      NDCG@1         0.390921
HR@5           0.764907      NDCG@5         0.589925
HR@10          0.876959      NDCG@10        0.626519


[Epoch: 20] [LR: 9.847709021836118e-05]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 20: Avg Loss/Batch 0.164955            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.384255      NDCG@1         0.384255
HR@5           0.755900      NDCG@5         0.582501
HR@10          0.871735      NDCG@10        0.620442


[Epoch: 21] [LR: 8.862938119652506e-05]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 21: Avg Loss/Batch 0.161907            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.398126      NDCG@1         0.398126
HR@5           0.766709      NDCG@5         0.593961
HR@10          0.877860      NDCG@10        0.630421


[Epoch: 22] [LR: 7.976644307687256e-05]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 22: Avg Loss/Batch 0.159058            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.396685      NDCG@1         0.396685
HR@5           0.764907      NDCG@5         0.592226
HR@10          0.878580      NDCG@10        0.629335


[Epoch: 23] [LR: 7.17897987691853e-05]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 23: Avg Loss/Batch 0.156400            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.386957      NDCG@1         0.386957
HR@5           0.763106      NDCG@5         0.587451
HR@10          0.876419      NDCG@10        0.624552


[Epoch: 24] [LR: 6.461081889226677e-05]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 24: Avg Loss/Batch 0.153930            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.397946      NDCG@1         0.397946
HR@5           0.769411      NDCG@5         0.595003
HR@10          0.880202      NDCG@10        0.631079


[Epoch: 25] [LR: 5.81497370030401e-05]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 25: Avg Loss/Batch 0.151596            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.398667      NDCG@1         0.398667
HR@5           0.769771      NDCG@5         0.596117
HR@10          0.877319      NDCG@10        0.631399


[Epoch: 26] [LR: 5.233476330273609e-05]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 26: Avg Loss/Batch 0.149401            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.398487      NDCG@1         0.398487
HR@5           0.768150      NDCG@5         0.595113
HR@10          0.878941      NDCG@10        0.631363


[Epoch: 27] [LR: 4.7101286972462485e-05]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 27: Avg Loss/Batch 0.147341            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.396865      NDCG@1         0.396865
HR@5           0.771933      NDCG@5         0.596800
HR@10          0.879121      NDCG@10        0.631736


[Epoch: 28] [LR: 4.239115827521624e-05]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 28: Avg Loss/Batch 0.145400            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.397226      NDCG@1         0.397226
HR@5           0.765448      NDCG@5         0.592837
HR@10          0.879481      NDCG@10        0.630157


[Epoch: 29] [LR: 3.8152042447694614e-05]


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6232), HTML(value='')), layout=Layout(display…


Epoch 29: Avg Loss/Batch 0.143573            


HBox(children=(IntProgress(value=0, description='Evaluate || ', max=5551, style=ProgressStyle(description_widt…


HR@1           0.397406      NDCG@1         0.397406
HR@5           0.767790      NDCG@5         0.594351
HR@10          0.880022      NDCG@10        0.631091




In [26]:
# %%capture final_eval_output

EVAL_AT = range(1, 11)
hrs, ndcgs = [], []
s = ""
scores, out = get_model_scores(dataset.test_data, dataset.item_users_list, config.max_neighbors, True)

for k in EVAL_AT:
    hr, ndcg = get_eval(scores, len(scores[0])-1, k)
    hrs.append(hr)
    ndcgs.append(ndcg)
    s += "{:<14} {:<14.6f}{:<14} {:.6f}\n".format('HR@%s' % k, hr,
                                                  'NDCG@%s' % k, ndcg)
print(s)

HR@1           0.397406      NDCG@1         0.397406
HR@2           0.561881      NDCG@2         0.501178
HR@3           0.656638      NDCG@3         0.548557
HR@4           0.720411      NDCG@4         0.576022
HR@5           0.767790      NDCG@5         0.594351
HR@6           0.803819      NDCG@6         0.607185
HR@7           0.829760      NDCG@7         0.615832
HR@8           0.850477      NDCG@8         0.622367
HR@9           0.865790      NDCG@9         0.626977
HR@10          0.880022      NDCG@10        0.631091



In [21]:
print('Saving training log...')
with open("{}{}".format(config.logdir, config.version+'.log'), 'w') as fout:
    header = ','.join([str(k) for k in EVAL_AT])
    fout.write("{},{}\n".format('metric', header))
    ndcg = ','.join([str(x) for x in ndcgs])
    hr = ','.join([str(x) for x in hrs])
    fout.write("ndcg,{}\n".format(ndcg))
    fout.write("hr,{}".format(hr))

Saving training log...


In [22]:
# save model weights
print("Saving model...")
torch.save(model.state_dict(), config.ssdir+config.version+'.ss')

Saving model...
