In [1]:
import numpy  as np
import pandas as pd
from metrics import MSE
from sklearn.linear_model import Ridge

In [2]:
X_train = pd.read_pickle("../Datasets/final/X_train.pkl").drop(columns=["gPlusUserId", "gPlusPlaceId"])
c = X_train.columns.tolist()
X_train = X_train.to_numpy()
X_val = pd.read_pickle("../Datasets/final/X_val.pkl").drop(columns=["gPlusUserId", "gPlusPlaceId"]).to_numpy()
X_test = pd.read_pickle("../Datasets/final/X_test.pkl").drop(columns=["gPlusUserId", "gPlusPlaceId"]).to_numpy()
y_train = pd.read_pickle("../Datasets/final/y_train.pkl").to_numpy()
y_val = pd.read_pickle("../Datasets/final/y_val.pkl").to_numpy()
y_test = pd.read_pickle("../Datasets/final/y_test.pkl").to_numpy()

In [3]:
best_model = None
best_mse = np.Inf

for alpha in [1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3]:
    r = Ridge(alpha=alpha)
    r.fit(X_train, y_train)
    
    y_pred_val = r.predict(X_val)
    mse_val = MSE(predictions=y_pred_val, labels=y_val)[0]
    print(f"Alpha: {alpha}, MSE: {mse_val}")
    
    if mse_val < best_mse:
        best_mse = mse_val
        best_model = r

Alpha: 0.001, MSE: 0.4198221535644171
Alpha: 0.01, MSE: 0.41982215540059326
Alpha: 0.1, MSE: 0.4198221731651225
Alpha: 1.0, MSE: 0.41982230166146717
Alpha: 10.0, MSE: 0.4198221700407061
Alpha: 100.0, MSE: 0.4198195635573689
Alpha: 1000.0, MSE: 0.41982414476216634


In [4]:
best_model, best_mse

(Ridge(alpha=100.0), 0.4198195635573689)

In [5]:
y_pred = best_model.predict(X_test)

In [6]:
MSE(predictions=y_pred, labels=y_test)[0]

0.4171984394646411

In [7]:
dict(zip(c, r.coef_[0]))

{'userCategoryAvgRating': 0.9979447363339872,
 'year_1990': 0.03630609241807717,
 'year_2002': 0.0003773941404473378,
 'year_2003': 0.0006178872993894914,
 'year_2004': -0.01330845997373743,
 'year_2005': -0.008164493601531568,
 'year_2006': 0.00045301678640148054,
 'year_2007': 0.0028358145726326404,
 'year_2008': -0.018343134023336687,
 'year_2009': -0.004486869275126905,
 'year_2010': 0.0006056686100361295,
 'year_2011': -0.016616621027220443,
 'year_2012': 0.018896782313159173,
 'year_2013': 0.0028657041537730682,
 'year_2014': -0.0020387824410434403,
 'month_01': -0.003274546082343076,
 'month_02': 0.00021056368455661675,
 'month_03': -0.0011856370443454783,
 'month_04': -0.0031093838538653077,
 'month_05': -0.0018412432355680492,
 'month_06': 0.015025128535098596,
 'month_07': 0.012533218264568016,
 'month_08': 0.007338020814773842,
 'month_09': 0.0030616400128928817,
 'month_10': -0.0080703143978739,
 'month_11': -0.009574727788732447,
 'month_12': -0.011112718874720283,
 'final