In [1]:
import pandas as pd
import numpy as np
from scipy import sparse
from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from sklearn.metrics import mean_squared_error
import pickle

In [2]:
with open('track_embeddings.pkl', 'rb') as f:
    track_embeddings = pickle.load(f)

In [3]:
track_embeddings

array([[-0.12850836,  0.05144424,  0.1324376 , ..., -0.0701643 ,
        -0.11606281, -0.1602558 ],
       [-0.12560181,  0.09551706,  0.10573198, ..., -0.11453907,
        -0.09670042, -0.14838946],
       [-0.13176899,  0.05360903,  0.10417848, ..., -0.09779049,
        -0.11299129, -0.15423141],
       ...,
       [-0.11718541,  0.15378292,  0.12776648, ..., -0.05870862,
        -0.10497359, -0.161321  ],
       [-0.10790123,  0.20756178,  0.10939675, ...,  0.12004782,
        -0.09874647, -0.13628687],
       [-0.12471136,  0.02359954,  0.11788659, ..., -0.09589922,
        -0.10256677, -0.17452659]], dtype=float32)

In [4]:
train_joke_df = pd.read_csv(r'..\data\recsys-in-practice\train_joke_df.csv')
test_joke_df_nofactrating = pd.read_csv(r'..\data\recsys-in-practice\test_joke_df_nofactrating.csv', index_col=0)

In [5]:
train_joke_df["UID"] = train_joke_df["UID"].astype(int)
train_joke_df["JID"] = train_joke_df["JID"].astype(int)
train_joke_df["UID"] = train_joke_df["UID"] - 1
train_joke_df["JID"] = train_joke_df["JID"] - 1

In [6]:
test_joke_df_nofactrating["UID"] = test_joke_df_nofactrating["UID"].astype(int)
test_joke_df_nofactrating["JID"] = test_joke_df_nofactrating["JID"].astype(int)

test_joke_df_nofactrating["UID"] = test_joke_df_nofactrating["UID"] - 1
test_joke_df_nofactrating["JID"] = test_joke_df_nofactrating["JID"] - 1

In [7]:
jokes_df = pd.DataFrame(track_embeddings, index=range(0, track_embeddings.shape[0]), columns=[f'j_emb_{i}' for i in range(32)])
jokes_df

Unnamed: 0,j_emb_0,j_emb_1,j_emb_2,j_emb_3,j_emb_4,j_emb_5,j_emb_6,j_emb_7,j_emb_8,j_emb_9,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
0,-0.128508,0.051444,0.132438,-0.037298,-0.201775,-0.069636,-0.113742,-0.024279,-0.029537,-0.053748,...,-0.196729,-0.140467,0.230395,-0.003536,-0.060885,-0.165479,-0.068899,-0.070164,-0.116063,-0.160256
1,-0.125602,0.095517,0.105732,-0.028557,-0.183307,-0.077972,-0.131677,0.055064,0.003441,-0.082902,...,-0.188492,-0.215398,0.215273,0.228148,-0.045374,-0.168430,-0.083771,-0.114539,-0.096700,-0.148389
2,-0.131769,0.053609,0.104178,-0.046955,-0.200002,0.016845,-0.251092,0.009414,-0.036045,-0.055575,...,-0.190417,-0.176878,0.221158,-0.021510,-0.072628,-0.171978,-0.068281,-0.097790,-0.112991,-0.154231
3,-0.104805,0.194723,0.106647,-0.044433,-0.163179,-0.154253,-0.264780,0.097174,-0.029896,-0.076351,...,-0.150741,-0.001954,0.195477,0.139092,-0.032957,-0.137452,-0.067610,-0.081869,-0.098548,-0.121358
4,-0.110405,0.015444,0.141151,-0.056691,-0.191257,0.052613,0.081804,0.049767,-0.012462,-0.087378,...,-0.190938,-0.172349,0.216513,0.333815,-0.054110,-0.173943,-0.054134,-0.013408,-0.104885,-0.163853
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,-0.142546,0.170686,0.125764,-0.063433,-0.220362,-0.036655,0.155781,0.014916,-0.052736,-0.086768,...,-0.207480,0.127158,0.234253,-0.107324,-0.099028,-0.204563,-0.098598,0.085497,-0.128372,-0.155286
96,-0.134356,0.132201,0.131486,-0.014839,-0.215539,0.029771,0.132847,0.096979,-0.090910,-0.071125,...,-0.202462,0.137410,0.239188,-0.027634,-0.074429,-0.189961,-0.078969,0.093237,-0.127856,-0.168573
97,-0.117185,0.153783,0.127766,-0.031634,-0.203616,-0.118698,0.015570,0.082631,-0.013541,-0.076063,...,-0.190877,-0.027533,0.231235,0.337730,-0.040640,-0.173552,-0.079145,-0.058709,-0.104974,-0.161321
98,-0.107901,0.207562,0.109397,-0.037075,-0.197782,-0.025386,0.032928,0.020050,-0.057390,-0.064281,...,-0.178827,0.125883,0.221786,-0.048951,-0.055854,-0.173269,-0.081489,0.120048,-0.098746,-0.136287


In [8]:
npusers_arr = np.zeros((24983, 100), dtype=np.half)

for row in tqdm(train_joke_df.values):
    npusers_arr[int(row[0]), int(row[1])] = row[2]

  0%|          | 0/1448364 [00:00<?, ?it/s]

In [9]:
users_df = pd.DataFrame(npusers_arr, index=range(0, npusers_arr.shape[0]), columns=[f'j_{i}' for i in range(100)])
users_df

Unnamed: 0,j_0,j_1,j_2,j_3,j_4,j_5,j_6,j_7,j_8,j_9,...,j_90,j_91,j_92,j_93,j_94,j_95,j_96,j_97,j_98,j_99
0,-7.820312,8.789062,-9.656250,-8.156250,-7.519531,-8.500000,-9.851562,4.171875,-8.976562,-4.761719,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,-5.628906,0.000000,0.000000,0.000000
1,4.078125,-0.290039,6.359375,4.371094,0.000000,-9.656250,-0.729980,-5.339844,8.882812,9.218750,...,2.820312,-4.949219,-0.290039,7.859375,-0.189941,0.000000,3.060547,0.000000,-4.320312,0.000000
2,0.000000,0.000000,0.000000,0.000000,9.031250,9.273438,9.031250,9.273438,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
3,0.000000,8.351562,0.000000,0.000000,1.799805,0.000000,-2.820312,6.210938,0.000000,1.839844,...,0.000000,0.000000,0.000000,0.529785,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
4,8.500000,4.609375,-4.171875,-5.390625,0.000000,0.000000,7.039062,0.000000,-0.439941,0.000000,...,0.000000,5.578125,4.269531,5.191406,5.730469,1.549805,0.000000,6.550781,1.799805,1.599609
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24978,0.439941,0.000000,0.000000,2.330078,0.000000,6.750000,-8.789062,-0.529785,0.000000,0.000000,...,8.828125,-1.209961,9.218750,-6.699219,0.000000,9.031250,6.550781,8.687500,0.000000,7.429688
24979,0.000000,-8.156250,8.593750,9.078125,0.870117,-8.929688,-3.500000,5.781250,-8.109375,0.000000,...,-1.169922,-5.730469,0.000000,0.239990,9.218750,-8.203125,0.000000,-8.593750,9.132812,8.453125
24980,0.000000,0.000000,0.000000,0.000000,-7.769531,0.000000,0.000000,-6.750000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
24981,0.000000,0.000000,0.000000,0.000000,-9.710938,0.000000,4.558594,-8.296875,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


In [10]:
train_joke_df = train_joke_df.merge(users_df, how='inner', left_on='UID',  right_index=True)
train_joke_df

Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_90,j_91,j_92,j_93,j_94,j_95,j_96,j_97,j_98,j_99
0,18028,5,-1.26,0.0,0.0,0.0,0.0,1.120117,-1.259766,3.300781,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
18990,18028,39,8.30,0.0,0.0,0.0,0.0,1.120117,-1.259766,3.300781,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
50486,18028,6,3.30,0.0,0.0,0.0,0.0,1.120117,-1.259766,3.300781,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
69643,18028,15,-9.27,0.0,0.0,0.0,0.0,1.120117,-1.259766,3.300781,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
124589,18028,14,8.11,0.0,0.0,0.0,0.0,1.120117,-1.259766,3.300781,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1303561,21378,61,0.15,0.0,0.0,0.0,0.0,-2.820312,0.000000,-8.007812,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.189941,0.0,0.0
1331118,21378,45,1.31,0.0,0.0,0.0,0.0,-2.820312,0.000000,-8.007812,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.189941,0.0,0.0
1368576,21378,48,-0.68,0.0,0.0,0.0,0.0,-2.820312,0.000000,-8.007812,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.189941,0.0,0.0
1376073,21378,53,2.33,0.0,0.0,0.0,0.0,-2.820312,0.000000,-8.007812,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.189941,0.0,0.0


In [11]:
train_joke_df = train_joke_df.merge(jokes_df, how='inner', left_on='JID',  right_index=True)
train_joke_df

Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
0,18028,5,-1.26,0.000000,0.000000,0.000000,0.000000,1.120117,-1.259766,3.300781,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
120304,3365,5,3.35,5.050781,7.621094,8.929688,4.320312,0.000000,3.349609,1.700195,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
492994,12734,5,1.26,-2.669922,0.000000,-1.500000,0.000000,-3.500000,1.259766,-0.629883,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
81927,11364,5,-5.87,0.000000,0.000000,0.000000,0.000000,0.000000,-5.871094,1.070312,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
650652,17989,5,-0.29,-0.290039,-3.789062,2.039062,0.000000,8.453125,-0.290039,-7.089844,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
501169,19067,84,1.60,0.000000,0.000000,0.000000,0.000000,2.140625,8.062500,-0.099976,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
326575,24909,84,-1.80,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,-2.570312,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1040976,9016,84,3.93,0.000000,0.000000,0.000000,0.000000,2.820312,0.000000,0.000000,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1148245,21967,84,3.54,0.000000,0.000000,0.000000,0.000000,1.799805,2.619141,-3.109375,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882


In [12]:
test_joke_df_nofactrating = test_joke_df_nofactrating.merge(users_df, how='inner', left_on='UID',  right_index=True)
test_joke_df_nofactrating = test_joke_df_nofactrating.merge(jokes_df, how='inner', left_on='JID',  right_index=True)
test_joke_df_nofactrating

Unnamed: 0_level_0,UID,JID,j_0,j_1,j_2,j_3,j_4,j_5,j_6,j_7,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
InteractionID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,11227,38,-9.468750,9.320312,-2.089844,-9.710938,0.000000,9.273438,9.031250,-3.060547,...,-0.218777,-0.011786,0.245097,0.065426,-0.047003,-0.180288,-0.071317,-0.105609,-0.114432,-0.162669
14,13991,38,0.000000,0.000000,0.000000,0.000000,0.000000,2.279297,0.000000,-7.621094,...,-0.218777,-0.011786,0.245097,0.065426,-0.047003,-0.180288,-0.071317,-0.105609,-0.114432,-0.162669
230989,13099,38,-1.410156,0.000000,-5.101562,-5.101562,0.000000,-2.669922,3.539062,0.000000,...,-0.218777,-0.011786,0.245097,0.065426,-0.047003,-0.180288,-0.071317,-0.105609,-0.114432,-0.162669
226103,11003,38,0.000000,-9.757812,5.730469,0.000000,0.970215,6.359375,3.060547,1.259766,...,-0.218777,-0.011786,0.245097,0.065426,-0.047003,-0.180288,-0.071317,-0.105609,-0.114432,-0.162669
42319,22190,38,-2.279297,-0.629883,0.000000,-4.660156,-9.031250,-4.421875,-9.132812,0.000000,...,-0.218777,-0.011786,0.245097,0.065426,-0.047003,-0.180288,-0.071317,-0.105609,-0.114432,-0.162669
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
204228,8614,79,0.000000,0.000000,0.000000,0.000000,7.179688,0.000000,7.820312,0.000000,...,-0.197151,0.014363,0.240012,-0.044904,-0.075044,-0.180220,-0.078847,0.009047,-0.114279,-0.173267
131692,2408,79,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.340088,0.000000,...,-0.197151,0.014363,0.240012,-0.044904,-0.075044,-0.180220,-0.078847,0.009047,-0.114279,-0.173267
134883,2965,79,0.000000,0.000000,0.000000,0.000000,-9.273438,0.000000,-0.830078,-6.800781,...,-0.197151,0.014363,0.240012,-0.044904,-0.075044,-0.180220,-0.078847,0.009047,-0.114279,-0.173267
265032,41,79,0.000000,0.000000,0.000000,0.000000,8.250000,0.000000,2.720703,0.529785,...,-0.197151,0.014363,0.240012,-0.044904,-0.075044,-0.180220,-0.078847,0.009047,-0.114279,-0.173267


In [13]:
for i, row in enumerate(tqdm(train_joke_df.values)):
    j = int(row[1])
    #print(train_joke_df.iloc[i, 3 + j])
    assert np.isclose(row[2], train_joke_df.iloc[i, 3 + j], atol=1e-2)
    train_joke_df.iloc[i, 3 + j] = 0
    #print(train_joke_df.iloc[i, 3 + j])


  0%|          | 0/1448364 [00:00<?, ?it/s]

In [14]:
for i, row in enumerate(tqdm(train_joke_df.values)):
    j = int(row[1])
    #print(train_joke_df.iloc[i, 3 + j])
    #print(row[3 + joke_id])
    assert np.isclose(0, train_joke_df.iloc[i, 3 + j])


  0%|          | 0/1448364 [00:00<?, ?it/s]

In [15]:
train_joke_df

Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
0,18028,5,-1.26,0.000000,0.000000,0.000000,0.000000,1.120117,0.000000,3.300781,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
120304,3365,5,3.35,5.050781,7.621094,8.929688,4.320312,0.000000,0.000000,1.700195,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
492994,12734,5,1.26,-2.669922,0.000000,-1.500000,0.000000,-3.500000,0.000000,-0.629883,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
81927,11364,5,-5.87,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,1.070312,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
650652,17989,5,-0.29,-0.290039,-3.789062,2.039062,0.000000,8.453125,0.000000,-7.089844,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
501169,19067,84,1.60,0.000000,0.000000,0.000000,0.000000,2.140625,8.062500,-0.099976,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
326575,24909,84,-1.80,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,-2.570312,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1040976,9016,84,3.93,0.000000,0.000000,0.000000,0.000000,2.820312,0.000000,0.000000,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1148245,21967,84,3.54,0.000000,0.000000,0.000000,0.000000,1.799805,2.619141,-3.109375,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882


In [16]:
train_joke_df.tail(30)

Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
1273601,6647,84,1.07,0.0,0.0,0.0,0.0,0.0,0.0,-0.680176,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1133216,11958,84,-6.89,0.0,0.0,0.0,0.0,0.0,0.0,-0.48999,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
913962,8582,84,4.9,0.0,0.0,-9.898438,0.0,1.410156,0.72998,2.230469,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
630768,7849,84,-6.41,0.0,0.0,0.0,0.0,1.839844,-9.757812,3.009766,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1205897,12319,84,3.79,0.0,0.0,0.0,0.0,0.0,0.0,-1.549805,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
735943,24881,84,-7.96,0.0,0.0,0.0,0.0,1.169922,0.0,0.0,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1437776,3543,84,3.11,0.0,0.0,0.0,0.0,0.779785,0.0,-6.171875,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1306126,12924,84,4.9,0.0,0.0,0.0,0.0,6.210938,0.0,0.0,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
137588,3196,84,3.06,0.0,0.0,0.0,0.0,-5.390625,0.0,1.30957,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
512106,4710,84,5.34,3.5,3.539062,1.839844,3.980469,-7.820312,4.269531,-4.609375,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882


In [17]:
rows, r_pos = np.unique(train_joke_df.values[:,0], return_inverse=True)
cols, c_pos = np.unique(train_joke_df.values[:,1], return_inverse=True)

In [18]:
rows

array([0.0000e+00, 1.0000e+00, 2.0000e+00, ..., 2.4980e+04, 2.4981e+04,
       2.4982e+04])

In [19]:
cols

array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
       26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38.,
       39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51.,
       52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64.,
       65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77.,
       78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90.,
       91., 92., 93., 94., 95., 96., 97., 98., 99.])

In [20]:
train_df, valid_df, train_queries, valid_queries = train_test_split(train_joke_df, r_pos, test_size=0.01, random_state=42)

In [21]:
len(valid_df)

14484

In [22]:
valid_df[valid_df['UID'] == 44]

Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31


In [23]:
train_df[train_df['UID'] == 44]

Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
337935,44,55,7.91,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.217123,-0.072647,0.254742,0.037862,-0.048185,-0.195705,-0.074436,-0.072623,-0.124790,-0.166313
1072225,44,11,7.04,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.209348,-0.090472,0.253714,-0.023258,-0.070321,-0.189687,-0.081262,-0.079108,-0.113483,-0.156067
1310937,44,42,7.91,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.174013,-0.000910,0.192058,-0.175867,-0.054576,-0.146486,-0.046585,-0.003952,-0.090521,-0.140018
1190578,44,77,3.54,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.206961,0.058762,0.248765,-0.056504,-0.066395,-0.178045,-0.094212,0.032349,-0.105475,-0.168493
551466,44,35,8.50,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.236904,-0.117266,0.281734,0.032962,-0.064384,-0.221286,-0.083540,-0.077467,-0.120540,-0.192433
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
464562,44,84,1.31,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
693981,44,22,8.83,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.182864,-0.000172,0.222849,-0.183920,-0.049851,-0.167913,-0.060558,-0.049183,-0.100267,-0.148523
243050,44,27,8.01,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.217150,0.062103,0.245869,-0.273162,-0.056651,-0.187698,-0.077184,0.012077,-0.119872,-0.177158
1244955,44,15,3.83,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.131025,0.058989,0.154818,0.073774,-0.055937,-0.112853,-0.049320,0.012145,-0.068114,-0.099235


In [24]:
for i, row in enumerate(tqdm(valid_df.values)):
    u = int(row[0])
    j = int(row[1])
    
    train_df.loc[train_df['UID'] == u, f'j_{j}'] = 0
    valid_df.loc[valid_df['UID'] == u, f'j_{j}'] = 0

  0%|          | 0/14484 [00:00<?, ?it/s]

In [25]:
train_df[train_df['UID'] == 44]

Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
337935,44,55,7.91,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.217123,-0.072647,0.254742,0.037862,-0.048185,-0.195705,-0.074436,-0.072623,-0.124790,-0.166313
1072225,44,11,7.04,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.209348,-0.090472,0.253714,-0.023258,-0.070321,-0.189687,-0.081262,-0.079108,-0.113483,-0.156067
1310937,44,42,7.91,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.174013,-0.000910,0.192058,-0.175867,-0.054576,-0.146486,-0.046585,-0.003952,-0.090521,-0.140018
1190578,44,77,3.54,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.206961,0.058762,0.248765,-0.056504,-0.066395,-0.178045,-0.094212,0.032349,-0.105475,-0.168493
551466,44,35,8.50,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.236904,-0.117266,0.281734,0.032962,-0.064384,-0.221286,-0.083540,-0.077467,-0.120540,-0.192433
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
464562,44,84,1.31,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
693981,44,22,8.83,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.182864,-0.000172,0.222849,-0.183920,-0.049851,-0.167913,-0.060558,-0.049183,-0.100267,-0.148523
243050,44,27,8.01,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.217150,0.062103,0.245869,-0.273162,-0.056651,-0.187698,-0.077184,0.012077,-0.119872,-0.177158
1244955,44,15,3.83,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.131025,0.058989,0.154818,0.073774,-0.055937,-0.112853,-0.049320,0.012145,-0.068114,-0.099235


In [26]:
valid_df[valid_df['UID'] == 44]

Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31


In [27]:
from catboost import CatBoostRanker, Pool, MetricVisualizer, CatBoostRegressor
from copy import deepcopy

In [28]:
train_df.drop(columns=['Rating'])

Unnamed: 0,UID,JID,j_0,j_1,j_2,j_3,j_4,j_5,j_6,j_7,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
1315628,14232,13,4.128906,-7.179688,3.880859,-5.679688,-9.320312,0.000000,1.209961,-4.371094,...,-0.219153,-0.062274,0.247847,0.110737,-0.055423,-0.186064,-0.078129,-0.099043,-0.112296,-0.165175
1393766,8113,45,5.578125,4.371094,4.421875,0.000000,2.330078,4.371094,-2.429688,-3.789062,...,-0.215624,-0.081347,0.247605,-0.001797,-0.042055,-0.191130,-0.068965,-0.114132,-0.109277,-0.169667
108972,18801,27,0.000000,0.000000,0.000000,0.000000,-5.781250,0.000000,-1.750000,0.000000,...,-0.217150,0.062103,0.245869,-0.273162,-0.056651,-0.187698,-0.077184,0.012077,-0.119872,-0.177158
199734,3512,1,0.870117,0.000000,-1.500000,0.000000,-3.160156,0.099976,-2.820312,1.309570,...,-0.188492,-0.215398,0.215273,0.228148,-0.045374,-0.168430,-0.083771,-0.114539,-0.096700,-0.148389
1035265,14831,67,0.000000,0.000000,0.000000,0.000000,3.740234,-0.150024,6.019531,2.519531,...,-0.234150,0.020325,0.271443,-0.232583,-0.072568,-0.212612,-0.089987,0.008103,-0.126860,-0.182385
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
633634,11081,21,0.870117,0.000000,0.580078,0.000000,2.960938,1.169922,0.000000,-5.921875,...,-0.196241,-0.090978,0.227133,0.065467,-0.044386,-0.185955,-0.080973,-0.110324,-0.102693,-0.157451
951565,613,76,-7.378906,-8.789062,4.128906,-7.230469,7.179688,0.000000,-4.949219,5.531250,...,-0.206986,0.132555,0.220478,-0.072370,-0.064179,-0.168671,-0.090839,0.048417,-0.101191,-0.162371
407318,6382,7,-0.340088,-9.273438,-0.239990,4.558594,4.078125,1.209961,0.000000,0.000000,...,-0.172836,-0.078154,0.211503,-0.105241,-0.045931,-0.144566,-0.062419,-0.027954,-0.094376,-0.141616
567497,1737,55,1.309570,-0.580078,7.960938,-3.300781,-1.799805,0.000000,3.349609,0.000000,...,-0.217123,-0.072647,0.254742,0.037862,-0.048185,-0.195705,-0.074436,-0.072623,-0.124790,-0.166313


In [29]:
cat_features = ['UID', 'JID']

In [30]:
train_pool = Pool(train_df.drop(columns=['Rating']), label=train_df['Rating'], cat_features=cat_features)
valid_pool = Pool(valid_df.drop(columns=['Rating']), label=valid_df['Rating'], cat_features=cat_features)
main_pool = Pool(train_joke_df.drop(columns=['Rating']), label=train_joke_df['Rating'], cat_features=cat_features)

test_pool = Pool(test_joke_df_nofactrating, cat_features=cat_features)

In [31]:
#train = Pool(
#    data=X_train,
#    label=y_train,
#    group_id=queries_train,
#    cat_features=[0, 1]
#)

#test = Pool(
#    data=X_test,
#    label=y_test,
#    group_id=queries_test,
#    cat_features=[0, 1]
#)

In [32]:
default_parameters = {
    'iterations': 1000,
    'custom_metric': 'RMSE',
    'random_seed': 0,
    'train_dir':'RMSE',
    'objective':'RMSE',
    'loss_function':'RMSE',
    'eval_metric':'RMSE',
}


In [33]:
model = CatBoostRegressor(**default_parameters)
model.fit(train_pool, eval_set=valid_pool, plot=True)

MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))

Learning rate set to 0.159393
0:	learn: 5.0587689	test: 5.0007497	best: 5.0007497 (0)	total: 577ms	remaining: 9m 36s
1:	learn: 4.9225808	test: 4.8534109	best: 4.8534109 (1)	total: 975ms	remaining: 8m 6s
2:	learn: 4.8220048	test: 4.7438272	best: 4.7438272 (2)	total: 1.24s	remaining: 6m 52s
3:	learn: 4.7439467	test: 4.6590053	best: 4.6590053 (3)	total: 1.68s	remaining: 6m 57s
4:	learn: 4.6832274	test: 4.5939083	best: 4.5939083 (4)	total: 2.14s	remaining: 7m 5s
5:	learn: 4.6342774	test: 4.5425347	best: 4.5425347 (5)	total: 2.63s	remaining: 7m 15s
6:	learn: 4.5950590	test: 4.5019889	best: 4.5019889 (6)	total: 2.94s	remaining: 6m 56s
7:	learn: 4.5644862	test: 4.4682614	best: 4.4682614 (7)	total: 3.34s	remaining: 6m 53s
8:	learn: 4.5369596	test: 4.4386635	best: 4.4386635 (8)	total: 3.83s	remaining: 7m 1s
9:	learn: 4.5100353	test: 4.4204260	best: 4.4204260 (9)	total: 4.36s	remaining: 7m 11s
10:	learn: 4.4859270	test: 4.4031564	best: 4.4031564 (10)	total: 4.94s	remaining: 7m 23s
11:	learn: 4.4

93:	learn: 3.9536046	test: 4.2110865	best: 4.2110865 (93)	total: 45.8s	remaining: 7m 21s
94:	learn: 3.9495800	test: 4.2102079	best: 4.2102079 (94)	total: 46.4s	remaining: 7m 21s
95:	learn: 3.9463649	test: 4.2096610	best: 4.2096610 (95)	total: 46.8s	remaining: 7m 20s
96:	learn: 3.9428369	test: 4.2099697	best: 4.2096610 (95)	total: 47.4s	remaining: 7m 21s
97:	learn: 3.9373417	test: 4.2089812	best: 4.2089812 (97)	total: 48s	remaining: 7m 21s
98:	learn: 3.9320030	test: 4.2087957	best: 4.2087957 (98)	total: 48.5s	remaining: 7m 21s
99:	learn: 3.9284220	test: 4.2079422	best: 4.2079422 (99)	total: 49s	remaining: 7m 21s
100:	learn: 3.9259477	test: 4.2069626	best: 4.2069626 (100)	total: 49.5s	remaining: 7m 20s
101:	learn: 3.9242408	test: 4.2066579	best: 4.2066579 (101)	total: 49.9s	remaining: 7m 19s
102:	learn: 3.9170589	test: 4.2053927	best: 4.2053927 (102)	total: 50.5s	remaining: 7m 19s
103:	learn: 3.9131824	test: 4.2052181	best: 4.2052181 (103)	total: 51.1s	remaining: 7m 20s
104:	learn: 3.909

184:	learn: 3.6214558	test: 4.1775287	best: 4.1770739 (183)	total: 1m 30s	remaining: 6m 39s
185:	learn: 3.6191946	test: 4.1770564	best: 4.1770564 (185)	total: 1m 31s	remaining: 6m 39s
186:	learn: 3.6118576	test: 4.1764820	best: 4.1764820 (186)	total: 1m 31s	remaining: 6m 39s
187:	learn: 3.6088850	test: 4.1768109	best: 4.1764820 (186)	total: 1m 32s	remaining: 6m 38s
188:	learn: 3.6065964	test: 4.1763811	best: 4.1763811 (188)	total: 1m 32s	remaining: 6m 38s
189:	learn: 3.6029030	test: 4.1762258	best: 4.1762258 (189)	total: 1m 33s	remaining: 6m 38s
190:	learn: 3.5994551	test: 4.1760228	best: 4.1760228 (190)	total: 1m 33s	remaining: 6m 37s
191:	learn: 3.5964387	test: 4.1758658	best: 4.1758658 (191)	total: 1m 34s	remaining: 6m 37s
192:	learn: 3.5919682	test: 4.1753414	best: 4.1753414 (192)	total: 1m 34s	remaining: 6m 36s
193:	learn: 3.5884125	test: 4.1754185	best: 4.1753414 (192)	total: 1m 35s	remaining: 6m 35s
194:	learn: 3.5837032	test: 4.1753761	best: 4.1753414 (192)	total: 1m 35s	remain

274:	learn: 3.3596649	test: 4.1702618	best: 4.1696171 (267)	total: 2m 15s	remaining: 5m 57s
275:	learn: 3.3580131	test: 4.1701978	best: 4.1696171 (267)	total: 2m 16s	remaining: 5m 56s
276:	learn: 3.3563404	test: 4.1701925	best: 4.1696171 (267)	total: 2m 16s	remaining: 5m 56s
277:	learn: 3.3526618	test: 4.1702828	best: 4.1696171 (267)	total: 2m 17s	remaining: 5m 56s
278:	learn: 3.3500403	test: 4.1702026	best: 4.1696171 (267)	total: 2m 17s	remaining: 5m 55s
279:	learn: 3.3490437	test: 4.1698467	best: 4.1696171 (267)	total: 2m 18s	remaining: 5m 54s
280:	learn: 3.3467679	test: 4.1698222	best: 4.1696171 (267)	total: 2m 18s	remaining: 5m 54s
281:	learn: 3.3441967	test: 4.1693310	best: 4.1693310 (281)	total: 2m 18s	remaining: 5m 53s
282:	learn: 3.3418722	test: 4.1691697	best: 4.1691697 (282)	total: 2m 19s	remaining: 5m 53s
283:	learn: 3.3380946	test: 4.1702733	best: 4.1691697 (282)	total: 2m 19s	remaining: 5m 52s
284:	learn: 3.3370318	test: 4.1698342	best: 4.1691697 (282)	total: 2m 20s	remain

364:	learn: 3.1550914	test: 4.1752164	best: 4.1683660 (315)	total: 2m 59s	remaining: 5m 12s
365:	learn: 3.1542489	test: 4.1753672	best: 4.1683660 (315)	total: 3m	remaining: 5m 12s
366:	learn: 3.1531026	test: 4.1757277	best: 4.1683660 (315)	total: 3m	remaining: 5m 11s
367:	learn: 3.1508054	test: 4.1762335	best: 4.1683660 (315)	total: 3m 1s	remaining: 5m 10s
368:	learn: 3.1488758	test: 4.1770892	best: 4.1683660 (315)	total: 3m 1s	remaining: 5m 10s
369:	learn: 3.1470976	test: 4.1775668	best: 4.1683660 (315)	total: 3m 1s	remaining: 5m 9s
370:	learn: 3.1462808	test: 4.1777110	best: 4.1683660 (315)	total: 3m 2s	remaining: 5m 9s
371:	learn: 3.1447571	test: 4.1778077	best: 4.1683660 (315)	total: 3m 3s	remaining: 5m 9s
372:	learn: 3.1426920	test: 4.1777171	best: 4.1683660 (315)	total: 3m 3s	remaining: 5m 8s
373:	learn: 3.1408800	test: 4.1782417	best: 4.1683660 (315)	total: 3m 4s	remaining: 5m 8s
374:	learn: 3.1392863	test: 4.1785204	best: 4.1683660 (315)	total: 3m 4s	remaining: 5m 7s
375:	learn

454:	learn: 2.9793343	test: 4.1906715	best: 4.1683660 (315)	total: 3m 45s	remaining: 4m 29s
455:	learn: 2.9782297	test: 4.1904078	best: 4.1683660 (315)	total: 3m 45s	remaining: 4m 29s
456:	learn: 2.9769530	test: 4.1905309	best: 4.1683660 (315)	total: 3m 46s	remaining: 4m 28s
457:	learn: 2.9749713	test: 4.1900881	best: 4.1683660 (315)	total: 3m 46s	remaining: 4m 28s
458:	learn: 2.9738210	test: 4.1903944	best: 4.1683660 (315)	total: 3m 47s	remaining: 4m 27s
459:	learn: 2.9722269	test: 4.1904621	best: 4.1683660 (315)	total: 3m 47s	remaining: 4m 27s
460:	learn: 2.9708155	test: 4.1905361	best: 4.1683660 (315)	total: 3m 47s	remaining: 4m 26s
461:	learn: 2.9703522	test: 4.1906820	best: 4.1683660 (315)	total: 3m 48s	remaining: 4m 25s
462:	learn: 2.9656846	test: 4.1914481	best: 4.1683660 (315)	total: 3m 48s	remaining: 4m 25s
463:	learn: 2.9638929	test: 4.1914399	best: 4.1683660 (315)	total: 3m 49s	remaining: 4m 24s
464:	learn: 2.9615940	test: 4.1919996	best: 4.1683660 (315)	total: 3m 49s	remain

544:	learn: 2.8414872	test: 4.2054276	best: 4.1683660 (315)	total: 4m 30s	remaining: 3m 45s
545:	learn: 2.8390518	test: 4.2053107	best: 4.1683660 (315)	total: 4m 30s	remaining: 3m 45s
546:	learn: 2.8374473	test: 4.2052475	best: 4.1683660 (315)	total: 4m 31s	remaining: 3m 44s
547:	learn: 2.8352090	test: 4.2058146	best: 4.1683660 (315)	total: 4m 31s	remaining: 3m 44s
548:	learn: 2.8331563	test: 4.2063189	best: 4.1683660 (315)	total: 4m 32s	remaining: 3m 43s
549:	learn: 2.8311660	test: 4.2069941	best: 4.1683660 (315)	total: 4m 32s	remaining: 3m 43s
550:	learn: 2.8291156	test: 4.2070940	best: 4.1683660 (315)	total: 4m 33s	remaining: 3m 42s
551:	learn: 2.8276971	test: 4.2079877	best: 4.1683660 (315)	total: 4m 33s	remaining: 3m 42s
552:	learn: 2.8247887	test: 4.2083558	best: 4.1683660 (315)	total: 4m 34s	remaining: 3m 41s
553:	learn: 2.8236956	test: 4.2088789	best: 4.1683660 (315)	total: 4m 34s	remaining: 3m 41s
554:	learn: 2.8206924	test: 4.2090203	best: 4.1683660 (315)	total: 4m 35s	remain

634:	learn: 2.7195636	test: 4.2221499	best: 4.1683660 (315)	total: 5m 16s	remaining: 3m 1s
635:	learn: 2.7158391	test: 4.2224272	best: 4.1683660 (315)	total: 5m 16s	remaining: 3m 1s
636:	learn: 2.7151049	test: 4.2225619	best: 4.1683660 (315)	total: 5m 17s	remaining: 3m
637:	learn: 2.7142910	test: 4.2229588	best: 4.1683660 (315)	total: 5m 18s	remaining: 3m
638:	learn: 2.7135415	test: 4.2232610	best: 4.1683660 (315)	total: 5m 18s	remaining: 2m 59s
639:	learn: 2.7091470	test: 4.2235723	best: 4.1683660 (315)	total: 5m 19s	remaining: 2m 59s
640:	learn: 2.7040651	test: 4.2234160	best: 4.1683660 (315)	total: 5m 19s	remaining: 2m 58s
641:	learn: 2.6999611	test: 4.2233665	best: 4.1683660 (315)	total: 5m 19s	remaining: 2m 58s
642:	learn: 2.6983380	test: 4.2230754	best: 4.1683660 (315)	total: 5m 20s	remaining: 2m 57s
643:	learn: 2.6953511	test: 4.2236763	best: 4.1683660 (315)	total: 5m 20s	remaining: 2m 57s
644:	learn: 2.6945016	test: 4.2238148	best: 4.1683660 (315)	total: 5m 21s	remaining: 2m 56

724:	learn: 2.5896458	test: 4.2416434	best: 4.1683660 (315)	total: 6m 2s	remaining: 2m 17s
725:	learn: 2.5859653	test: 4.2416449	best: 4.1683660 (315)	total: 6m 2s	remaining: 2m 16s
726:	learn: 2.5845756	test: 4.2421328	best: 4.1683660 (315)	total: 6m 3s	remaining: 2m 16s
727:	learn: 2.5832381	test: 4.2421909	best: 4.1683660 (315)	total: 6m 3s	remaining: 2m 15s
728:	learn: 2.5798404	test: 4.2420168	best: 4.1683660 (315)	total: 6m 4s	remaining: 2m 15s
729:	learn: 2.5784656	test: 4.2421269	best: 4.1683660 (315)	total: 6m 4s	remaining: 2m 14s
730:	learn: 2.5782226	test: 4.2423621	best: 4.1683660 (315)	total: 6m 5s	remaining: 2m 14s
731:	learn: 2.5763906	test: 4.2427244	best: 4.1683660 (315)	total: 6m 5s	remaining: 2m 13s
732:	learn: 2.5757313	test: 4.2430427	best: 4.1683660 (315)	total: 6m 6s	remaining: 2m 13s
733:	learn: 2.5733017	test: 4.2431453	best: 4.1683660 (315)	total: 6m 7s	remaining: 2m 13s
734:	learn: 2.5725732	test: 4.2433107	best: 4.1683660 (315)	total: 6m 7s	remaining: 2m 12s

814:	learn: 2.4673831	test: 4.2655024	best: 4.1683660 (315)	total: 6m 48s	remaining: 1m 32s
815:	learn: 2.4669723	test: 4.2658018	best: 4.1683660 (315)	total: 6m 49s	remaining: 1m 32s
816:	learn: 2.4667856	test: 4.2659813	best: 4.1683660 (315)	total: 6m 49s	remaining: 1m 31s
817:	learn: 2.4659557	test: 4.2657238	best: 4.1683660 (315)	total: 6m 50s	remaining: 1m 31s
818:	learn: 2.4654367	test: 4.2659142	best: 4.1683660 (315)	total: 6m 50s	remaining: 1m 30s
819:	learn: 2.4646870	test: 4.2658638	best: 4.1683660 (315)	total: 6m 51s	remaining: 1m 30s
820:	learn: 2.4632741	test: 4.2659559	best: 4.1683660 (315)	total: 6m 51s	remaining: 1m 29s
821:	learn: 2.4593483	test: 4.2663456	best: 4.1683660 (315)	total: 6m 52s	remaining: 1m 29s
822:	learn: 2.4582956	test: 4.2665758	best: 4.1683660 (315)	total: 6m 52s	remaining: 1m 28s
823:	learn: 2.4569880	test: 4.2667025	best: 4.1683660 (315)	total: 6m 52s	remaining: 1m 28s
824:	learn: 2.4512348	test: 4.2682991	best: 4.1683660 (315)	total: 6m 53s	remain

904:	learn: 2.3533299	test: 4.2940351	best: 4.1683660 (315)	total: 7m 34s	remaining: 47.7s
905:	learn: 2.3511675	test: 4.2941498	best: 4.1683660 (315)	total: 7m 35s	remaining: 47.2s
906:	learn: 2.3503868	test: 4.2944133	best: 4.1683660 (315)	total: 7m 35s	remaining: 46.7s
907:	learn: 2.3499226	test: 4.2946030	best: 4.1683660 (315)	total: 7m 36s	remaining: 46.2s
908:	learn: 2.3486949	test: 4.2950759	best: 4.1683660 (315)	total: 7m 36s	remaining: 45.7s
909:	learn: 2.3482436	test: 4.2955499	best: 4.1683660 (315)	total: 7m 37s	remaining: 45.2s
910:	learn: 2.3466237	test: 4.2958444	best: 4.1683660 (315)	total: 7m 37s	remaining: 44.7s
911:	learn: 2.3434634	test: 4.2963961	best: 4.1683660 (315)	total: 7m 38s	remaining: 44.2s
912:	learn: 2.3420296	test: 4.2981702	best: 4.1683660 (315)	total: 7m 38s	remaining: 43.7s
913:	learn: 2.3407932	test: 4.2979007	best: 4.1683660 (315)	total: 7m 39s	remaining: 43.2s
914:	learn: 2.3394342	test: 4.2980770	best: 4.1683660 (315)	total: 7m 39s	remaining: 42.7s

995:	learn: 2.2519877	test: 4.3150498	best: 4.1683660 (315)	total: 8m 20s	remaining: 2.01s
996:	learn: 2.2517407	test: 4.3152867	best: 4.1683660 (315)	total: 8m 21s	remaining: 1.51s
997:	learn: 2.2509928	test: 4.3158729	best: 4.1683660 (315)	total: 8m 21s	remaining: 1s
998:	learn: 2.2506826	test: 4.3160053	best: 4.1683660 (315)	total: 8m 22s	remaining: 503ms
999:	learn: 2.2504299	test: 4.3162907	best: 4.1683660 (315)	total: 8m 22s	remaining: 0us

bestTest = 4.16836603
bestIteration = 315

Shrink model to first 316 iterations.


<catboost.core.CatBoostRegressor at 0x233095bebe0>

In [34]:
predict = model.predict(valid_pool)
print(mean_squared_error(valid_df['Rating'].values, predict, squared=False))


4.168366025985882


In [38]:

predict = model.predict(test_pool)

test_joke_df_nofactrating['Rating'] = predict

display(test_joke_df_nofactrating['Rating'].to_frame().head(5))
test_joke_df_nofactrating['Rating'].to_frame().to_csv('catboost_with_rating_and_item_emb_01.csv')

Unnamed: 0_level_0,Rating
InteractionID,Unnamed: 1_level_1
0,1.652449
14,2.459946
230989,-1.809826
226103,-4.728921
42319,-1.679739


In [39]:
default_parameters = {
    'iterations': 333,
    'custom_metric': 'RMSE',
    'random_seed': 0,
    'train_dir':'RMSE',
    'objective':'RMSE',
    'loss_function':'RMSE',
    'eval_metric':'RMSE',
}


In [40]:
model = CatBoostRegressor(**default_parameters)
model.fit(main_pool, plot=True)

MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))

Learning rate set to 0.316119
0:	learn: 4.9085621	total: 460ms	remaining: 2m 32s
1:	learn: 4.7275898	total: 1.07s	remaining: 2m 56s
2:	learn: 4.6214393	total: 1.41s	remaining: 2m 34s
3:	learn: 4.5535655	total: 2.15s	remaining: 2m 57s
4:	learn: 4.5079648	total: 2.71s	remaining: 2m 57s
5:	learn: 4.4738238	total: 3.25s	remaining: 2m 56s
6:	learn: 4.4482785	total: 3.6s	remaining: 2m 47s
7:	learn: 4.4179661	total: 4.13s	remaining: 2m 47s
8:	learn: 4.4038769	total: 4.63s	remaining: 2m 46s
9:	learn: 4.3876662	total: 5.02s	remaining: 2m 42s
10:	learn: 4.3705488	total: 5.58s	remaining: 2m 43s
11:	learn: 4.3579961	total: 6.04s	remaining: 2m 41s
12:	learn: 4.3452540	total: 6.57s	remaining: 2m 41s
13:	learn: 4.3340014	total: 7.16s	remaining: 2m 43s
14:	learn: 4.3136176	total: 7.66s	remaining: 2m 42s
15:	learn: 4.2989083	total: 8.33s	remaining: 2m 45s
16:	learn: 4.2852841	total: 8.94s	remaining: 2m 46s
17:	learn: 4.2731754	total: 9.42s	remaining: 2m 44s
18:	learn: 4.2635254	total: 9.98s	remaining: 

157:	learn: 3.3121529	total: 1m 19s	remaining: 1m 27s
158:	learn: 3.3078294	total: 1m 19s	remaining: 1m 27s
159:	learn: 3.3053359	total: 1m 20s	remaining: 1m 26s
160:	learn: 3.3010411	total: 1m 20s	remaining: 1m 26s
161:	learn: 3.2976533	total: 1m 21s	remaining: 1m 25s
162:	learn: 3.2940072	total: 1m 21s	remaining: 1m 25s
163:	learn: 3.2913342	total: 1m 22s	remaining: 1m 25s
164:	learn: 3.2859638	total: 1m 23s	remaining: 1m 24s
165:	learn: 3.2797641	total: 1m 23s	remaining: 1m 24s
166:	learn: 3.2756253	total: 1m 24s	remaining: 1m 23s
167:	learn: 3.2667066	total: 1m 24s	remaining: 1m 23s
168:	learn: 3.2628021	total: 1m 25s	remaining: 1m 22s
169:	learn: 3.2604366	total: 1m 25s	remaining: 1m 22s
170:	learn: 3.2554882	total: 1m 26s	remaining: 1m 21s
171:	learn: 3.2497163	total: 1m 27s	remaining: 1m 21s
172:	learn: 3.2475022	total: 1m 27s	remaining: 1m 20s
173:	learn: 3.2421940	total: 1m 28s	remaining: 1m 20s
174:	learn: 3.2366272	total: 1m 28s	remaining: 1m 20s
175:	learn: 3.2348062	total:

312:	learn: 2.7675446	total: 2m 38s	remaining: 10.1s
313:	learn: 2.7481086	total: 2m 39s	remaining: 9.64s
314:	learn: 2.7437812	total: 2m 39s	remaining: 9.12s
315:	learn: 2.7364019	total: 2m 40s	remaining: 8.61s
316:	learn: 2.7348928	total: 2m 40s	remaining: 8.1s
317:	learn: 2.7324016	total: 2m 40s	remaining: 7.58s
318:	learn: 2.7306322	total: 2m 41s	remaining: 7.08s
319:	learn: 2.7300181	total: 2m 41s	remaining: 6.57s
320:	learn: 2.7285387	total: 2m 42s	remaining: 6.07s
321:	learn: 2.7270278	total: 2m 42s	remaining: 5.56s
322:	learn: 2.7224587	total: 2m 43s	remaining: 5.05s
323:	learn: 2.7194448	total: 2m 43s	remaining: 4.55s
324:	learn: 2.7180130	total: 2m 44s	remaining: 4.04s
325:	learn: 2.7153150	total: 2m 44s	remaining: 3.54s
326:	learn: 2.7130203	total: 2m 45s	remaining: 3.03s
327:	learn: 2.7114530	total: 2m 45s	remaining: 2.53s
328:	learn: 2.7041733	total: 2m 46s	remaining: 2.02s
329:	learn: 2.7033102	total: 2m 47s	remaining: 1.52s
330:	learn: 2.7010815	total: 2m 47s	remaining: 

<catboost.core.CatBoostRegressor at 0x2330959c160>

In [41]:

predict = model.predict(test_pool)

test_joke_df_nofactrating['Rating'] = predict

display(test_joke_df_nofactrating['Rating'].to_frame().head(5))
test_joke_df_nofactrating['Rating'].to_frame().to_csv('catboost_with_rating_and_item_emb_alldata.csv')

Unnamed: 0_level_0,Rating
InteractionID,Unnamed: 1_level_1
0,1.293766
14,1.450526
230989,-0.874008
226103,-4.05437
42319,0.27871


In [42]:
len(test_joke_df_nofactrating)

362091