In [1]:
import logging
import math
import os

import pandas as pd
from tqdm.auto import tqdm
from matplotlib import pyplot as plt

from data.data_reader import DataReader
from tools.restaurant_profiles_manager import RestaurantProfilesManager
from tools.user_profiles_manager import UserProfilesManager

logging.basicConfig(level=logging.INFO)
tqdm.pandas()
while str(os.getcwd())[-3:] != 'src':  # Execute from src-directory root
    os.chdir('..')

In [2]:
train_data, test_data = DataReader().read_data()
(b_train, r_train, u_train), (b_test, r_test, u_test) = train_data, test_data
user_profiles = UserProfilesManager().get_best()
business_profiles = RestaurantProfilesManager().get_best()

                                                                                           

In [19]:
def get_bias(row) -> float:
    return global_mean + row['restaurant_bias'] + row['user_bias']

In [3]:
global_mean = r_train['stars_normalised'].mean()
restaurant_bias = (r_train.groupby("business_id").mean(numeric_only=True)['stars_normalised'] - global_mean).rename("restaurant_bias")
ratings_in_testset_per_user = r_test.groupby("user_id").count()['funny_cool'].rename('review_count')

In [None]:
verwerk_data = r_test[['stars_normalised', 'user_id', "business_id"]].join(restaurant_bias, on='business_id').join(ratings_in_testset_per_user, on='user_id')
verwerk_data = verwerk_data[verwerk_data['review_count'] > 1]

all_user_biases_list = []
to_drop = []
all_users = r_test['user_id'].unique()
for current_user_id in tqdm(all_users, desc='Calculating User Biases'):
    reviews_of_user_all = r_test[r_test['user_id'] == current_user_id]
    reviews_of_user_gen = reviews_of_user_all.sample(frac=0.5)
    to_drop.append(reviews_of_user_gen.index)
    user_bias = reviews_of_user_gen.join(restaurant_bias, on='business_id').apply(lambda row: row['stars_normalised'] - global_mean - row['restaurant_bias'], axis=1).mean()
    all_user_biases_list.append(user_bias)
user_bias = pd.Series(data=all_user_biases_list, index=all_users).rename("user_bias").fillna(0)
user_bias

In [16]:
drops = pd.concat([pd.Series(index=x) for x in to_drop])
verwerk_data = verwerk_data.drop(drops.index)

In [17]:
verwerk_data = verwerk_data.join(user_bias, on='user_id')
verwerk_data

Unnamed: 0_level_0,stars_normalised,user_id,business_id,restaurant_bias,review_count,user_bias
review_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1,0.50,1157506,4603,-0.099899,31,0.178958
2,1.00,1157507,2239,0.114750,4,0.023634
7,0.75,1157508,1052,0.094565,2,0.298278
17,0.75,1157509,5033,-0.041176,101,-0.033191
32,1.00,1157512,813,0.169507,5,0.249256
...,...,...,...,...,...,...
4730915,0.50,1258447,49669,0.080511,10,0.022816
4730917,0.75,1202479,50485,0.205263,165,-0.047763
4730920,0.25,1178381,47556,-0.029410,170,0.004191
4730928,0.25,1196912,49942,-0.137737,52,-0.085497


In [20]:
predictions = verwerk_data.progress_apply(get_bias, axis=1).rename('predicted')
predictions

  0%|          | 0/383064 [00:00<?, ?it/s]

review_id
1          0.777643
2          0.836967
7          1.091426
17         0.624216
32         1.117346
             ...   
4730915    0.801910
4730917    0.856083
4730920    0.673364
4730928    0.475350
4730988    0.903824
Name: predicted, Length: 383064, dtype: float64

In [24]:
scaled_predictions = round((predictions * 4) + 1)
actual = (verwerk_data['stars_normalised'] * 4) + 1
difference = abs(scaled_predictions - actual).rename("difference")

mse = (predictions - verwerk_data['stars_normalised']).transform(lambda x: x * x).mean()
print(f"MSE: {mse}")
print(f"adjusted RMSE: {math.sqrt(mse) * 4}")
results = pd.concat([scaled_predictions, actual, difference], axis=1)
results

MSE: 0.10365450248234899
adjusted RMSE: 1.2878167725719307


Unnamed: 0_level_0,predicted,stars_normalised,difference
review_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,4.0,3.0,1.0
2,4.0,5.0,1.0
7,5.0,4.0,1.0
17,3.0,4.0,1.0
32,5.0,5.0,0.0
...,...,...,...
4730915,4.0,3.0,1.0
4730917,4.0,4.0,0.0
4730920,4.0,2.0,2.0
4730928,3.0,2.0,1.0
