# Necessary modules and environment params

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")

In [3]:
from src.data_pipeline.DataLoader import DataLoader
from src.utility.sys_utils import get_spark

# import your model
from src.model.ALS_MF import ALS_MF

# import training pipeline
from src.data_pipeline.pipeline import cross_validation, test_evaluation

# import result checking tools
from src.utility.Summary import Summary

# Load whole data from the text file

In [4]:
spark = get_spark(cores=4) # change cores up to 6 if needed
data_loader = DataLoader(spark, "ml-1m-full")

Using split of range (0, 0.2), test set contains 202451 of 1000209 records in total.


In [5]:
spark

# Get train and test set

In [6]:
train = data_loader.get_train_set()
test = data_loader.get_test_set()

# Construct Model

In [7]:
als_params = {
    "rank": 64,
    "maxIter": 15,
    "regParam": 0.05
}

In [8]:
als = ALS_MF(als_params)

In [9]:
als.get_name()

'ALS_MF'

# Cross validation

In [10]:
cross_validation(data_loader, als, spark, k_fold=5, top_k=10)

Using split of range [0.0, 0.2], test set contains 162035 of 797758 records in total.
Using cached file from /home/ds2019/log/ml-1m-full/ALS_MF/maxIter_15-rank_64-regParam_0.05/Ranking_fold_0.parquet
Using cached file from /home/ds2019/log/ml-1m-full/ALS_MF/maxIter_15-rank_64-regParam_0.05/Rating_fold_0.parquet
Dummy printing of test set count in Evaluator.__evaluate_rating(): 162035
Using split of range [0.2, 0.4], test set contains 159638 of 797758 records in total.
Using cached file from /home/ds2019/log/ml-1m-full/ALS_MF/maxIter_15-rank_64-regParam_0.05/Ranking_fold_1.parquet
Using cached file from /home/ds2019/log/ml-1m-full/ALS_MF/maxIter_15-rank_64-regParam_0.05/Rating_fold_1.parquet
Dummy printing of test set count in Evaluator.__evaluate_rating(): 159638
Using split of range [0.4, 0.6], test set contains 159704 of 797758 records in total.
Using cached file from /home/ds2019/log/ml-1m-full/ALS_MF/maxIter_15-rank_64-regParam_0.05/Ranking_fold_2.parquet
Using cached file from /ho

defaultdict(list,
            {'ndcg@10': [0.04029013955306538,
              0.04416826611809346,
              0.04287737968779719,
              0.0423988808380025,
              0.04344428067449856],
             'precision@10': [0.03773178807947021,
              0.04001655629139073,
              0.03956953642384105,
              0.03932119205298014,
              0.03990066225165566]})

# Check evaluation results

In [17]:
summary = Summary(data_loader.get_config().db_path)
summary.summarize_cv("ml-1m-full", ["ndcg@10"])

Unnamed: 0,model,hyper,metric,mean,std,rnk
1,ALS_MF,"[('maxIter', 15), ('rank', 64), ('regParam', 0...",ndcg@10,0.042636,0.001468,1.0
0,ALS_MF,"[('maxIter', 15), ('rank', 32), ('regParam', 0...",ndcg@10,0.03539,0.00086,2.0


# Finally train model using all training data, and evaluate on test data

In [13]:
best_params = {
    "rank": 64,
    "maxIter": 15,
    "regParam": 0.05
}

als_final = ALS_MF(best_params)

In [14]:
test_evaluation(data_loader, als_final, spark, top_k=10, force_rewrite=True, oracle_type=None)

Rewriting files in /home/ds2019/log/ml-1m-full/ALS_MF/maxIter_15-rank_64-regParam_0.05/Ranking_fold_-1.parquet
Rewriting files in /home/ds2019/log/ml-1m-full/ALS_MF/maxIter_15-rank_64-regParam_0.05/Rating_fold_-1.parquet
Dummy printing of test set count in Evaluator.__evaluate_rating(): 202451


ChainMap({'ndcg@10': 0.05632499121402725, 'precision@10': 0.052715231788079464}, {})

In [15]:
summary.get_model_test_perf("ml-1m-full", "ALS_MF")

Unnamed: 0_level_0,model,hyper,metric,value,ts
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
123,ALS_MF,"[('maxIter', 15), ('rank', 128), ('regParam', ...",ndcg@10,0.052973,2019-10-30 17:58:03.291079
124,ALS_MF,"[('maxIter', 15), ('rank', 128), ('regParam', ...",precision@10,0.050778,2019-10-30 17:58:03.291378
137,ALS_MF,"[('maxIter', 15), ('rank', 64), ('regParam', 0...",ndcg@10,0.056325,2019-10-30 18:28:24.082514
138,ALS_MF,"[('maxIter', 15), ('rank', 64), ('regParam', 0...",precision@10,0.052715,2019-10-30 18:28:24.082758
