In [None]:
!pip install pytorch-lightning

In [None]:
!pip install neptune-client

In [1]:
import pandas as pd
import numpy as np
import pickle
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

In [3]:
ratings = pd.read_csv('drive/MyDrive/Colab Notebooks/data/ratings_new.csv')
with open('movie_to_index.pkl', 'rb') as movie_mapping:
    movie_to_index = pickle.load(movie_mapping)
with open('user_to_index.pkl', 'rb') as user_mapping:
    user_to_index = pickle.load(user_mapping)

In [4]:
ratings.movieId = ratings.movieId.apply(lambda x: movie_to_index[x])
ratings.userId = ratings.userId.apply(lambda x: user_to_index[x])
ratings.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,0,0,2.0,1256677210
1,0,1,3.5,1256677486
2,1,2,3.5,1113766176
3,1,3,4.5,1113766820
4,1,4,3.5,1113766824


In [5]:
n_users=int(ratings.userId.nunique())
n_movies=int(ratings.movieId.nunique())
min_rating, max_rating = ratings.rating.min(),ratings.rating.max()

In [6]:
print(
    "Number of users: {}, Number of Movies: {}, Min rating: {}, Max rating: {}".format(
        n_users, n_movies, min_rating, max_rating
    )
)

Number of users: 181664, Number of Movies: 21639, Min rating: 0.5, Max rating: 5.0


In [7]:
ratings['rank_latest'] = ratings.groupby(['userId'])['timestamp'] \
                                .rank(method='first', ascending=False)

In [8]:
train_ratings = ratings[ratings['rank_latest'] != 1]
test_ratings = ratings[ratings['rank_latest'] == 1]

train_ratings = train_ratings[['userId', 'movieId', 'rating']]
test_ratings = test_ratings[['userId', 'movieId', 'rating']]

In [9]:
train_ratings.loc[:, 'rating'] = 1

In [10]:
test_ratings.loc[:, 'rating'] = 1

In [11]:
def generate_data(ratings, train_ratings, num_negatives = 4):
    all_movieIds = ratings['movieId'].unique()

    users, items, labels = [], [], []

    #set of movies that each user has rated
    user_item_set = set(zip(train_ratings['userId'], train_ratings['movieId']))

    for (u, i) in (user_item_set):
        users.append(u)
        items.append(i)
        labels.append(1) # moveis that the user has interacted with are positive
        for _ in range(num_negatives):
            # randomly select an movie
            negative_item = np.random.choice(all_movieIds) 
            # check that the user has not interacted with this movie
            while (u, negative_item) in user_item_set:
                negative_item = np.random.choice(all_movieIds)
            users.append(u)
            items.append(negative_item)
            labels.append(0) # movies not interacted with are negative
    return users, items, labels

In [12]:
users_train, movies_train, labels_train = generate_data(ratings, train_ratings)

In [13]:
len(users_train)

45365260

In [14]:
class MovieDataset(Dataset):

    def __init__(self, users, movies, labels):
        self.users, self.movies, self.labels = users, movies, labels

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        return self.users[idx], self.movies[idx], self.labels[idx]

In [15]:
train_data = MovieDataset(users_train, movies_train, labels_train)
datasets = {'train':train_data}
dataloaders = {x: DataLoader(datasets[x], batch_size=512, num_workers=2)
              for x in ['train']}


In [22]:
class LightningEmbeddingModel(pl.LightningModule):
    
    def __init__(self, num_users, num_movies, n_factors=100, 
                 embedding_dropout=0.1, dropouts=0.2):
        super().__init__()
        self.user_embedding = nn.Embedding(num_embeddings=num_users, embedding_dim=n_factors)
        self.item_embedding = nn.Embedding(num_embeddings=num_movies, embedding_dim=n_factors)
        self.drop_embedding = nn.Dropout(embedding_dropout)
        self.drop_1 = nn.Dropout(dropouts*2)
        self.drop_2 = nn.Dropout(dropouts)
        self.fc1 = nn.Linear(in_features=2*n_factors, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        self.output = nn.Linear(in_features=64, out_features=1)
        self.train_accuracy = pl.metrics.Accuracy()
        
    def forward(self, user_input, item_input):
        # Pass through embedding layers
        user_embedded = self.user_embedding(user_input)
        item_embedded = self.item_embedding(item_input)
        # Concat the two embedding layers
        vector = torch.cat([user_embedded, item_embedded], dim=-1)

        vector = self.drop_embedding(vector)
        vector = nn.ReLU()(self.fc1(vector))
        vector = self.drop_1(vector)
        vector = nn.ReLU()(self.fc2(vector))
        vector = self.drop_2(vector)
        pred = nn.Sigmoid()(self.output(vector))
        return pred
    
    def training_step(self, train_batch, batch_idx):
        user, movie, labels = train_batch
        predicted_labels = self(user, movie)
        loss = nn.BCELoss()(predicted_labels, labels.view(-1, 1).float())
        train_acc_batch = self.train_accuracy(predicted_labels, labels)
        # self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc_batch', train_acc_batch)
        self.log('train_loss_batch', loss)
        return {'loss' : loss, 'accuracy' : train_acc_batch}


    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

In [25]:

num_users = ratings['userId'].max()+1
num_movies = ratings['movieId'].max()+1

model = LightningEmbeddingModel(num_users, num_movies)

In [26]:
from pytorch_lightning.loggers.neptune import NeptuneLogger
neptune_logger = NeptuneLogger(
    api_key="ANONYMOUS",
    project_name="shared/pytorch-lightning-integration")

NeptuneLogger will work in online mode


In [34]:
tb_logger = pl_loggers.TensorBoardLogger('logs/')

In [27]:
trainer = pl.Trainer(max_epochs=5, gpus=-1, logger=neptune_logger)

trainer.fit(model, dataloaders['train'])

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


https://ui.neptune.ai/shared/pytorch-lightning-integration/e/PYTOR-165573



  | Name           | Type      | Params
---------------------------------------------
0 | user_embedding | Embedding | 18.2 M
1 | item_embedding | Embedding | 2.2 M 
2 | drop_embedding | Dropout   | 0     
3 | drop_1         | Dropout   | 0     
4 | drop_2         | Dropout   | 0     
5 | fc1            | Linear    | 25.7 K
6 | fc2            | Linear    | 8.3 K 
7 | output         | Linear    | 65    
8 | train_accuracy | Accuracy  | 0     
---------------------------------------------
20.4 M    Trainable params
0         Non-trainable params
20.4 M    Total params
81.457    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

In [28]:
trainer.callback_metrics

{'train_acc_batch': tensor(1., device='cuda:0'),
 'train_loss_batch': tensor(0.0541, device='cuda:0')}

In [37]:
trainer.logged_metrics

{'epoch': tensor(2.),
 'train_loss_epoch': tensor(0.1244),
 'train_loss_step': tensor(0.1442, device='cuda:0')}

In [38]:
trainer.save_checkpoint("drive/MyDrive/Colab Notebooks/data/new_model.ckpt")

In [29]:
from tqdm import tqdm
all_movieIds = ratings['movieId'].unique()
test_user_item_set = set(zip(test_ratings['userId'], test_ratings['movieId']))

# 
user_interacted_movies = ratings.groupby('userId')['movieId'].apply(list).to_dict()

hits = []
for (u,i) in tqdm(test_user_item_set):
    interacted_movies = user_interacted_movies[u]
    not_interacted_movies = set(all_movieIds) - set(interacted_movies)
    selected_not_interacted = list(np.random.choice(list(not_interacted_movies), 99))
    test_items = selected_not_interacted + [i]
    
    predicted_labels = np.squeeze(model(torch.tensor([u]*100), 
                                        torch.tensor(test_items)).detach().numpy())
    
    top10_items = [test_items[i] for i in np.argsort(predicted_labels)[::-1][0:10].tolist()]
    
    if i in top10_items:
        hits.append(1)
    else:
        hits.append(0)
        
print("The Hit Ratio @ 10 is {:.2f}".format(np.average(hits)))

100%|██████████| 181664/181664 [18:01<00:00, 167.92it/s]

The Hit Ratio @ 10 is 0.77





In [30]:
movie_embeddings = pd.DataFrame(model.item_embedding.weight.data.numpy())

In [31]:
movie_embeddings

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99
0,1.535194,-0.247348,-0.080621,0.564229,-2.510557,-0.695133,-0.832231,1.938959,1.345536,0.885206,2.645245,0.138863,-2.592949,1.071037,1.222089,-3.347285,-3.434695,1.587943,-2.145657,0.989560,-2.131338,-0.267061,-1.324327,-0.989419,-0.590134,2.474216,-1.932369,-2.207471,-2.257832,-1.999413,-2.182904,-2.070303,1.095430,-0.386761,2.240860,1.078167,0.708249,-1.897508,-2.313262,0.762308,...,-0.792037,-3.043008,-0.399016,3.191906,-1.866617,0.879337,-0.522625,1.725829,-2.891703,-1.821911,2.664574,-0.295783,-1.320566,1.644067,0.627649,1.567352,2.504978,-0.724636,-0.291628,1.466641,-1.921248,-3.403442,-1.582265,-0.416403,2.387349,-0.107905,1.551539,1.341488,0.217530,3.491016,-1.069215,-0.937633,-0.377377,-0.814743,0.631861,-1.306406,-0.316799,0.456169,-0.273433,0.749216
1,4.802566,-0.965853,-0.696936,0.930465,-2.632655,-0.511551,-0.980518,1.194111,1.196879,0.620863,0.424983,-0.342816,-2.515524,-0.019239,2.082138,-0.501346,-1.875413,1.317409,-0.334371,0.231722,-0.937812,0.554754,-1.300228,-0.546218,-2.071005,0.743710,-2.568744,-3.037861,-2.540641,-1.018213,-2.144077,-2.382326,1.706227,-1.349982,2.108010,0.515623,1.268672,-1.698326,-1.047991,-0.704154,...,-0.869843,-2.013683,-0.191287,1.470560,-0.475458,0.086576,1.061602,-0.490968,-3.657546,-3.161427,1.495805,0.036516,-1.494187,0.513044,0.494483,1.363697,2.288846,-0.541973,0.564606,-0.161824,-2.462306,-1.284633,-1.211442,-0.122130,3.106641,-0.018089,3.061797,0.749517,-1.624903,1.471355,-2.367860,-1.580792,-0.352224,-0.380344,-0.435751,-2.304036,1.231536,0.700240,-1.105132,-0.044702
2,1.215444,1.695535,0.703967,0.973830,-0.354751,0.938603,-0.504604,-0.798399,0.596563,0.052290,0.274865,-0.224446,-1.404365,0.781034,2.533462,-2.219657,0.769805,-1.131046,-1.279021,-1.980291,1.889334,1.618809,0.063525,1.917970,-0.414986,-2.355111,-1.537288,1.018568,-1.154796,0.712420,0.097756,-0.821441,0.432416,0.842387,1.498330,-0.557821,0.608409,1.380282,-0.828706,-0.900027,...,1.143994,-1.063845,-0.924260,3.000073,1.812496,1.445639,1.618661,-0.698550,0.058844,-1.935270,0.302517,-0.825807,-1.158709,-1.102472,-1.584448,0.294946,1.356493,-1.324837,-1.911882,-1.031229,1.168112,-1.408330,-0.095454,0.486864,2.467602,-0.697069,1.115962,-0.137960,-1.262290,0.252917,-1.316303,0.742727,-1.074956,0.183158,-2.139648,-0.032420,0.568077,0.222712,-0.504295,-1.297538
3,3.028320,0.979472,0.042140,-0.257693,-2.112701,0.156050,0.607013,0.855045,0.897843,0.537421,-0.117037,1.940066,-1.955510,-0.047782,0.704797,-1.891490,-1.164394,2.091318,-1.096848,1.396205,-0.657765,-0.112710,-1.102690,0.020219,-2.260381,0.517397,-2.166209,-2.341950,-2.360511,-1.503836,-0.576013,-2.592741,1.720390,-0.806419,1.324085,0.092807,0.704546,-2.436803,-1.154420,-0.469121,...,-1.075142,-3.170782,-1.147929,2.426361,-0.622375,-0.086361,-1.624184,0.576353,-1.385191,-3.087209,1.813762,-1.306380,-2.040669,-1.100465,1.095960,2.708210,2.896514,-0.948269,1.119571,-0.046128,-1.165171,-1.920288,-1.781693,-0.198736,2.484753,0.673247,1.070982,1.813790,0.103914,1.361632,-1.414584,-2.095924,-0.549448,-0.202776,-1.663242,-2.648712,0.919777,1.948047,0.148528,0.936682
4,0.525346,-1.349836,-0.019433,0.213419,-0.627803,-0.100843,-1.413438,2.548865,1.445349,1.046214,0.482103,-0.308023,-1.549433,1.316298,0.827802,-2.863286,-3.000908,1.577743,-1.876366,1.438588,-2.304731,-0.364120,-1.646048,-2.596822,-1.804161,2.458990,-1.995863,-1.514827,-2.119350,-1.988471,-3.156528,-1.381348,2.156200,-1.564521,2.114737,0.566602,2.146625,-2.417741,-2.725510,0.780879,...,-2.270933,-3.015146,-0.122677,2.234872,-1.448798,2.682798,-0.696697,1.676100,-2.002702,-1.423327,2.056692,-0.124541,-1.347327,1.487660,1.772339,1.465119,2.021623,-0.732103,-1.185861,1.460667,-1.350118,-2.997282,-1.532703,1.392549,1.546805,-0.719284,1.693005,1.538164,-0.560690,4.154436,-2.964317,-1.457283,-2.769279,-0.505201,-0.464938,-1.487550,-1.442978,1.070293,-2.468941,0.347724
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21634,-11.303205,5.615102,6.172951,-9.769722,8.741179,5.868194,5.046060,-11.340105,-11.380955,-9.319210,-7.434369,-5.888227,6.089105,6.758451,6.699721,10.505979,6.465243,-6.831612,-3.365293,-7.211526,7.044685,-5.801210,8.278599,3.362726,12.692915,-4.725628,-1.420968,11.400815,0.387438,6.971975,-1.801395,4.011356,-3.105956,7.508659,-8.918922,-8.051274,-1.077327,10.474669,7.546896,13.330481,...,8.025473,1.554785,8.032415,1.552525,1.879607,7.589856,-5.351361,-7.543920,10.609083,0.193940,-4.409284,-5.315081,7.667352,-0.654638,0.546198,-7.801945,2.979832,-4.204182,6.674282,-8.739772,-4.206005,1.885089,4.664048,-0.408699,1.279529,8.464443,-10.567370,-6.936335,13.398509,-7.662632,9.986078,8.105583,4.648882,10.182254,-5.264226,5.972640,3.031159,-9.490811,0.087133,1.045691
21635,-9.123408,2.804004,8.111215,-12.364523,7.996700,5.090131,6.384393,-10.584096,-11.179126,-7.189492,-7.739944,-4.946179,6.573339,6.364190,4.480616,9.470349,5.918701,-7.442729,-3.740669,-9.995952,7.837955,-3.970417,7.213356,4.790495,11.450585,-6.471291,-3.495595,12.639543,-1.037396,7.363169,-1.516567,3.126568,-3.832046,6.927645,-7.744203,-7.276289,-0.869065,9.229241,9.643314,10.681025,...,10.087626,1.969173,10.163468,1.954554,2.990142,5.981826,-5.474865,-7.935812,9.543193,0.123811,-4.652300,-4.651728,9.106400,-3.031849,-0.664036,-6.318029,2.333277,-4.049113,4.636392,-8.990394,-4.414220,-1.021333,5.734004,-0.049520,0.644961,6.694693,-10.426223,-8.720874,11.186080,-5.754755,9.126072,6.359188,4.232862,9.080841,-5.694338,5.790191,4.774803,-8.202873,0.733221,-0.884144
21636,-12.715288,3.733294,7.960297,-10.445810,7.174436,5.419531,4.349464,-10.734626,-9.516499,-9.154610,-8.643875,-3.589076,7.638786,6.044703,5.576465,12.067415,6.160419,-6.179085,-6.104634,-8.646148,9.128655,-3.742596,9.931428,4.640338,12.062830,-4.516827,-2.318299,11.093092,-3.391605,7.310643,-0.599780,5.487050,-2.772444,6.414431,-9.061742,-6.649461,-0.535908,8.349678,9.519620,12.203557,...,8.608383,3.129779,10.252966,1.051087,1.803296,7.693744,-6.579625,-6.409641,10.528800,0.287034,-3.864629,-4.740696,8.639822,-1.828822,-0.857889,-5.910066,1.746161,-5.183580,5.646160,-8.396484,-6.458504,0.509669,6.057875,-2.770545,-0.671455,7.844072,-10.866656,-8.455945,12.742754,-7.841545,9.098660,6.450333,5.126356,8.248777,-4.971012,4.168180,3.462578,-8.175447,-0.436564,0.120593
21637,-9.813241,4.370050,7.051245,-10.836305,8.856269,4.481215,6.209445,-9.832881,-10.965780,-7.103457,-9.672645,-3.859169,8.445089,6.020241,4.718162,10.750411,6.901471,-4.504827,-4.590376,-11.723133,7.521195,-3.445080,8.399127,4.054757,11.293388,-7.408207,-3.011177,9.182123,-3.314842,7.214219,-2.194689,3.602414,-4.072712,7.339193,-9.554373,-6.314574,0.359389,7.770813,8.921612,12.085795,...,7.421478,2.517490,9.447018,1.712163,0.803138,6.056864,-5.032150,-8.689059,9.031939,-1.200055,-3.569475,-5.291827,7.455212,-1.939491,-1.876697,-6.236553,-0.202723,-4.872121,5.775003,-5.739169,-4.903299,-1.901483,8.355788,-0.660290,0.281076,6.496253,-11.399314,-8.610718,13.164948,-7.393217,10.758722,8.522037,5.411080,6.507087,-5.129093,5.160239,4.306411,-10.584814,-0.314322,0.767515


In [32]:
movie_embeddings.to_pickle("drive/MyDrive/Colab Notebooks/data/nn_embeddings_new_2.pkl")