# User-user neighborhood model

This notebook present the user-user neighborhood model training and testing

Note: Because of the dataset size, we chose to adapt the model following these steps:

__Preprocessing:__ we replace the plays by binary interactions (1 if any plays, 0 otherwise). 


__Pipeline:__


1. For each user, compute a set of neighbors and compute the similarity. 
2. Using the neighbors, we compute predictions for the number of plays for a given artist.

__Summary__

1.[Load datasets](#1.-Load-datasets)
- Load train and test datasets. 

2.[Compute list of user neighbors](#2.-Find-user-neighbors)
- For each user, using the binary interactions, compute a list of neighbors with similarity score

3.[Compute predictions](#Compute-predictions)
- Given list of neighbors, compute predictions for a given user. 

4.[Evaluate model](#Evaluate-model)
- Compute metrics for a small subset of users to evaluate model predictions

5.[Fine tune model](#Fine-tuning)
- Compare different similarity score and prediction formula

In [1]:
import os
import sys
from tqdm.auto import tqdm
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import time

from IPython.display import clear_output

sys.path.append('src')
from train_test import train_test_split

from metrics import compute_metrics

tqdm.pandas()
%load_ext autoreload
%autoreload 2

## 1. Load datasets

In [2]:
DATA_PATH = './data.nosync/lastfm-dataset-360K/'

In [38]:
# Load data files
train = pd.read_csv(DATA_PATH + 'train.csv')
valid = pd.read_csv(DATA_PATH + 'valid.csv')
test = pd.read_csv(DATA_PATH + 'test.csv')
lastfm_360_behav = pd.read_csv(DATA_PATH + 'behav-360k-processed.csv')
lastfm_360_demo = pd.read_csv(DATA_PATH + 'demo-360k-processed.csv')
lastfm_360_demo = lastfm_360_demo.set_index('user_email')
test_users = np.load(DATA_PATH + 'test_users.npy')

In [4]:
train.shape, valid.shape, test.shape

((5644266, 3), (620749, 3), (30022346, 3))

In [6]:
train

Unnamed: 0,user_email,artist_id,rating
0,43941,68275,0.000000
1,62958,82754,3.429031
2,1931,1719,3.828571
3,52024,882,3.625399
4,5085,910,1.753127
...,...,...,...
5644261,61713,142,0.193042
5644262,63257,10599,0.000000
5644263,13378,302,0.021553
5644264,42288,17,0.313919


## 2. Split users 

The dataset is too big to compute the pairwise correlation. In order to fix this, we split the dataset based on users into different groups using age chunks. 

We try to make the groups smaller than 20k users. 


In [7]:
users = train['user_email'].unique()

In [8]:
# Return index given user age
age_columns = lastfm_360_demo.columns[3:93]
def get_age_chunk_index(user_email:int, chunk_size=4):
    user_age_vals = lastfm_360_demo.loc[user_email, age_columns]
    user_age = [i for i, v in user_age_vals.items() if v]
    
    if len(user_age) != 1:
        return -1
    
    min_age = int(float(user_age_vals.index[0]))
    age = int(float(user_age[0]))
    
    return (age - min_age) // chunk_size 

In [9]:
def compute_groups(df):
    users = df['user_email'].unique()
    # Compute chunk index given user email
    user_to_chunk = {u:get_age_chunk_index(u) for u in  tqdm(users)}
    
    # Get the chunk indices
    chunks = set(user_to_chunk.values())

    # Compute dict of chunks to users
    chunk_users = {}
    for c in chunks:
        chunk_users[c] = [u for u, v in user_to_chunk.items() if v == c]
        
    # Compute list of users group of less than 20k users (maximum for our correlation matrix).
    groups = []
    current_group = []
    
    chunks.remove(-1) # Remove the nan age group (added as an individual group)
    
    for i in chunks: # Build group of less than 20k users
        if len(current_group) > 10000 or (len(chunk_users[i]) > 10000): # If length above 10k, reset current group
            groups.append(current_group)
            current_group = []
        
        current_group.extend(chunk_users[i])
    groups.append(chunk_users[-1]) # Add the group of users without age 
    return groups

In [10]:
train_groups = compute_groups(train)

HBox(children=(FloatProgress(value=0.0, max=66928.0), HTML(value='')))




In [11]:
[len(train_groups[i]) for i in range(len(train_groups))]

[9079, 17226, 10001, 11977, 15131]

## 2. Find user neighbors

Because of the big size of the dataset 67k users, the pairwise correlation cannot be compute on all pairs, therefore, we are going to chunk the users dataset using the 'age' demographic parameter. We are going to split the users in chunk of 5 years. 

In [12]:
len(train['artist_id'].unique())

84497

In [13]:
def get_top_k_user_scores(corr_df, user, k=100):
    user_corr = corr_df.loc[user]
    top_k_indices = user_corr.values.argsort()[-(k+1):][::-1][1:]
    top_k_users = user_corr.index[top_k_indices]
    return list(user_corr[top_k_users].items())
    

In [14]:
def filter_artists(train_df, artist_threshold:int=100, verbose:bool=False):
    selected_artists = train_df['artist_id'].value_counts()
    selected_artists= selected_artists[selected_artists > artist_threshold].index.values
    if verbose: print(f"Number of selected artists: {len(selected_artists)}")
    return selected_artists

In [15]:
def compute_correlations(train_df, user_groups, artist_threshold:int=100, k_neighbors:int=100, binary_scores:bool=True, verbose=False):
    
    # In order to reduce the computation time, we make predictions from only a subset of artists. 
    # The subset contains artists listened by more than 100 users.
    selected_artists = filter_artists(train_df, artist_threshold, verbose)
    
    # Make sure there are no duplicates
    my_df = train_df.groupby(['user_email', 'artist_id'], as_index=False).sum()
    # Select row with selected artists
    my_df = my_df[my_df['artist_id'].isin(selected_artists)]
    
    # Select binary score or ratings
    score_col = 'binary' if binary_scores else 'rating'
    
    if score_col == 'binary':
        my_df['binary'] = my_df['rating'].apply(lambda x: 1 if x > 0 else 0)
        
    # Iterate over the user groups to reduce the correlation matrix dimention
    user_to_neighbors = {}
    groups_size = [len(train_groups[i]) for i in range(len(train_groups))]
    if verbose: print(f"User groups size: {groups_size}")
    for user_group in tqdm(user_groups):
        start = time.time()
        my_df_small = my_df[my_df['user_email'].isin(user_group)] # Get data for given user group
        # Build my with scores
        my_df_small = my_df_small.pivot(index='user_email', columns='artist_id', values=score_col).fillna(0)
        # Compute correlation matrix
        corr = np.corrcoef(my_df_small.values)
        corr_df = pd.DataFrame(data = corr, index=my_df_small.index, columns=my_df_small.index).fillna(0)
        end = time.time()
        if verbose: print(f"Correlation matrix computation: {end - start} seconds.")
        # Get top k neighbors
        for user in tqdm(corr_df.index):
            user_to_neighbors[user] = get_top_k_user_scores(corr_df, user)
    
    neighbors_df = pd.DataFrame({'user_email':user_to_neighbors.keys(), 
                                 'neighbors':user_to_neighbors.values()}).set_index('user_email')
    return neighbors_df
    

In [16]:
model = compute_correlations(train, train_groups, verbose=True)

Number of selected artists: 5354
User groups size: [9079, 17226, 10001, 11977, 15131]


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

  c /= stddev[:, None]
  c /= stddev[None, :]


Correlation matrix computation: 7.676588773727417 seconds.


HBox(children=(FloatProgress(value=0.0, max=9076.0), HTML(value='')))


Correlation matrix computation: 30.61698627471924 seconds.


HBox(children=(FloatProgress(value=0.0, max=17223.0), HTML(value='')))


Correlation matrix computation: 9.967627048492432 seconds.


HBox(children=(FloatProgress(value=0.0, max=10001.0), HTML(value='')))


Correlation matrix computation: 12.80122685432434 seconds.


HBox(children=(FloatProgress(value=0.0, max=11974.0), HTML(value='')))


Correlation matrix computation: 20.19828200340271 seconds.


HBox(children=(FloatProgress(value=0.0, max=15126.0), HTML(value='')))





In [22]:
# Save the model
model.to_csv("user_neighborhood_model.csv", index=False)
# Load the model 
#model = pd.read_csv("user_neighborhood_model.csv", index_col=None)

## 3. Compute user predictions


### Edge cases:

- 1. If the user is not in the model (in case all the listened artists are not in the selected artists)
    -> Predictions are 0 for all artists
- 2. For a given artist: If none of the neighbors has listened to it: predict 0

Note: We do not count the neighbors that do not have listened to the artist. 

Issue: How to we take in account the fact that the artist is not popular within the neighbors set. 

In [32]:
selected_artists = filter_artists(train, artist_threshold=300)

In [34]:
def get_neighbors_score(neighbors, n_df, artist_id):
    res_df = pd.DataFrame(neighbors, columns=['user_email'])
    a_df = n_df[n_df['artist_id'] == artist_id]
    res_df = res_df.merge(a_df, how='left', on='user_email')[['user_email', 'rating']].fillna(0).set_index('user_email')
    return res_df
    

In [35]:
def compute_var(remaining_n_rating, n_corr_df, avg_n_rating):
  
    var = 0
    corr_sum = 0
    
    for n in remaining_n_rating.index:
        corr_sum += n_corr_df.loc[n, 'corr']
        
        var += n_corr_df.loc[n, 'corr'] * (remaining_n_rating.loc[n, 'rating'] - avg_n_rating.loc[n, 'rating'])
        
    var = var / corr_sum
    
    #print(var)
    return var

In [36]:
def compute_predictions(train_df, user_email, selected_artists,  model, verbose=False):
    
    if user_email not in model.index: # If the user had no artists in the selected artists, return 0. 
        return [0]*selected_artists
    
    neighbors = model.loc[user_email]
    
    neighbors_email = np.array([n[0] for n in neighbors[0]])
    neighbors_corr = np.array([n[1] for n in neighbors[0]])
    
    n_corr_df = pd.DataFrame(data={'user_email':neighbors_email, 'corr':neighbors_corr}).set_index('user_email')
    
    
    neighbors_behav = train_df[train_df['user_email'].isin(neighbors_email)]
    
    avg_n_rating = pd.DataFrame(neighbors_behav.groupby('user_email')['rating'].mean())
    
    avg_user_rating = train_df.loc[train_df['user_email'] == user_email, 'rating'].mean()
    
    results = []
    
    for a in tqdm(selected_artists): 
        neighbors_rating = get_neighbors_score(neighbors_email, neighbors_behav, a)
        
        remaining_n_rating = neighbors_rating[neighbors_rating['rating'] > 0]
        
        if not len(remaining_n_rating) == 0:
            
            result = avg_user_rating + compute_var(remaining_n_rating, n_corr_df, avg_n_rating)
            
            if np.isnan(result):
                result = 0
            
        else:
            result = 0 # No neighbors have listened to the artist, return 0. 
            
        results.append(result)
        
        
    return results
    

In [37]:
compute_predictions(train, 10, selected_artists, model)

HBox(children=(FloatProgress(value=0.0, max=1810.0), HTML(value='')))




[2.4806858204751685,
 1.9535931081797093,
 1.3733997644466616,
 1.148740264741722,
 1.738766710488829,
 1.6911434765651554,
 1.6659272560658245,
 2.227210474797293,
 1.3817642927066078,
 1.5130840937868093,
 1.4898494648498226,
 1.544750031216121,
 0,
 1.3948442998974089,
 1.8490550268692933,
 1.5427332137479217,
 1.4612086283181467,
 0.9780663457074996,
 0.8364110448654098,
 1.607998766389449,
 1.1328521600768218,
 1.1371491454005906,
 1.7186772902948844,
 1.3123532829958107,
 1.6429687033688092,
 2.2015162922030105,
 1.1370647510129661,
 1.060846228605649,
 1.667942449536413,
 1.264158248622901,
 0.8739020643419618,
 0.7871438425810242,
 1.291929532016189,
 1.3728216195609502,
 2.255337769688647,
 0,
 0.6278200075698116,
 1.7679889968248172,
 0,
 2.157588829206044,
 1.5682343616043355,
 0.9654538215906876,
 1.615820091897218,
 1.271992455272434,
 0.7381308425528367,
 1.6368190931876843,
 1.8895327403905782,
 1.2757099344594316,
 0.8314791191954635,
 1.4943380685294305,
 0.98201475951

In [40]:
len(test)

30022346

In [44]:
train_subset = train[train['user_email'].isin(test_users)]

In [57]:
test_small = test[test['user_email'].isin(test_users)]

In [63]:
test_small.shape

(81918, 3)

In [67]:
len(test_small['user_email'].unique())

184

In [72]:
pred_ratings_dict = {}
true_dict = {}
for user, user_df in tqdm(test_small.groupby('user_email')):
    artists = user_df['artist_id'].values
    pred_ratings_dict[user] = [artists, np.array(compute_predictions(train, user, artists, model))]
    true_dict[user] = [artists, user_df['rating'].values]

HBox(children=(FloatProgress(value=0.0, max=184.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=309.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=409.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=473.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=491.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=501.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=464.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=409.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=419.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=519.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=419.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=409.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=409.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=255.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=519.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=573.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=473.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=36.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=546.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=464.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=473.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=482.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=501.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=373.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=546.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=409.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=146.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=337.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=601.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=419.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=501.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=419.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=501.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=419.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=419.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=464.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=637.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=537.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=373.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=482.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=419.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=473.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=637.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=473.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=409.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=419.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=464.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=546.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=473.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=355.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=519.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=491.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=573.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=519.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=537.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=164.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=546.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=528.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=582.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=473.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=482.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=555.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=464.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=519.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=519.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=409.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=519.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=482.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=455.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=337.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=519.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=419.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=409.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=464.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=510.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=473.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=464.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=364.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=409.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=482.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=683.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=564.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=419.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=419.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=446.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))





In [73]:

# Evaluate the model 
from sklearn.metrics import mean_squared_error

def rmse(y_true, y_pred):
    return mean_squared_error(y_true, y_pred, squared=False)

In [80]:
rmse_arr = []
for user in pred_ratings_dict:
    u_true = true_dict[user][1]
    u_pred = pred_ratings_dict[user][1]
    rmse_arr.append(rmse(u_true, u_pred))
    
print(f"Average RMSE: {np.mean(rmse_arr)}")

Average RMSE: 0.20082133414236836


In [106]:
for k in pred_ratings_dict:
    pred_ratings_dict[k]= np.stack(pred_ratings_dict[k])

In [112]:
k = 10
_, _, _, _, _  = compute_metrics(test.drop(test[test.rating == 0].index), test_users, pred_ratings_dict, k)

Computing precision & recall...


HBox(children=(FloatProgress(value=0.0, max=183.0), HTML(value='')))


Computing normalized discounted cumulative gain...


HBox(children=(FloatProgress(value=0.0, max=183.0), HTML(value='')))


Computing hit rate...


HBox(children=(FloatProgress(value=0.0, max=183.0), HTML(value='')))


Computing average reciprocal hit ranking...


HBox(children=(FloatProgress(value=0.0, max=183.0), HTML(value='')))



Metrics: 

Precision @ 10: 0.35081967213114756
Recall    @ 10: 0.7171155347384854
Ndcg @ 10: 0.6063062869201542
Hit rate: 3.480874316939891
Arhr: 1.1327131581229943
