In [403]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

**Reading data.** 

In [404]:
df = pd.read_csv("jester-data-1.csv", header = None)

In [405]:
df.shape

(24983, 101)

**Checking missing values.** 

In [406]:
# as provided in data description we know that Ratings are real values ranging from -10.00 to +10.00 (the value "99" corresponds to "null" = "not rated").


In [407]:
number_of_user_ratings = df[0]

In [408]:
df = df[range(1, 101)].replace(99, np.nan)

In [409]:
df[0] = number_of_user_ratings

In [410]:
df.isnull().sum().sum() / df.shape[0]

27.532522115038226

**Dropping duration column.** 

In [411]:
# according to description this column doesn't contains the ratings
df.drop(0, axis = 1, inplace=True)

**Total number of best and worst jokes using pandas.** 

In [412]:
finding_best_rated_jokes = pd.DataFrame([{x: df[df[x] > 0].shape[0] for x in df.columns}]).T

In [413]:
finding_best_rated_jokes.sort_values(0, ascending=True).iloc[:10]

Unnamed: 0,0
74,3290
58,3504
71,3907
75,4343
79,4715
86,5088
77,5093
73,5207
84,5355
80,5439


In [415]:
df[95][3884]

nan

In [416]:
# From above are best jokes top-10

**Getting train, test and validation set** 

In [417]:
test_data = df[df.isnull().any(axis=1)]

In [418]:
validation_data = df.dropna().iloc[:int(df.dropna().shape[0] * 0.20)]

In [419]:
df_ =  df.dropna().iloc[int(df.dropna().shape[0] * 0.20):]

**Training** 

In [440]:
n_latent_factors = 12

user_ratings = df.values
# Initialise as random values
latent_user_preferences = np.random.random((user_ratings.shape[0], n_latent_factors))
latent_item_features = np.random.random((user_ratings.shape[1], n_latent_factors))

In [428]:
n_latent_factors = 2

user_ratings = df_.values
# Initialise as random values
latent_user_preferences = np.random.random((user_ratings.shape[0], n_latent_factors))
latent_item_features = np.random.random((user_ratings.shape[1], n_latent_factors))

In [340]:
def predict_rating(user_id, item_id):
    """ Predict a rating given a user_id and an item_id.
    """
    user_preference = latent_user_preferences[user_id]
    item_preference = latent_item_features[item_id]
    return user_preference.dot(item_preference)


def train(user_id, item_id, rating, alpha=0.0001):
    #print(item_id)
    predicted_rating = predict_rating(user_id, item_id)
    err =  predicted_rating - rating
    #print(err)
    user_pref_values = latent_user_preferences[user_id]
    latent_user_preferences[user_id] -= alpha * err * latent_item_features[item_id]
    latent_item_features[item_id] -= alpha * err * user_pref_values
    return err
    

def sgd(iterations):
    """ Iterate over all users and all items and train for 
        a certain number of iterations
    """
    mse_history = []
    for iteration in range(iterations):
        error = []
        for user_id in range(latent_user_preferences.shape[0]):
            for item_id in range(latent_item_features.shape[0]):
                rating = user_ratings[user_id, item_id]
                if not np.isnan(rating):
                    err = train(user_id, item_id, rating)
                    error.append(err)
        mse = (np.array(error) ** 2).mean()   
        if (iteration % 10000) == 0:
            print('Iteration %d/%d:\tMSE=%.6f' % (iteration, iterations, mse))
            mse_history.append(mse)
    return mse_history

In [341]:
num_iter = 5
hist = sgd(num_iter)  # Note how the MSE decreases with the number of iterations

Iteration 0/5:	MSE=26.774039


In [342]:
validation_data.shape

(1440, 100)

**Calculating mse for validation set** 

In [343]:
error = []
for user_id in validation_data.index:
    for item_id in range(validation_data.shape[1]):
        rating = user_ratings[user_id, item_id]
        if not np.isnan(rating):
            predicted_rating = predict_rating(user_id, item_id)
            err =  predicted_rating - rating
            error.append(err)

In [344]:
mse = (np.array(error) ** 2).mean()   

In [481]:
mse

23.711697056362453

**Getting prediction for test set** 

In [476]:
latent_user_preferences_test = np.random.random((test_data.shape[0], n_latent_factors))
latent_item_features_test = np.random.random((test_data.shape[1], n_latent_factors))

In [477]:
predictions = latent_user_preferences_test.dot(latent_item_features_test.T)
predictions

array([[3.97208155, 2.89785495, 3.34519859, ..., 3.17171666, 3.10653541,
        2.85064022],
       [4.26773193, 2.93315288, 3.56060188, ..., 2.78452089, 2.92609719,
        2.37742195],
       [3.4524392 , 3.13627038, 3.67852572, ..., 3.03622653, 3.54800912,
        3.04389289],
       ...,
       [3.55201386, 3.22130086, 3.37561295, ..., 3.32815742, 3.65200498,
        2.92732871],
       [4.00943293, 2.49203718, 2.8445397 , ..., 2.66441896, 2.26682102,
        2.38543605],
       [4.29930357, 2.86183994, 3.79706205, ..., 3.37155174, 3.62412604,
        2.74033903]])

In [478]:
test_data.shape

(17783, 100)

In [479]:
values = [zip(test_data.values[i], predictions[i]) for i in range(predictions.shape[0])]
comparison_data = pd.DataFrame(values)
comparison_data.columns = test_data.columns

In [480]:
comparison_data

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,91,92,93,94,95,96,97,98,99,100
0,"(-7.82, 3.9720815503913824)","(8.79, 2.8978549527682245)","(-9.66, 3.345198587536779)","(-8.16, 3.590587629169965)","(-7.52, 3.0188449409167917)","(-8.5, 3.517031509543339)","(-9.85, 2.310483483877972)","(4.17, 3.081748679125996)","(-8.98, 2.6968006276678707)","(-4.76, 3.0093927738499913)",...,"(2.82, 2.53475698654882)","(nan, 3.7219607188192665)","(nan, 3.521688261188994)","(nan, 3.5896631823887746)","(nan, 1.9085887821997192)","(nan, 2.2385185540495796)","(-5.63, 4.472758013364635)","(nan, 3.171716659117797)","(nan, 3.106535414139134)","(nan, 2.850640216878193)"
1,"(nan, 4.267731926087249)","(nan, 2.933152882309649)","(nan, 3.560601881264577)","(nan, 3.943770416077405)","(9.03, 2.973147530792298)","(9.27, 3.0843741060056233)","(9.03, 2.590439159086898)","(9.27, 3.323204777100451)","(nan, 2.6347772138289365)","(nan, 2.959260301614482)",...,"(nan, 2.9085807416979805)","(nan, 3.5940298603179004)","(nan, 3.8501486418483784)","(9.08, 3.300192969169696)","(nan, 1.9345375558643394)","(nan, 2.4263009237457167)","(nan, 5.033129425223906)","(nan, 2.7845208867368485)","(nan, 2.9260971940300573)","(nan, 2.377421947782539)"
2,"(nan, 3.452439198492598)","(8.35, 3.136270383889947)","(nan, 3.678525715236008)","(nan, 3.7276655550307067)","(1.8, 3.275476001382596)","(8.16, 4.003054748529952)","(-2.82, 2.6316126149740984)","(6.21, 2.8819134161188593)","(nan, 2.581905231419659)","(1.84, 2.999018676341429)",...,"(nan, 2.238884351075046)","(nan, 3.1236980668651384)","(nan, 3.4489033739822195)","(0.53, 3.154429923721124)","(nan, 2.2353705459306594)","(nan, 2.3758362649441027)","(nan, 4.6497978645517)","(nan, 3.0362265251052487)","(nan, 3.548009118287445)","(nan, 3.043892890646703)"
3,"(8.5, 3.055762207857918)","(4.61, 1.7499081579531228)","(-4.17, 2.7704036468965434)","(-5.39, 2.7548906509323863)","(1.36, 2.2050470829085254)","(1.6, 2.4840917673833904)","(7.04, 1.961874225738842)","(4.61, 2.505459117497918)","(-0.44, 2.076600258371243)","(5.73, 2.428746792074193)",...,"(5.19, 2.0692042671201065)","(5.58, 2.417904303220456)","(4.27, 2.7333386350528737)","(5.19, 2.3764305613658268)","(5.73, 1.585325101809795)","(1.55, 2.0243497211574617)","(3.11, 4.075322119901159)","(6.55, 2.3769949033759654)","(1.8, 2.4701556478454867)","(1.6, 2.293988422460752)"
4,"(nan, 2.486414980833757)","(nan, 2.4291438183636394)","(nan, 2.0406195669553724)","(nan, 2.471214880834291)","(8.59, 1.6633374880264116)","(-9.85, 1.714999564888002)","(7.72, 1.5609769453925144)","(8.79, 2.48655122847654)","(nan, 1.3970465344376084)","(nan, 1.5466770329089548)",...,"(nan, 1.5764251737770583)","(nan, 2.213146174635104)","(nan, 2.4814369478422003)","(nan, 2.416111041108622)","(nan, 0.9644516206816417)","(2.33, 1.8930170713370138)","(nan, 3.028774740529995)","(nan, 1.8287723003422247)","(nan, 1.6957745996799949)","(nan, 1.4527062426779815)"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
17778,"(nan, 4.253231645713874)","(nan, 3.2868156024788107)","(nan, 3.896236168198102)","(nan, 3.928994597330417)","(7.67, 2.5695515464779053)","(nan, 3.613275216294128)","(1.02, 2.6541756459384542)","(-3.74, 3.814125780534466)","(nan, 2.7898593016177595)","(nan, 2.9889648079730864)",...,"(nan, 2.9198993803285815)","(nan, 3.371078988892632)","(nan, 3.879600545548267)","(nan, 3.191165558835543)","(nan, 2.343131474065769)","(nan, 2.821718982454611)","(nan, 5.263888429339908)","(nan, 3.2545961718247156)","(nan, 3.2349694015526995)","(nan, 2.546507312952199)"
17779,"(9.13, 3.534350694090974)","(-8.16, 3.227393342340868)","(8.59, 3.4657961468700287)","(9.08, 4.463647198033294)","(0.87, 3.2897840473168256)","(-8.93, 3.8580349909441245)","(-3.5, 3.1222038449292286)","(5.78, 3.4689152546848785)","(-8.11, 2.7519461446381634)","(4.9, 2.725965601699293)",...,"(-1.17, 2.604338805203881)","(-5.73, 3.304119384171563)","(-1.46, 3.2889675626185246)","(0.24, 3.337223944672344)","(9.22, 1.9239012411037413)","(-8.2, 2.6042700050369976)","(-7.23, 4.634679422622741)","(-8.59, 3.6864999702707943)","(9.13, 3.7036245312554628)","(8.45, 2.6449928670518643)"
17780,"(nan, 3.5520138618653214)","(nan, 3.2213008588845664)","(nan, 3.375612948210173)","(nan, 3.9961155060009506)","(-7.77, 3.255198316128628)","(nan, 4.077224586907583)","(6.7, 2.719243570790532)","(-6.75, 2.91029625945342)","(nan, 2.582857068175021)","(nan, 2.7846738681061534)",...,"(nan, 2.4252233063973545)","(nan, 3.5143581950845144)","(nan, 3.180064430145627)","(nan, 3.58693724997245)","(nan, 2.0823087601041257)","(nan, 1.8979766132253868)","(nan, 4.037658670060064)","(nan, 3.32815742166544)","(nan, 3.652004980156181)","(nan, 2.9273287109013157)"
17781,"(nan, 4.009432928591691)","(nan, 2.492037184987731)","(nan, 2.844539699805461)","(nan, 3.0989940727953735)","(-9.71, 2.1458025650006625)","(nan, 3.0889205877775714)","(4.56, 2.11499639208768)","(-8.3, 3.250798957377448)","(nan, 2.0364101893261792)","(nan, 3.1252839371205274)",...,"(nan, 2.3600124477777062)","(nan, 3.0840546029839064)","(nan, 3.13501880224413)","(nan, 3.0349885807524464)","(nan, 1.5470633646019638)","(nan, 1.712699247511586)","(nan, 4.1861919684252715)","(nan, 2.664418961391413)","(nan, 2.266821015191779)","(nan, 2.3854360485100097)"
