In [1]:
import pandas as pd
import numpy as np
from scipy import sparse
from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from sklearn.metrics import mean_squared_error
import pickle

In [2]:
with open('track_embeddings.pkl', 'rb') as f:
    track_embeddings = pickle.load(f)

In [3]:
track_embeddings

array([[-0.12850836,  0.05144424,  0.1324376 , ..., -0.0701643 ,
        -0.11606281, -0.1602558 ],
       [-0.12560181,  0.09551706,  0.10573198, ..., -0.11453907,
        -0.09670042, -0.14838946],
       [-0.13176899,  0.05360903,  0.10417848, ..., -0.09779049,
        -0.11299129, -0.15423141],
       ...,
       [-0.11718541,  0.15378292,  0.12776648, ..., -0.05870862,
        -0.10497359, -0.161321  ],
       [-0.10790123,  0.20756178,  0.10939675, ...,  0.12004782,
        -0.09874647, -0.13628687],
       [-0.12471136,  0.02359954,  0.11788659, ..., -0.09589922,
        -0.10256677, -0.17452659]], dtype=float32)

In [4]:
with open('user_embeddings.pkl', 'rb') as f:
    user_embeddings = pickle.load(f)

In [5]:
train_joke_df = pd.read_csv(r'..\data\recsys-in-practice\train_joke_df.csv')

In [6]:
train_joke_df["UID"] = train_joke_df["UID"].astype(int)
train_joke_df["JID"] = train_joke_df["JID"].astype(int)

In [7]:
train_joke_df["UID"] = train_joke_df["UID"] - 1
train_joke_df["JID"] = train_joke_df["JID"] - 1

In [8]:
jokes_df = pd.DataFrame(track_embeddings, index=range(0, track_embeddings.shape[0]), columns=[f'j_emb_{i}' for i in range(32)])
jokes_df

Unnamed: 0,j_emb_0,j_emb_1,j_emb_2,j_emb_3,j_emb_4,j_emb_5,j_emb_6,j_emb_7,j_emb_8,j_emb_9,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
0,-0.128508,0.051444,0.132438,-0.037298,-0.201775,-0.069636,-0.113742,-0.024279,-0.029537,-0.053748,...,-0.196729,-0.140467,0.230395,-0.003536,-0.060885,-0.165479,-0.068899,-0.070164,-0.116063,-0.160256
1,-0.125602,0.095517,0.105732,-0.028557,-0.183307,-0.077972,-0.131677,0.055064,0.003441,-0.082902,...,-0.188492,-0.215398,0.215273,0.228148,-0.045374,-0.168430,-0.083771,-0.114539,-0.096700,-0.148389
2,-0.131769,0.053609,0.104178,-0.046955,-0.200002,0.016845,-0.251092,0.009414,-0.036045,-0.055575,...,-0.190417,-0.176878,0.221158,-0.021510,-0.072628,-0.171978,-0.068281,-0.097790,-0.112991,-0.154231
3,-0.104805,0.194723,0.106647,-0.044433,-0.163179,-0.154253,-0.264780,0.097174,-0.029896,-0.076351,...,-0.150741,-0.001954,0.195477,0.139092,-0.032957,-0.137452,-0.067610,-0.081869,-0.098548,-0.121358
4,-0.110405,0.015444,0.141151,-0.056691,-0.191257,0.052613,0.081804,0.049767,-0.012462,-0.087378,...,-0.190938,-0.172349,0.216513,0.333815,-0.054110,-0.173943,-0.054134,-0.013408,-0.104885,-0.163853
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,-0.142546,0.170686,0.125764,-0.063433,-0.220362,-0.036655,0.155781,0.014916,-0.052736,-0.086768,...,-0.207480,0.127158,0.234253,-0.107324,-0.099028,-0.204563,-0.098598,0.085497,-0.128372,-0.155286
96,-0.134356,0.132201,0.131486,-0.014839,-0.215539,0.029771,0.132847,0.096979,-0.090910,-0.071125,...,-0.202462,0.137410,0.239188,-0.027634,-0.074429,-0.189961,-0.078969,0.093237,-0.127856,-0.168573
97,-0.117185,0.153783,0.127766,-0.031634,-0.203616,-0.118698,0.015570,0.082631,-0.013541,-0.076063,...,-0.190877,-0.027533,0.231235,0.337730,-0.040640,-0.173552,-0.079145,-0.058709,-0.104974,-0.161321
98,-0.107901,0.207562,0.109397,-0.037075,-0.197782,-0.025386,0.032928,0.020050,-0.057390,-0.064281,...,-0.178827,0.125883,0.221786,-0.048951,-0.055854,-0.173269,-0.081489,0.120048,-0.098746,-0.136287


In [9]:
users_df = pd.DataFrame(user_embeddings, index=range(0, user_embeddings.shape[0]), columns=[f'u_{i}' for i in range(32)])
users_df

Unnamed: 0,u_0,u_1,u_2,u_3,u_4,u_5,u_6,u_7,u_8,u_9,...,u_22,u_23,u_24,u_25,u_26,u_27,u_28,u_29,u_30,u_31
0,-0.562434,-0.512613,-0.210158,0.165727,0.507139,0.360304,0.409235,0.033714,-0.116103,-0.248208,...,0.883278,0.027182,1.326357,-0.205095,-0.686394,0.532235,0.145234,0.016586,0.288479,0.606274
1,-0.485959,0.945893,0.445113,-0.715890,-0.301412,-0.052740,-0.397748,0.270035,-0.460695,-0.294081,...,-0.046065,-0.164151,-0.277525,0.261794,-0.068311,0.431705,-0.213186,-0.573583,-0.078941,-0.387077
2,-0.645697,0.363637,0.928545,-0.177133,0.375063,-0.378120,-0.074192,-0.027364,-0.166625,-0.699849,...,-0.513778,0.132884,0.726032,-0.051234,-0.201419,-0.679559,-0.039691,-0.154259,0.821360,-1.029511
3,0.007982,-0.119265,-0.254778,0.039290,0.067811,-0.407250,-0.092729,0.063339,0.138287,-0.449985,...,-0.004080,-0.041485,-0.006391,0.209681,-0.204729,0.020591,-0.041234,-0.371177,-0.036550,-0.304493
4,-0.250110,0.206184,0.878673,-0.060909,-0.693472,-0.310396,0.017901,0.020361,-0.313942,0.615974,...,-0.239595,0.581819,0.511321,0.227073,0.481617,0.599454,0.029895,-0.010553,-0.475891,0.718033
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24978,-0.730169,-0.031571,-0.836192,0.320712,-1.085958,0.083085,-0.194057,0.048488,-0.620317,0.487865,...,-1.524487,-0.219132,-0.041619,-0.207920,-0.555294,-0.965212,0.402377,0.100585,-0.021638,0.255030
24979,-0.982884,0.180405,0.403154,0.012332,-0.062073,-0.872497,-0.862941,0.223579,-0.649529,-0.375198,...,0.121970,-0.411345,0.242665,0.175302,0.417094,-0.348651,0.322378,0.578813,0.035689,-0.771677
24980,0.076153,0.251545,-0.007040,0.047573,-1.213524,0.029352,0.175674,0.039051,0.012186,0.117219,...,-0.200420,0.014490,1.286112,0.054143,-0.031784,0.691570,0.635157,-0.513831,0.657975,-0.789508
24981,0.411928,0.236942,0.213315,-0.197184,-1.051007,-0.055444,-0.132759,-0.166720,-0.274337,-0.292379,...,0.004268,-0.295684,0.553627,0.221775,0.300616,0.508309,-0.202422,0.232041,0.589790,-0.712487


In [10]:
train_joke_df = train_joke_df.merge(jokes_df, how='inner', left_on='JID',  right_index=True) \
                            .merge(users_df, how='inner', left_on='UID',  right_index=True)
train_joke_df

Unnamed: 0,UID,JID,Rating,j_emb_0,j_emb_1,j_emb_2,j_emb_3,j_emb_4,j_emb_5,j_emb_6,...,u_22,u_23,u_24,u_25,u_26,u_27,u_28,u_29,u_30,u_31
0,18028,5,-1.26,-0.136766,-0.013906,0.140304,-0.045568,-0.215760,0.001899,-0.029915,...,-0.733468,-0.105453,0.499175,0.066219,0.111256,-0.133982,0.096381,0.435898,-0.561148,-0.483245
205981,18028,37,-0.92,-0.120649,0.078239,0.129345,-0.060937,-0.214939,-0.121421,-0.018706,...,-0.733468,-0.105453,0.499175,0.066219,0.111256,-0.133982,0.096381,0.435898,-0.561148,-0.483245
832596,18028,35,3.98,-0.152154,-0.229791,0.154801,-0.059190,-0.251570,-0.048316,0.178035,...,-0.733468,-0.105453,0.499175,0.066219,0.111256,-0.133982,0.096381,0.435898,-0.561148,-0.483245
924751,18028,34,3.11,-0.147116,-0.175355,0.151641,-0.095654,-0.246760,0.014102,0.156275,...,-0.733468,-0.105453,0.499175,0.066219,0.111256,-0.133982,0.096381,0.435898,-0.561148,-0.483245
1302061,18028,28,8.88,-0.160405,-0.226419,0.133310,-0.068568,-0.244091,0.063908,0.106469,...,-0.733468,-0.105453,0.499175,0.066219,0.111256,-0.133982,0.096381,0.435898,-0.561148,-0.483245
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
242709,21057,19,7.14,-0.109361,0.023283,0.119035,-0.081532,-0.172330,0.099342,0.115429,...,0.406925,0.073443,0.599645,0.445132,-0.282644,-0.531429,0.155384,-0.120837,-0.092186,-0.533658
957801,21057,66,3.50,-0.122767,0.244425,0.110366,-0.008108,-0.183211,-0.151318,-0.136290,...,0.406925,0.073443,0.599645,0.445132,-0.282644,-0.531429,0.155384,-0.120837,-0.092186,-0.533658
1189585,21057,87,3.45,-0.140342,0.102317,0.157228,-0.015092,-0.217698,0.061223,0.163278,...,0.406925,0.073443,0.599645,0.445132,-0.282644,-0.531429,0.155384,-0.120837,-0.092186,-0.533658
927620,21057,4,6.02,-0.110405,0.015444,0.141151,-0.056691,-0.191257,0.052613,0.081804,...,0.406925,0.073443,0.599645,0.445132,-0.282644,-0.531429,0.155384,-0.120837,-0.092186,-0.533658


In [19]:

test_joke_df_nofactrating = pd.read_csv(r'..\data\recsys-in-practice\test_joke_df_nofactrating.csv', index_col=0)
len(test_joke_df_nofactrating)

362091

In [21]:
test_joke_df_nofactrating['UID'] = test_joke_df_nofactrating['UID'] - 1
test_joke_df_nofactrating['JID'] = test_joke_df_nofactrating['JID'] - 1

In [22]:
test_joke_df_nofactrating = test_joke_df_nofactrating.merge(jokes_df, how='inner', left_on='JID',  right_index=True) \
                                .merge(users_df, how='inner', left_on='UID',  right_index=True)
test_joke_df_nofactrating

Unnamed: 0_level_0,UID,JID,j_emb_0,j_emb_1,j_emb_2,j_emb_3,j_emb_4,j_emb_5,j_emb_6,j_emb_7,...,u_22,u_23,u_24,u_25,u_26,u_27,u_28,u_29,u_30,u_31
InteractionID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,11227,38,-0.116937,0.027235,0.129899,-0.063150,-0.206269,-0.071847,-0.080700,0.055479,...,-0.301830,-0.000620,1.433958,0.365661,0.144065,-1.793244,-0.620851,-1.447682,0.076965,1.634137
122911,11227,55,-0.135140,-0.086692,0.132088,-0.096544,-0.227490,-0.065795,0.083030,-0.016752,...,-0.301830,-0.000620,1.433958,0.365661,0.144065,-1.793244,-0.620851,-1.447682,0.076965,1.634137
326318,11227,8,-0.115149,0.204732,0.108475,-0.032418,-0.186207,-0.069652,-0.172419,-0.038763,...,-0.301830,-0.000620,1.433958,0.365661,0.144065,-1.793244,-0.620851,-1.447682,0.076965,1.634137
352687,11227,15,-0.084877,0.131114,0.077526,-0.048014,-0.132597,-0.055589,0.089108,0.138736,...,-0.301830,-0.000620,1.433958,0.365661,0.144065,-1.793244,-0.620851,-1.447682,0.076965,1.634137
291864,11227,34,-0.147116,-0.175355,0.151641,-0.095654,-0.246760,0.014102,0.156275,-0.051353,...,-0.301830,-0.000620,1.433958,0.365661,0.144065,-1.793244,-0.620851,-1.447682,0.076965,1.634137
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
40924,6132,16,-0.098334,-0.140107,0.106744,-0.031099,-0.169286,-0.149478,0.018066,0.047286,...,-0.527641,0.135909,0.521594,-0.187243,0.029391,-0.291993,-0.209998,-0.177303,-0.061121,-0.383874
250562,2267,16,-0.098334,-0.140107,0.106744,-0.031099,-0.169286,-0.149478,0.018066,0.047286,...,-1.102729,0.113644,0.915779,-0.215404,-0.390778,-0.870952,-0.095459,-0.055836,0.239556,-0.318991
341892,2267,70,-0.104122,0.178185,0.110333,-0.061403,-0.187474,0.103858,0.059013,0.122808,...,-1.102729,0.113644,0.915779,-0.215404,-0.390778,-0.870952,-0.095459,-0.055836,0.239556,-0.318991
88621,8918,10,-0.145565,-0.037493,0.132847,-0.058355,-0.222131,-0.028390,-0.117221,-0.041754,...,-0.925111,-0.458085,1.452385,0.144862,0.016681,-0.679180,-0.464443,0.055869,-0.087316,0.399645


In [23]:
len(test_joke_df_nofactrating)

362091

Unnamed: 0,j_0,j_1,j_2,j_3,j_4,j_5,j_6,j_7,j_8,j_9,...,j_90,j_91,j_92,j_93,j_94,j_95,j_96,j_97,j_98,j_99
0,-7.820312,8.789062,-9.656250,-8.156250,-7.519531,-8.500000,-9.851562,4.171875,-8.976562,-4.761719,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,-5.628906,0.000000,0.000000,0.000000
1,4.078125,-0.290039,6.359375,4.371094,0.000000,-9.656250,-0.729980,-5.339844,8.882812,9.218750,...,2.820312,-4.949219,-0.290039,7.859375,-0.189941,0.000000,3.060547,0.000000,-4.320312,0.000000
2,0.000000,0.000000,0.000000,0.000000,9.031250,9.273438,9.031250,9.273438,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
3,0.000000,8.351562,0.000000,0.000000,1.799805,0.000000,-2.820312,6.210938,0.000000,1.839844,...,0.000000,0.000000,0.000000,0.529785,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
4,8.500000,4.609375,-4.171875,-5.390625,0.000000,0.000000,7.039062,0.000000,-0.439941,0.000000,...,0.000000,5.578125,4.269531,5.191406,5.730469,1.549805,0.000000,6.550781,1.799805,1.599609
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24978,0.439941,0.000000,0.000000,2.330078,0.000000,6.750000,-8.789062,-0.529785,0.000000,0.000000,...,8.828125,-1.209961,9.218750,-6.699219,0.000000,9.031250,6.550781,8.687500,0.000000,7.429688
24979,0.000000,-8.156250,8.593750,9.078125,0.870117,-8.929688,-3.500000,5.781250,-8.109375,0.000000,...,-1.169922,-5.730469,0.000000,0.239990,9.218750,-8.203125,0.000000,-8.593750,9.132812,8.453125
24980,0.000000,0.000000,0.000000,0.000000,-7.769531,0.000000,0.000000,-6.750000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
24981,0.000000,0.000000,0.000000,0.000000,-9.710938,0.000000,4.558594,-8.296875,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_90,j_91,j_92,j_93,j_94,j_95,j_96,j_97,j_98,j_99
0,18028,5,-1.26,0.0,0.0,0.0,0.0,1.120117,-1.259766,3.300781,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
18990,18028,39,8.30,0.0,0.0,0.0,0.0,1.120117,-1.259766,3.300781,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
50486,18028,6,3.30,0.0,0.0,0.0,0.0,1.120117,-1.259766,3.300781,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
69643,18028,15,-9.27,0.0,0.0,0.0,0.0,1.120117,-1.259766,3.300781,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
124589,18028,14,8.11,0.0,0.0,0.0,0.0,1.120117,-1.259766,3.300781,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1303561,21378,61,0.15,0.0,0.0,0.0,0.0,-2.820312,0.000000,-8.007812,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.189941,0.0,0.0
1331118,21378,45,1.31,0.0,0.0,0.0,0.0,-2.820312,0.000000,-8.007812,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.189941,0.0,0.0
1368576,21378,48,-0.68,0.0,0.0,0.0,0.0,-2.820312,0.000000,-8.007812,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.189941,0.0,0.0
1376073,21378,53,2.33,0.0,0.0,0.0,0.0,-2.820312,0.000000,-8.007812,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.189941,0.0,0.0


Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
0,18028,5,-1.26,0.000000,0.000000,0.000000,0.000000,1.120117,-1.259766,3.300781,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
120304,3365,5,3.35,5.050781,7.621094,8.929688,4.320312,0.000000,3.349609,1.700195,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
492994,12734,5,1.26,-2.669922,0.000000,-1.500000,0.000000,-3.500000,1.259766,-0.629883,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
81927,11364,5,-5.87,0.000000,0.000000,0.000000,0.000000,0.000000,-5.871094,1.070312,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
650652,17989,5,-0.29,-0.290039,-3.789062,2.039062,0.000000,8.453125,-0.290039,-7.089844,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
501169,19067,84,1.60,0.000000,0.000000,0.000000,0.000000,2.140625,8.062500,-0.099976,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
326575,24909,84,-1.80,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,-2.570312,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1040976,9016,84,3.93,0.000000,0.000000,0.000000,0.000000,2.820312,0.000000,0.000000,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1148245,21967,84,3.54,0.000000,0.000000,0.000000,0.000000,1.799805,2.619141,-3.109375,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882


  0%|          | 0/1448364 [00:00<?, ?it/s]

  0%|          | 0/1448364 [00:00<?, ?it/s]

Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
0,18028,5,-1.26,0.000000,0.000000,0.000000,0.000000,1.120117,0.000000,3.300781,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
120304,3365,5,3.35,5.050781,7.621094,8.929688,4.320312,0.000000,0.000000,1.700195,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
492994,12734,5,1.26,-2.669922,0.000000,-1.500000,0.000000,-3.500000,0.000000,-0.629883,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
81927,11364,5,-5.87,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,1.070312,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
650652,17989,5,-0.29,-0.290039,-3.789062,2.039062,0.000000,8.453125,0.000000,-7.089844,...,-0.213626,-0.049605,0.246616,-0.196301,-0.056399,-0.182334,-0.088423,-0.072930,-0.110614,-0.178068
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
501169,19067,84,1.60,0.000000,0.000000,0.000000,0.000000,2.140625,8.062500,-0.099976,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
326575,24909,84,-1.80,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,-2.570312,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1040976,9016,84,3.93,0.000000,0.000000,0.000000,0.000000,2.820312,0.000000,0.000000,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1148245,21967,84,3.54,0.000000,0.000000,0.000000,0.000000,1.799805,2.619141,-3.109375,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882


Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
1273601,6647,84,1.07,0.0,0.0,0.0,0.0,0.0,0.0,-0.680176,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1133216,11958,84,-6.89,0.0,0.0,0.0,0.0,0.0,0.0,-0.48999,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
913962,8582,84,4.9,0.0,0.0,-9.898438,0.0,1.410156,0.72998,2.230469,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
630768,7849,84,-6.41,0.0,0.0,0.0,0.0,1.839844,-9.757812,3.009766,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1205897,12319,84,3.79,0.0,0.0,0.0,0.0,0.0,0.0,-1.549805,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
735943,24881,84,-7.96,0.0,0.0,0.0,0.0,1.169922,0.0,0.0,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1437776,3543,84,3.11,0.0,0.0,0.0,0.0,0.779785,0.0,-6.171875,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
1306126,12924,84,4.9,0.0,0.0,0.0,0.0,6.210938,0.0,0.0,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
137588,3196,84,3.06,0.0,0.0,0.0,0.0,-5.390625,0.0,1.30957,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
512106,4710,84,5.34,3.5,3.539062,1.839844,3.980469,-7.820312,4.269531,-4.609375,...,-0.197037,-0.01147,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882


In [16]:
rows, r_pos = np.unique(train_joke_df.values[:,0], return_inverse=True)
cols, c_pos = np.unique(train_joke_df.values[:,1], return_inverse=True)

In [17]:
rows

array([0.0000e+00, 1.0000e+00, 2.0000e+00, ..., 2.4980e+04, 2.4981e+04,
       2.4982e+04])

In [18]:
cols

array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
       26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38.,
       39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51.,
       52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64.,
       65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77.,
       78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90.,
       91., 92., 93., 94., 95., 96., 97., 98., 99.])

In [19]:
train_df, valid_df, train_queries, valid_queries = train_test_split(train_joke_df, r_pos, test_size=0.15, random_state=42)

In [20]:
len(valid_df)

14484

Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31


Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
337935,44,55,7.91,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.217123,-0.072647,0.254742,0.037862,-0.048185,-0.195705,-0.074436,-0.072623,-0.124790,-0.166313
1072225,44,11,7.04,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.209348,-0.090472,0.253714,-0.023258,-0.070321,-0.189687,-0.081262,-0.079108,-0.113483,-0.156067
1310937,44,42,7.91,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.174013,-0.000910,0.192058,-0.175867,-0.054576,-0.146486,-0.046585,-0.003952,-0.090521,-0.140018
1190578,44,77,3.54,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.206961,0.058762,0.248765,-0.056504,-0.066395,-0.178045,-0.094212,0.032349,-0.105475,-0.168493
551466,44,35,8.50,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.236904,-0.117266,0.281734,0.032962,-0.064384,-0.221286,-0.083540,-0.077467,-0.120540,-0.192433
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
464562,44,84,1.31,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
693981,44,22,8.83,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.182864,-0.000172,0.222849,-0.183920,-0.049851,-0.167913,-0.060558,-0.049183,-0.100267,-0.148523
243050,44,27,8.01,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.217150,0.062103,0.245869,-0.273162,-0.056651,-0.187698,-0.077184,0.012077,-0.119872,-0.177158
1244955,44,15,3.83,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.131025,0.058989,0.154818,0.073774,-0.055937,-0.112853,-0.049320,0.012145,-0.068114,-0.099235


  0%|          | 0/14484 [00:00<?, ?it/s]

Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
337935,44,55,7.91,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.217123,-0.072647,0.254742,0.037862,-0.048185,-0.195705,-0.074436,-0.072623,-0.124790,-0.166313
1072225,44,11,7.04,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.209348,-0.090472,0.253714,-0.023258,-0.070321,-0.189687,-0.081262,-0.079108,-0.113483,-0.156067
1310937,44,42,7.91,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.174013,-0.000910,0.192058,-0.175867,-0.054576,-0.146486,-0.046585,-0.003952,-0.090521,-0.140018
1190578,44,77,3.54,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.206961,0.058762,0.248765,-0.056504,-0.066395,-0.178045,-0.094212,0.032349,-0.105475,-0.168493
551466,44,35,8.50,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.236904,-0.117266,0.281734,0.032962,-0.064384,-0.221286,-0.083540,-0.077467,-0.120540,-0.192433
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
464562,44,84,1.31,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.197037,-0.011470,0.221943,0.054523,-0.081366,-0.173583,-0.074338,-0.039511,-0.101258,-0.150882
693981,44,22,8.83,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.182864,-0.000172,0.222849,-0.183920,-0.049851,-0.167913,-0.060558,-0.049183,-0.100267,-0.148523
243050,44,27,8.01,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.217150,0.062103,0.245869,-0.273162,-0.056651,-0.187698,-0.077184,0.012077,-0.119872,-0.177158
1244955,44,15,3.83,8.789062,0.0,5.488281,0.919922,0.0,0.0,4.511719,...,-0.131025,0.058989,0.154818,0.073774,-0.055937,-0.112853,-0.049320,0.012145,-0.068114,-0.099235


Unnamed: 0,UID,JID,Rating,j_0,j_1,j_2,j_3,j_4,j_5,j_6,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31


In [26]:
from catboost import CatBoostRanker, Pool, MetricVisualizer, CatBoostRegressor
from copy import deepcopy

In [27]:
train_df.drop(columns=['Rating'])

Unnamed: 0,UID,JID,j_0,j_1,j_2,j_3,j_4,j_5,j_6,j_7,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
1315628,14232,13,4.128906,-7.179688,3.880859,-5.679688,-9.320312,0.000000,1.209961,-4.371094,...,-0.219153,-0.062274,0.247847,0.110737,-0.055423,-0.186064,-0.078129,-0.099043,-0.112296,-0.165175
1393766,8113,45,5.578125,4.371094,4.421875,0.000000,2.330078,4.371094,-2.429688,-3.789062,...,-0.215624,-0.081347,0.247605,-0.001797,-0.042055,-0.191130,-0.068965,-0.114132,-0.109277,-0.169667
108972,18801,27,0.000000,0.000000,0.000000,0.000000,-5.781250,0.000000,-1.750000,0.000000,...,-0.217150,0.062103,0.245869,-0.273162,-0.056651,-0.187698,-0.077184,0.012077,-0.119872,-0.177158
199734,3512,1,0.870117,0.000000,-1.500000,0.000000,-3.160156,0.099976,-2.820312,1.309570,...,-0.188492,-0.215398,0.215273,0.228148,-0.045374,-0.168430,-0.083771,-0.114539,-0.096700,-0.148389
1035265,14831,67,0.000000,0.000000,0.000000,0.000000,3.740234,-0.150024,6.019531,2.519531,...,-0.234150,0.020325,0.271443,-0.232583,-0.072568,-0.212612,-0.089987,0.008103,-0.126860,-0.182385
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
633634,11081,21,0.870117,0.000000,0.580078,0.000000,2.960938,1.169922,0.000000,-5.921875,...,-0.196241,-0.090978,0.227133,0.065467,-0.044386,-0.185955,-0.080973,-0.110324,-0.102693,-0.157451
951565,613,76,-7.378906,-8.789062,4.128906,-7.230469,7.179688,0.000000,-4.949219,5.531250,...,-0.206986,0.132555,0.220478,-0.072370,-0.064179,-0.168671,-0.090839,0.048417,-0.101191,-0.162371
407318,6382,7,-0.340088,-9.273438,-0.239990,4.558594,4.078125,1.209961,0.000000,0.000000,...,-0.172836,-0.078154,0.211503,-0.105241,-0.045931,-0.144566,-0.062419,-0.027954,-0.094376,-0.141616
567497,1737,55,1.309570,-0.580078,7.960938,-3.300781,-1.799805,0.000000,3.349609,0.000000,...,-0.217123,-0.072647,0.254742,0.037862,-0.048185,-0.195705,-0.074436,-0.072623,-0.124790,-0.166313


In [28]:
cat_features = ['UID', 'JID']

In [29]:
train_pool = Pool(train_df.drop(columns=['Rating']), label=train_df['Rating'], cat_features=cat_features)
valid_pool = Pool(valid_df.drop(columns=['Rating']), label=valid_df['Rating'], cat_features=cat_features)
main_pool = Pool(train_joke_df.drop(columns=['Rating']), label=train_joke_df['Rating'], cat_features=cat_features)

test_pool = Pool(test_joke_df_nofactrating, cat_features=cat_features)

In [30]:
#train = Pool(
#    data=X_train,
#    label=y_train,
#    group_id=queries_train,
#    cat_features=[0, 1]
#)

#test = Pool(
#    data=X_test,
#    label=y_test,
#    group_id=queries_test,
#    cat_features=[0, 1]
#)

In [34]:
default_parameters = {
    'iterations': 1000,
    'custom_metric': 'RMSE',
    'random_seed': 0,
    'train_dir':'RMSE',
    'objective':'RMSE',
    'loss_function':'RMSE',
    'eval_metric':'RMSE',
}


In [35]:
model = CatBoostRegressor(**default_parameters)
model.fit(train_pool, eval_set=valid_pool, plot=True)

MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))

Learning rate set to 0.159393
0:	learn: 5.0587689	test: 5.0007497	best: 5.0007497 (0)	total: 425ms	remaining: 7m 4s
1:	learn: 4.9225808	test: 4.8534109	best: 4.8534109 (1)	total: 824ms	remaining: 6m 51s
2:	learn: 4.8220048	test: 4.7438272	best: 4.7438272 (2)	total: 1.07s	remaining: 5m 56s
3:	learn: 4.7439467	test: 4.6590053	best: 4.6590053 (3)	total: 1.51s	remaining: 6m 15s
4:	learn: 4.6832274	test: 4.5939083	best: 4.5939083 (4)	total: 1.97s	remaining: 6m 31s
5:	learn: 4.6342774	test: 4.5425347	best: 4.5425347 (5)	total: 2.44s	remaining: 6m 43s
6:	learn: 4.5950590	test: 4.5019889	best: 4.5019889 (6)	total: 2.74s	remaining: 6m 28s
7:	learn: 4.5644862	test: 4.4682614	best: 4.4682614 (7)	total: 3.14s	remaining: 6m 29s
8:	learn: 4.5369596	test: 4.4386635	best: 4.4386635 (8)	total: 3.62s	remaining: 6m 39s
9:	learn: 4.5100353	test: 4.4204260	best: 4.4204260 (9)	total: 4.15s	remaining: 6m 51s
10:	learn: 4.4859270	test: 4.4031564	best: 4.4031564 (10)	total: 4.73s	remaining: 7m 5s
11:	learn: 4.

93:	learn: 3.9536046	test: 4.2110865	best: 4.2110865 (93)	total: 46.3s	remaining: 7m 26s
94:	learn: 3.9495800	test: 4.2102079	best: 4.2102079 (94)	total: 46.9s	remaining: 7m 26s
95:	learn: 3.9463649	test: 4.2096610	best: 4.2096610 (95)	total: 47.3s	remaining: 7m 25s
96:	learn: 3.9428369	test: 4.2099697	best: 4.2096610 (95)	total: 47.9s	remaining: 7m 26s
97:	learn: 3.9373417	test: 4.2089812	best: 4.2089812 (97)	total: 48.6s	remaining: 7m 27s
98:	learn: 3.9320030	test: 4.2087957	best: 4.2087957 (98)	total: 49.1s	remaining: 7m 26s
99:	learn: 3.9284220	test: 4.2079422	best: 4.2079422 (99)	total: 49.6s	remaining: 7m 26s
100:	learn: 3.9259477	test: 4.2069626	best: 4.2069626 (100)	total: 50.1s	remaining: 7m 25s
101:	learn: 3.9242408	test: 4.2066579	best: 4.2066579 (101)	total: 50.6s	remaining: 7m 25s
102:	learn: 3.9170589	test: 4.2053927	best: 4.2053927 (102)	total: 51.2s	remaining: 7m 25s
103:	learn: 3.9131824	test: 4.2052181	best: 4.2052181 (103)	total: 51.8s	remaining: 7m 26s
104:	learn: 3

184:	learn: 3.6214558	test: 4.1775287	best: 4.1770739 (183)	total: 1m 31s	remaining: 6m 44s
185:	learn: 3.6191946	test: 4.1770564	best: 4.1770564 (185)	total: 1m 32s	remaining: 6m 43s
186:	learn: 3.6118576	test: 4.1764820	best: 4.1764820 (186)	total: 1m 32s	remaining: 6m 43s
187:	learn: 3.6088850	test: 4.1768109	best: 4.1764820 (186)	total: 1m 33s	remaining: 6m 43s
188:	learn: 3.6065964	test: 4.1763811	best: 4.1763811 (188)	total: 1m 33s	remaining: 6m 43s
189:	learn: 3.6029030	test: 4.1762258	best: 4.1762258 (189)	total: 1m 34s	remaining: 6m 42s
190:	learn: 3.5994551	test: 4.1760228	best: 4.1760228 (190)	total: 1m 34s	remaining: 6m 41s
191:	learn: 3.5964387	test: 4.1758658	best: 4.1758658 (191)	total: 1m 35s	remaining: 6m 41s
192:	learn: 3.5919682	test: 4.1753414	best: 4.1753414 (192)	total: 1m 36s	remaining: 6m 41s
193:	learn: 3.5884125	test: 4.1754185	best: 4.1753414 (192)	total: 1m 36s	remaining: 6m 40s
194:	learn: 3.5837032	test: 4.1753761	best: 4.1753414 (192)	total: 1m 36s	remain

274:	learn: 3.3596649	test: 4.1702618	best: 4.1696171 (267)	total: 2m 16s	remaining: 6m
275:	learn: 3.3580131	test: 4.1701978	best: 4.1696171 (267)	total: 2m 17s	remaining: 5m 59s
276:	learn: 3.3563404	test: 4.1701925	best: 4.1696171 (267)	total: 2m 17s	remaining: 5m 59s
277:	learn: 3.3526618	test: 4.1702828	best: 4.1696171 (267)	total: 2m 18s	remaining: 5m 59s
278:	learn: 3.3500403	test: 4.1702026	best: 4.1696171 (267)	total: 2m 18s	remaining: 5m 58s
279:	learn: 3.3490437	test: 4.1698467	best: 4.1696171 (267)	total: 2m 19s	remaining: 5m 57s
280:	learn: 3.3467679	test: 4.1698222	best: 4.1696171 (267)	total: 2m 19s	remaining: 5m 57s
281:	learn: 3.3441967	test: 4.1693310	best: 4.1693310 (281)	total: 2m 20s	remaining: 5m 56s
282:	learn: 3.3418722	test: 4.1691697	best: 4.1691697 (282)	total: 2m 20s	remaining: 5m 56s
283:	learn: 3.3380946	test: 4.1702733	best: 4.1691697 (282)	total: 2m 21s	remaining: 5m 55s
284:	learn: 3.3370318	test: 4.1698342	best: 4.1691697 (282)	total: 2m 21s	remaining:

364:	learn: 3.1550914	test: 4.1752164	best: 4.1683660 (315)	total: 3m 1s	remaining: 5m 14s
365:	learn: 3.1542489	test: 4.1753672	best: 4.1683660 (315)	total: 3m 1s	remaining: 5m 14s
366:	learn: 3.1531026	test: 4.1757277	best: 4.1683660 (315)	total: 3m 2s	remaining: 5m 13s
367:	learn: 3.1508054	test: 4.1762335	best: 4.1683660 (315)	total: 3m 2s	remaining: 5m 13s
368:	learn: 3.1488758	test: 4.1770892	best: 4.1683660 (315)	total: 3m 2s	remaining: 5m 12s
369:	learn: 3.1470976	test: 4.1775668	best: 4.1683660 (315)	total: 3m 3s	remaining: 5m 12s
370:	learn: 3.1462808	test: 4.1777110	best: 4.1683660 (315)	total: 3m 4s	remaining: 5m 12s
371:	learn: 3.1447571	test: 4.1778077	best: 4.1683660 (315)	total: 3m 4s	remaining: 5m 11s
372:	learn: 3.1426920	test: 4.1777171	best: 4.1683660 (315)	total: 3m 5s	remaining: 5m 11s
373:	learn: 3.1408800	test: 4.1782417	best: 4.1683660 (315)	total: 3m 5s	remaining: 5m 10s
374:	learn: 3.1392863	test: 4.1785204	best: 4.1683660 (315)	total: 3m 6s	remaining: 5m 10s

454:	learn: 2.9793343	test: 4.1906715	best: 4.1683660 (315)	total: 3m 47s	remaining: 4m 32s
455:	learn: 2.9782297	test: 4.1904078	best: 4.1683660 (315)	total: 3m 47s	remaining: 4m 31s
456:	learn: 2.9769530	test: 4.1905309	best: 4.1683660 (315)	total: 3m 48s	remaining: 4m 31s
457:	learn: 2.9749713	test: 4.1900881	best: 4.1683660 (315)	total: 3m 48s	remaining: 4m 30s
458:	learn: 2.9738210	test: 4.1903944	best: 4.1683660 (315)	total: 3m 49s	remaining: 4m 30s
459:	learn: 2.9722269	test: 4.1904621	best: 4.1683660 (315)	total: 3m 49s	remaining: 4m 29s
460:	learn: 2.9708155	test: 4.1905361	best: 4.1683660 (315)	total: 3m 49s	remaining: 4m 28s
461:	learn: 2.9703522	test: 4.1906820	best: 4.1683660 (315)	total: 3m 50s	remaining: 4m 28s
462:	learn: 2.9656846	test: 4.1914481	best: 4.1683660 (315)	total: 3m 50s	remaining: 4m 27s
463:	learn: 2.9638929	test: 4.1914399	best: 4.1683660 (315)	total: 3m 51s	remaining: 4m 27s
464:	learn: 2.9615940	test: 4.1919996	best: 4.1683660 (315)	total: 3m 51s	remain

544:	learn: 2.8414872	test: 4.2054276	best: 4.1683660 (315)	total: 4m 32s	remaining: 3m 47s
545:	learn: 2.8390518	test: 4.2053107	best: 4.1683660 (315)	total: 4m 32s	remaining: 3m 46s
546:	learn: 2.8374473	test: 4.2052475	best: 4.1683660 (315)	total: 4m 33s	remaining: 3m 46s
547:	learn: 2.8352090	test: 4.2058146	best: 4.1683660 (315)	total: 4m 33s	remaining: 3m 45s
548:	learn: 2.8331563	test: 4.2063189	best: 4.1683660 (315)	total: 4m 34s	remaining: 3m 45s
549:	learn: 2.8311660	test: 4.2069941	best: 4.1683660 (315)	total: 4m 35s	remaining: 3m 45s
550:	learn: 2.8291156	test: 4.2070940	best: 4.1683660 (315)	total: 4m 35s	remaining: 3m 44s
551:	learn: 2.8276971	test: 4.2079877	best: 4.1683660 (315)	total: 4m 36s	remaining: 3m 44s
552:	learn: 2.8247887	test: 4.2083558	best: 4.1683660 (315)	total: 4m 36s	remaining: 3m 43s
553:	learn: 2.8236956	test: 4.2088789	best: 4.1683660 (315)	total: 4m 37s	remaining: 3m 43s
554:	learn: 2.8206924	test: 4.2090203	best: 4.1683660 (315)	total: 4m 37s	remain

634:	learn: 2.7195636	test: 4.2221499	best: 4.1683660 (315)	total: 5m 18s	remaining: 3m 3s
635:	learn: 2.7158391	test: 4.2224272	best: 4.1683660 (315)	total: 5m 19s	remaining: 3m 2s
636:	learn: 2.7151049	test: 4.2225619	best: 4.1683660 (315)	total: 5m 19s	remaining: 3m 2s
637:	learn: 2.7142910	test: 4.2229588	best: 4.1683660 (315)	total: 5m 20s	remaining: 3m 1s
638:	learn: 2.7135415	test: 4.2232610	best: 4.1683660 (315)	total: 5m 20s	remaining: 3m 1s
639:	learn: 2.7091470	test: 4.2235723	best: 4.1683660 (315)	total: 5m 21s	remaining: 3m
640:	learn: 2.7040651	test: 4.2234160	best: 4.1683660 (315)	total: 5m 21s	remaining: 3m
641:	learn: 2.6999611	test: 4.2233665	best: 4.1683660 (315)	total: 5m 21s	remaining: 2m 59s
642:	learn: 2.6983380	test: 4.2230754	best: 4.1683660 (315)	total: 5m 22s	remaining: 2m 58s
643:	learn: 2.6953511	test: 4.2236763	best: 4.1683660 (315)	total: 5m 22s	remaining: 2m 58s
644:	learn: 2.6945016	test: 4.2238148	best: 4.1683660 (315)	total: 5m 23s	remaining: 2m 58s
6

724:	learn: 2.5896458	test: 4.2416434	best: 4.1683660 (315)	total: 6m 4s	remaining: 2m 18s
725:	learn: 2.5859653	test: 4.2416449	best: 4.1683660 (315)	total: 6m 5s	remaining: 2m 17s
726:	learn: 2.5845756	test: 4.2421328	best: 4.1683660 (315)	total: 6m 5s	remaining: 2m 17s
727:	learn: 2.5832381	test: 4.2421909	best: 4.1683660 (315)	total: 6m 6s	remaining: 2m 16s
728:	learn: 2.5798404	test: 4.2420168	best: 4.1683660 (315)	total: 6m 6s	remaining: 2m 16s
729:	learn: 2.5784656	test: 4.2421269	best: 4.1683660 (315)	total: 6m 7s	remaining: 2m 15s
730:	learn: 2.5782226	test: 4.2423621	best: 4.1683660 (315)	total: 6m 7s	remaining: 2m 15s
731:	learn: 2.5763906	test: 4.2427244	best: 4.1683660 (315)	total: 6m 8s	remaining: 2m 14s
732:	learn: 2.5757313	test: 4.2430427	best: 4.1683660 (315)	total: 6m 8s	remaining: 2m 14s
733:	learn: 2.5733017	test: 4.2431453	best: 4.1683660 (315)	total: 6m 9s	remaining: 2m 13s
734:	learn: 2.5725732	test: 4.2433107	best: 4.1683660 (315)	total: 6m 9s	remaining: 2m 13s

814:	learn: 2.4673831	test: 4.2655024	best: 4.1683660 (315)	total: 6m 51s	remaining: 1m 33s
815:	learn: 2.4669723	test: 4.2658018	best: 4.1683660 (315)	total: 6m 52s	remaining: 1m 32s
816:	learn: 2.4667856	test: 4.2659813	best: 4.1683660 (315)	total: 6m 52s	remaining: 1m 32s
817:	learn: 2.4659557	test: 4.2657238	best: 4.1683660 (315)	total: 6m 52s	remaining: 1m 31s
818:	learn: 2.4654367	test: 4.2659142	best: 4.1683660 (315)	total: 6m 53s	remaining: 1m 31s
819:	learn: 2.4646870	test: 4.2658638	best: 4.1683660 (315)	total: 6m 54s	remaining: 1m 30s
820:	learn: 2.4632741	test: 4.2659559	best: 4.1683660 (315)	total: 6m 54s	remaining: 1m 30s
821:	learn: 2.4593483	test: 4.2663456	best: 4.1683660 (315)	total: 6m 55s	remaining: 1m 29s
822:	learn: 2.4582956	test: 4.2665758	best: 4.1683660 (315)	total: 6m 55s	remaining: 1m 29s
823:	learn: 2.4569880	test: 4.2667025	best: 4.1683660 (315)	total: 6m 55s	remaining: 1m 28s
824:	learn: 2.4512348	test: 4.2682991	best: 4.1683660 (315)	total: 6m 56s	remain

904:	learn: 2.3533299	test: 4.2940351	best: 4.1683660 (315)	total: 7m 38s	remaining: 48.1s
905:	learn: 2.3511675	test: 4.2941498	best: 4.1683660 (315)	total: 7m 38s	remaining: 47.6s
906:	learn: 2.3503868	test: 4.2944133	best: 4.1683660 (315)	total: 7m 39s	remaining: 47.1s
907:	learn: 2.3499226	test: 4.2946030	best: 4.1683660 (315)	total: 7m 39s	remaining: 46.6s
908:	learn: 2.3486949	test: 4.2950759	best: 4.1683660 (315)	total: 7m 40s	remaining: 46.1s
909:	learn: 2.3482436	test: 4.2955499	best: 4.1683660 (315)	total: 7m 40s	remaining: 45.6s
910:	learn: 2.3466237	test: 4.2958444	best: 4.1683660 (315)	total: 7m 41s	remaining: 45.1s
911:	learn: 2.3434634	test: 4.2963961	best: 4.1683660 (315)	total: 7m 41s	remaining: 44.6s
912:	learn: 2.3420296	test: 4.2981702	best: 4.1683660 (315)	total: 7m 42s	remaining: 44.1s
913:	learn: 2.3407932	test: 4.2979007	best: 4.1683660 (315)	total: 7m 42s	remaining: 43.6s
914:	learn: 2.3394342	test: 4.2980770	best: 4.1683660 (315)	total: 7m 43s	remaining: 43s
9

995:	learn: 2.2519877	test: 4.3150498	best: 4.1683660 (315)	total: 8m 25s	remaining: 2.03s
996:	learn: 2.2517407	test: 4.3152867	best: 4.1683660 (315)	total: 8m 25s	remaining: 1.52s
997:	learn: 2.2509928	test: 4.3158729	best: 4.1683660 (315)	total: 8m 26s	remaining: 1.01s
998:	learn: 2.2506826	test: 4.3160053	best: 4.1683660 (315)	total: 8m 26s	remaining: 507ms
999:	learn: 2.2504299	test: 4.3162907	best: 4.1683660 (315)	total: 8m 27s	remaining: 0us

bestTest = 4.16836603
bestIteration = 315

Shrink model to first 316 iterations.


<catboost.core.CatBoostRegressor at 0x23cb0904520>

In [None]:
assert False

In [48]:
predict = model.predict(valid_pool)
print(mean_squared_error(valid_df['Rating'].values, predict, squared=False))


4.168366025985882


In [42]:

test_joke_df_nofactrating = pd.read_csv(r'..\data\recsys-in-practice\test_joke_df_nofactrating.csv', index_col=0)

In [43]:
test_joke_df_nofactrating = test_joke_df_nofactrating.merge(jokes_df, how='left', left_on='JID',  right_index=True)
test_joke_df_nofactrating

Unnamed: 0_level_0,UID,JID,j_emb_0,j_emb_1,j_emb_2,j_emb_3,j_emb_4,j_emb_5,j_emb_6,j_emb_7,...,j_emb_22,j_emb_23,j_emb_24,j_emb_25,j_emb_26,j_emb_27,j_emb_28,j_emb_29,j_emb_30,j_emb_31
InteractionID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,11228,39,-0.119995,0.018505,0.148837,-0.064350,-0.206992,-0.003035,-0.106557,0.064814,...,-0.187694,-0.096055,0.238734,0.032480,-0.047392,-0.190207,-0.081313,-0.056617,-0.113464,-0.170966
1,21724,85,-0.116442,0.197225,0.111184,-0.013174,-0.191459,-0.060305,0.054180,0.034010,...,-0.175970,0.123419,0.218063,-0.197847,-0.074726,-0.167683,-0.088025,0.039383,-0.107756,-0.149503
2,16782,56,-0.094673,0.287483,0.085811,-0.022507,-0.153191,-0.084600,-0.182968,0.080868,...,-0.150467,0.055504,0.166492,-0.112180,-0.041445,-0.136276,-0.044525,-0.031849,-0.103089,-0.116734
3,12105,42,-0.117512,0.179030,0.102959,-0.041321,-0.176818,-0.084887,-0.225433,-0.059977,...,-0.174013,-0.000910,0.192058,-0.175867,-0.054576,-0.146486,-0.046585,-0.003952,-0.090521,-0.140018
4,14427,2,-0.131769,0.053609,0.104178,-0.046955,-0.200002,0.016845,-0.251092,0.009414,...,-0.190417,-0.176878,0.221158,-0.021510,-0.072628,-0.171978,-0.068281,-0.097790,-0.112991,-0.154231
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
362086,3085,66,-0.122767,0.244425,0.110366,-0.008108,-0.183211,-0.151318,-0.136290,-0.093442,...,-0.172353,0.021159,0.198583,-0.183121,-0.049692,-0.144667,-0.076905,-0.075946,-0.083077,-0.131158
362087,13765,31,-0.147003,-0.216622,0.149209,-0.078921,-0.247421,0.087036,0.184531,-0.005167,...,-0.232575,-0.054796,0.284917,-0.067532,-0.064056,-0.231741,-0.084266,-0.037552,-0.127999,-0.193191
362088,10341,29,-0.112254,0.189472,0.113607,-0.052574,-0.182198,-0.012057,-0.170762,0.065305,...,-0.180020,0.004690,0.206803,-0.159529,-0.046954,-0.173159,-0.064026,-0.018296,-0.096136,-0.141635
362089,3553,8,-0.115149,0.204732,0.108475,-0.032418,-0.186207,-0.069652,-0.172419,-0.038763,...,-0.173986,0.033112,0.205515,-0.258891,-0.056810,-0.164397,-0.073520,-0.043038,-0.089910,-0.127763


In [44]:
test_joke_df_nofactrating = test_joke_df_nofactrating.merge(users_df, how='left', left_on='UID',  right_index=True)
test_joke_df_nofactrating

Unnamed: 0_level_0,UID,JID,j_emb_0,j_emb_1,j_emb_2,j_emb_3,j_emb_4,j_emb_5,j_emb_6,j_emb_7,...,j_90,j_91,j_92,j_93,j_94,j_95,j_96,j_97,j_98,j_99
InteractionID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,11228,39,-0.119995,0.018505,0.148837,-0.064350,-0.206992,-0.003035,-0.106557,0.064814,...,0.000000,9.171875,5.531250,0.000000,0.000000,0.000000,9.218750,0.000000,0.000000,0.000000
1,21724,85,-0.116442,0.197225,0.111184,-0.013174,-0.191459,-0.060305,0.054180,0.034010,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
2,16782,56,-0.094673,0.287483,0.085811,-0.022507,-0.153191,-0.084600,-0.182968,0.080868,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
3,12105,42,-0.117512,0.179030,0.102959,-0.041321,-0.176818,-0.084887,-0.225433,-0.059977,...,0.000000,-4.558594,5.148438,3.060547,-4.078125,4.660156,3.689453,0.489990,5.101562,0.000000
4,14427,2,-0.131769,0.053609,0.104178,-0.046955,-0.200002,0.016845,-0.251092,0.009414,...,2.039062,3.980469,0.000000,-9.851562,8.687500,-0.529785,-0.439941,-5.339844,0.000000,3.109375
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
362086,3085,66,-0.122767,0.244425,0.110366,-0.008108,-0.183211,-0.151318,-0.136290,-0.093442,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,-0.049988,0.000000
362087,13765,31,-0.147003,-0.216622,0.149209,-0.078921,-0.247421,0.087036,0.184531,-0.005167,...,-0.529785,0.000000,0.729980,-0.389893,1.019531,-1.549805,-3.880859,0.000000,1.839844,-0.099976
362088,10341,29,-0.112254,0.189472,0.113607,-0.052574,-0.182198,-0.012057,-0.170762,0.065305,...,2.720703,1.650391,0.000000,1.500000,0.000000,1.650391,3.199219,5.050781,2.859375,-2.330078
362089,3553,8,-0.115149,0.204732,0.108475,-0.032418,-0.186207,-0.069652,-0.172419,-0.038763,...,6.460938,0.000000,6.359375,0.000000,7.140625,6.750000,6.988281,-8.687500,5.531250,0.000000


In [45]:

test_pool = Pool(test_joke_df_nofactrating, cat_features=cat_features)

predict = model.predict(test_pool)

test_joke_df_nofactrating['Rating'] = predict

display(test_joke_df_nofactrating['Rating'].to_frame().head(5))
test_joke_df_nofactrating['Rating'].to_frame().to_csv('catboost_with_rating_and_item_emb.csv')

Unnamed: 0_level_0,Rating
InteractionID,Unnamed: 1_level_1
0,2.219415
1,-0.918924
2,-4.171209
3,1.776806
4,-0.698705


In [46]:
test_joke_df_nofactrating

Unnamed: 0_level_0,UID,JID,j_emb_0,j_emb_1,j_emb_2,j_emb_3,j_emb_4,j_emb_5,j_emb_6,j_emb_7,...,j_91,j_92,j_93,j_94,j_95,j_96,j_97,j_98,j_99,Rating
InteractionID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,11228,39,-0.119995,0.018505,0.148837,-0.064350,-0.206992,-0.003035,-0.106557,0.064814,...,9.171875,5.531250,0.000000,0.000000,0.000000,9.218750,0.000000,0.000000,0.000000,2.219415
1,21724,85,-0.116442,0.197225,0.111184,-0.013174,-0.191459,-0.060305,0.054180,0.034010,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,-0.918924
2,16782,56,-0.094673,0.287483,0.085811,-0.022507,-0.153191,-0.084600,-0.182968,0.080868,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,-4.171209
3,12105,42,-0.117512,0.179030,0.102959,-0.041321,-0.176818,-0.084887,-0.225433,-0.059977,...,-4.558594,5.148438,3.060547,-4.078125,4.660156,3.689453,0.489990,5.101562,0.000000,1.776806
4,14427,2,-0.131769,0.053609,0.104178,-0.046955,-0.200002,0.016845,-0.251092,0.009414,...,3.980469,0.000000,-9.851562,8.687500,-0.529785,-0.439941,-5.339844,0.000000,3.109375,-0.698705
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
362086,3085,66,-0.122767,0.244425,0.110366,-0.008108,-0.183211,-0.151318,-0.136290,-0.093442,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,-0.049988,0.000000,0.007615
362087,13765,31,-0.147003,-0.216622,0.149209,-0.078921,-0.247421,0.087036,0.184531,-0.005167,...,0.000000,0.729980,-0.389893,1.019531,-1.549805,-3.880859,0.000000,1.839844,-0.099976,3.336360
362088,10341,29,-0.112254,0.189472,0.113607,-0.052574,-0.182198,-0.012057,-0.170762,0.065305,...,1.650391,0.000000,1.500000,0.000000,1.650391,3.199219,5.050781,2.859375,-2.330078,0.588080
362089,3553,8,-0.115149,0.204732,0.108475,-0.032418,-0.186207,-0.069652,-0.172419,-0.038763,...,0.000000,6.359375,0.000000,7.140625,6.750000,6.988281,-8.687500,5.531250,0.000000,-1.757688


In [47]:

display(test_joke_df_nofactrating['Rating'].to_frame().head(30))

Unnamed: 0_level_0,Rating
InteractionID,Unnamed: 1_level_1
0,2.219415
1,-0.918924
2,-4.171209
3,1.776806
4,-0.698705
5,2.577563
6,3.184241
7,2.238143
8,1.018018
9,-2.958534


In [33]:
parameters = {
    'iterations': 1000,
    'custom_metric': 'RMSE',
    'random_seed': 0,
    'train_dir':'RMSE',
    'objective':'RMSE',
    'loss_function':'RMSE',
    'eval_metric':'RMSE',
}

model = CatBoostRegressor(**parameters)
model.fit(main_pool, eval_set=valid_pool, plot=True)

MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))

Learning rate set to 0.159644
0:	learn: 5.0573094	test: 4.9897510	best: 4.9897510 (0)	total: 391ms	remaining: 6m 30s
1:	learn: 4.9230318	test: 4.8397172	best: 4.8397172 (1)	total: 965ms	remaining: 8m 1s
2:	learn: 4.8201945	test: 4.7233051	best: 4.7233051 (2)	total: 1.24s	remaining: 6m 52s
3:	learn: 4.7413556	test: 4.6325059	best: 4.6325059 (3)	total: 1.85s	remaining: 7m 39s
4:	learn: 4.6794453	test: 4.5621262	best: 4.5621262 (4)	total: 2.11s	remaining: 7m
5:	learn: 4.6315501	test: 4.5120370	best: 4.5120370 (5)	total: 2.44s	remaining: 6m 44s
6:	learn: 4.5927317	test: 4.4711692	best: 4.4711692 (6)	total: 2.93s	remaining: 6m 55s
7:	learn: 4.5604341	test: 4.4343370	best: 4.4343370 (7)	total: 3.21s	remaining: 6m 38s
8:	learn: 4.5336764	test: 4.4022881	best: 4.4022881 (8)	total: 3.71s	remaining: 6m 48s
9:	learn: 4.5118377	test: 4.3792587	best: 4.3792587 (9)	total: 4.04s	remaining: 6m 39s
10:	learn: 4.4923950	test: 4.3612634	best: 4.3612634 (10)	total: 4.48s	remaining: 6m 42s
11:	learn: 4.468

93:	learn: 3.9549474	test: 3.8800950	best: 3.8800950 (93)	total: 46.7s	remaining: 7m 29s
94:	learn: 3.9476413	test: 3.8733198	best: 3.8733198 (94)	total: 47.2s	remaining: 7m 29s
95:	learn: 3.9450710	test: 3.8713711	best: 3.8713711 (95)	total: 47.7s	remaining: 7m 29s
96:	learn: 3.9406671	test: 3.8675562	best: 3.8675562 (96)	total: 48.2s	remaining: 7m 29s
97:	learn: 3.9370694	test: 3.8640788	best: 3.8640788 (97)	total: 48.9s	remaining: 7m 30s
98:	learn: 3.9336525	test: 3.8617196	best: 3.8617196 (98)	total: 49.3s	remaining: 7m 29s
99:	learn: 3.9296740	test: 3.8575888	best: 3.8575888 (99)	total: 50s	remaining: 7m 29s
100:	learn: 3.9242294	test: 3.8536524	best: 3.8536524 (100)	total: 50.4s	remaining: 7m 28s
101:	learn: 3.9201787	test: 3.8500103	best: 3.8500103 (101)	total: 50.9s	remaining: 7m 28s
102:	learn: 3.9152997	test: 3.8465540	best: 3.8465540 (102)	total: 51.3s	remaining: 7m 26s
103:	learn: 3.9100816	test: 3.8423329	best: 3.8423329 (103)	total: 51.9s	remaining: 7m 27s
104:	learn: 3.9

184:	learn: 3.6132747	test: 3.5616655	best: 3.5616655 (184)	total: 1m 32s	remaining: 6m 49s
185:	learn: 3.6102218	test: 3.5589154	best: 3.5589154 (185)	total: 1m 33s	remaining: 6m 48s
186:	learn: 3.6078082	test: 3.5565991	best: 3.5565991 (186)	total: 1m 33s	remaining: 6m 47s
187:	learn: 3.6065143	test: 3.5555380	best: 3.5555380 (187)	total: 1m 34s	remaining: 6m 47s
188:	learn: 3.6020293	test: 3.5518481	best: 3.5518481 (188)	total: 1m 34s	remaining: 6m 47s
189:	learn: 3.6006107	test: 3.5499939	best: 3.5499939 (189)	total: 1m 35s	remaining: 6m 46s
190:	learn: 3.5968585	test: 3.5466790	best: 3.5466790 (190)	total: 1m 35s	remaining: 6m 45s
191:	learn: 3.5928080	test: 3.5434139	best: 3.5434139 (191)	total: 1m 36s	remaining: 6m 45s
192:	learn: 3.5864604	test: 3.5371579	best: 3.5371579 (192)	total: 1m 36s	remaining: 6m 44s
193:	learn: 3.5808929	test: 3.5325545	best: 3.5325545 (193)	total: 1m 37s	remaining: 6m 43s
194:	learn: 3.5795976	test: 3.5305621	best: 3.5305621 (194)	total: 1m 37s	remain

274:	learn: 3.3695485	test: 3.3346832	best: 3.3346832 (274)	total: 2m 17s	remaining: 6m 3s
275:	learn: 3.3659300	test: 3.3321192	best: 3.3321192 (275)	total: 2m 18s	remaining: 6m 3s
276:	learn: 3.3642348	test: 3.3304725	best: 3.3304725 (276)	total: 2m 18s	remaining: 6m 2s
277:	learn: 3.3629485	test: 3.3295020	best: 3.3295020 (277)	total: 2m 19s	remaining: 6m 2s
278:	learn: 3.3606948	test: 3.3277926	best: 3.3277926 (278)	total: 2m 19s	remaining: 6m 1s
279:	learn: 3.3556507	test: 3.3239712	best: 3.3239712 (279)	total: 2m 20s	remaining: 6m
280:	learn: 3.3519899	test: 3.3209416	best: 3.3209416 (280)	total: 2m 20s	remaining: 6m
281:	learn: 3.3511678	test: 3.3204159	best: 3.3204159 (281)	total: 2m 21s	remaining: 5m 59s
282:	learn: 3.3479504	test: 3.3184955	best: 3.3184955 (282)	total: 2m 21s	remaining: 5m 59s
283:	learn: 3.3459817	test: 3.3168053	best: 3.3168053 (283)	total: 2m 22s	remaining: 5m 58s
284:	learn: 3.3433071	test: 3.3136904	best: 3.3136904 (284)	total: 2m 22s	remaining: 5m 58s
2

364:	learn: 3.1370242	test: 3.1239197	best: 3.1239197 (364)	total: 3m 2s	remaining: 5m 18s
365:	learn: 3.1361216	test: 3.1225019	best: 3.1225019 (365)	total: 3m 3s	remaining: 5m 17s
366:	learn: 3.1346144	test: 3.1211435	best: 3.1211435 (366)	total: 3m 3s	remaining: 5m 17s
367:	learn: 3.1331420	test: 3.1198082	best: 3.1198082 (367)	total: 3m 4s	remaining: 5m 16s
368:	learn: 3.1307370	test: 3.1180800	best: 3.1180800 (368)	total: 3m 4s	remaining: 5m 15s
369:	learn: 3.1292793	test: 3.1165756	best: 3.1165756 (369)	total: 3m 5s	remaining: 5m 15s
370:	learn: 3.1284531	test: 3.1159119	best: 3.1159119 (370)	total: 3m 5s	remaining: 5m 15s


KeyboardInterrupt: 

In [None]:
predict = model.predict(valid_pool)
print(mean_squared_error(valid_df['Rating'].values, predict, squared=False))

In [None]:
predict = model.predict(test_pool)

test_joke_df_nofactrating['Rating'] = predict

display(test_joke_df_nofactrating['Rating'].to_frame().head(5))
test_joke_df_nofactrating['Rating'].to_frame().to_csv('catboost_with_rating_and_item_emb.csv')