In [1]:
from matrix_factorization import BaselineModel, KernelMF, train_update_test_split

import pandas as pd
from sklearn.metrics import mean_squared_error

# Movie data found here https://grouplens.org/datasets/movielens/
cols = ["user_id", "item_id", "rating", "timestamp"]
movie_data = pd.read_csv(
    "u.data", names=cols, sep="\t", usecols=[0, 1, 2], engine="python"
)

cols2 = ["item_id", "item_name"]
movie_names = pd.read_csv(
    "u.item", names=cols2, sep="|", usecols=[0, 1], engine="python"
)



# Declare a list that is to be converted into a column
#movie_name = []

#movie_data = pd.reset_index()  # make sure indexes pair with number of rows
#for index, row in movie_data.iterrows():
#    for index2, row2 in movie_names.iterrows():
#        if row['item_id'] == row2['item_id']:
#            movie_name.append(row2['item_name'])
#            print(row['item_id'], row2['item_name'])
            
df_outer = pd.merge(movie_data, movie_names, on='item_id', how='outer') #here id is common column



X = df_outer[["user_id", "item_name"]]
y = df_outer["rating"]

#df_outer
    
# Adding the new column to the movie_data
#movie_data['movie_name'] = movie_name
 
# Observe the result


# Split data for learning
(
    X_train_initial,
    y_train_initial,
    X_train_update,
    y_train_update,
    X_test_update,
    y_test_update,
) = train_update_test_split(df_outer, frac_new_users=0.2)

# Initial training
matrix_fact = KernelMF(n_epochs=20, n_factors=20, verbose=1, lr=0.001, reg=0.005)
matrix_fact.fit(X_train_initial, y_train_initial)

# Update model with new users
matrix_fact.update_users(
    X_train_update, y_train_update, lr=0.001, n_epochs=20, verbose=1
)
pred = matrix_fact.predict(X_test_update)
rmse = mean_squared_error(y_test_update, pred, squared=False)
# print(f"\nTest RMSE: {rmse:.4f}")

# Get recommendations
user = 189
items_known = X_train_initial.query("user_id == @user")["item_id"]

matrix_fact.recommend(user=user, items_known=items_known)




Epoch  1 / 20  -  train_rmse: 1.075977554752216
Epoch  2 / 20  -  train_rmse: 1.0443785553314529
Epoch  3 / 20  -  train_rmse: 1.0227139084617995
Epoch  4 / 20  -  train_rmse: 1.0070379600392083
Epoch  5 / 20  -  train_rmse: 0.9951675056922648
Epoch  6 / 20  -  train_rmse: 0.9858290656984885
Epoch  7 / 20  -  train_rmse: 0.9782459270834505
Epoch  8 / 20  -  train_rmse: 0.9719347382647087
Epoch  9 / 20  -  train_rmse: 0.9665716720910427
Epoch  10 / 20  -  train_rmse: 0.9619287566067772
Epoch  11 / 20  -  train_rmse: 0.9578498596062578
Epoch  12 / 20  -  train_rmse: 0.9542230504081175
Epoch  13 / 20  -  train_rmse: 0.9509621157989471
Epoch  14 / 20  -  train_rmse: 0.9480035927360065
Epoch  15 / 20  -  train_rmse: 0.9452993575843485
Epoch  16 / 20  -  train_rmse: 0.9428087112759618
Epoch  17 / 20  -  train_rmse: 0.9405028790419802
Epoch  18 / 20  -  train_rmse: 0.9383566463459216
Epoch  19 / 20  -  train_rmse: 0.9363476936785782
Epoch  20 / 20  -  train_rmse: 0.9344585398531063
Epoch  1 /

Unnamed: 0,user_id,item_id,rating_pred
44,189,12,4.758838
337,189,64,4.688385
417,189,408,4.645301
66,189,272,4.591768
54,189,169,4.588638
69,189,98,4.561279
128,189,357,4.514342
284,189,302,4.486102
196,189,515,4.470198
87,189,427,4.459109


In [7]:
for row in matrix_fact.item_features:
    print(row[0])

0.1801954706368779
0.017277040741413053
0.0948449359538893
0.02608941553884216
-0.04064570897390858
0.04240479209959975
-0.03601127012167758
-0.010092113842582419
-0.01200317195177999
-0.06562917877341251
-0.17401865360376978
-0.1327925535659596
0.16888669230422657
0.00860763643687947
-0.00616620582939435
0.05429356787683381
0.11159770034764893
-0.07253604985831627
0.0650008865442658
0.11488632540178141
0.03718822952459383
-0.0567844055717825
0.14071752539080631
0.008837643962751083
-0.10373806242106433
0.07537102115150375
0.01798970726037658
-0.1362591188394343
0.18234416700111453
0.11907176020538233
0.020018177632210424
-0.021683155978316347
0.08113123515464427
0.03874388845990045
0.05292292809723514
-0.06111990778524852
0.008832455879646907
0.1705588792667721
-0.1334913317501972
0.03918568447661011
0.004942323270354779
0.15079136516243044
-0.02743463014824802
-0.036738475870584764
-0.08136319815731437
0.1094337706238884
-0.008451861988579269
0.104889981610805
0.18717975906573506
0.1

0.032524232185721036
-0.16197176349802578
0.16038426909178305
-0.11637947172959728
-0.19569587260279292
0.09634979099959291
-0.03941368115745308
-0.10942554975018165
0.0061888311398245795
-0.11433464900255574
-0.003775403641032757
-0.07009080830014326
0.16313409358797498
-0.265537986085222
0.035925408684312904
0.07878129157181182
-0.1700671616091005
-0.05265763669326539
-0.002724050062180999
0.1582593958066154
0.0066026997850012785
-0.09408871611004677
0.029035840311799087
-0.11047987447530323
0.14118248951463977
-0.01901910893920593
-0.09582037392035148
-0.12274154739500431
0.0231255932987355
0.13259457087096718
-0.04681383008078467
-0.11020512101020197
0.03709631759428879
-0.13204887145375868
0.0247073014076454
-0.015541423118127918
0.04513525224915122
-0.018819854708921193
-0.11863461358039085
0.0827793744706836
0.13058776564550975
0.1487530431030996
-0.030321896453839668
0.019892464121167006
-0.07116803867325637
0.11119208059266106
-0.08550794212894029
0.08887297539561775
-0.083969