# Necessary modules and environment params

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")

In [3]:
from src.data_pipeline.DataLoader import DataLoader
from src.utility.sys_utils import get_spark

# import your model
from src.model.ALS_MF import ALS_MF
from src.model.KNN import KNN
from src.model.BaseLine import BaseLine

# import training pipeline
from src.data_pipeline.pipeline import cross_validation, test_evaluation

# import result checking tools
from src.utility.Summary import Summary

# Load whole data from the text file

In [4]:
spark = get_spark(cores=4)  # change cores up to 6 if needed

In [5]:
# remember to specify the config file under the config/ dir, make sure all 
data_loader = DataLoader(spark, "ml-20m-5p", config_name="default_config.json")

Using sampled subset with 6.896900E+04 records
Using split of range (0, 0.2), test set contains 19365 of 68969 records in total.


# Construct Model

In [6]:
als_params = {
    "rank": 64,
    "maxIter": 15,
    "regParam": 0.05,
    "num_neg": 0
}

In [7]:
als = ALS_MF(als_params)

In [8]:
knn = KNN({"k":20})

In [9]:
bl = BaseLine({"model": "count_rank_rating"})

# Cross validation

In [10]:
cross_validation(data_loader, als, spark, k_fold=5, top_k=10)

Using split of range [0.0, 0.2], test set contains 14810 of 49604 records in total.
Using cached file from /home/ds2019/log/ml-20m-5p/ALS_MF/maxIter_15-num_neg_0-rank_64-regParam_0.05/Ranking_fold_0.parquet
Using cached file from /home/ds2019/log/ml-20m-5p/ALS_MF/maxIter_15-num_neg_0-rank_64-regParam_0.05/Rating_fold_0.parquet
Using split of range [0.2, 0.4], test set contains 8402 of 49604 records in total.
Using cached file from /home/ds2019/log/ml-20m-5p/ALS_MF/maxIter_15-num_neg_0-rank_64-regParam_0.05/Ranking_fold_1.parquet
Using cached file from /home/ds2019/log/ml-20m-5p/ALS_MF/maxIter_15-num_neg_0-rank_64-regParam_0.05/Rating_fold_1.parquet
Using split of range [0.4, 0.6], test set contains 8407 of 49604 records in total.
Using cached file from /home/ds2019/log/ml-20m-5p/ALS_MF/maxIter_15-num_neg_0-rank_64-regParam_0.05/Ranking_fold_2.parquet
Using cached file from /home/ds2019/log/ml-20m-5p/ALS_MF/maxIter_15-num_neg_0-rank_64-regParam_0.05/Rating_fold_2.parquet
Using split of 

defaultdict(list,
            {'ndcg@10': [0.01483122753971261,
              0.007585024522262664,
              0.009230603021243594,
              0.007088368306926822,
              0.013283914161411723],
             'precision@10': [0.0041849840578039406,
              0.0023120511609765345,
              0.0027059623454624006,
              0.0021220195587044913,
              0.0037072842985693124]})

# Check evaluation results from the database

In [11]:
summary = Summary(data_loader.get_config().db_path)
summary.summarize_cv("ml-20m-5p", ["ndcg@10"])

Unnamed: 0,model,hyper,metric,mean,std,rnk
57,KNN,"[('k', 30)]",ndcg@10,0.015035,0.002261,1.0
53,KNN,"[('k', 10)]",ndcg@10,0.015035,0.002261,2.0
54,KNN,"[('k', 15)]",ndcg@10,0.015035,0.002261,3.0
56,KNN,"[('k', 25)]",ndcg@10,0.015035,0.002261,4.0
55,KNN,"[('k', 20)]",ndcg@10,0.015035,0.002261,5.0
58,KNN,"[('k', 5)]",ndcg@10,0.015035,0.002261,6.0
51,BaseLine,"[('model', 'avg_rating')]",ndcg@10,0.010987,0.001144,7.0
52,BaseLine,"[('model', 'count_rank_rating')]",ndcg@10,0.010987,0.001144,8.0
9,ALS_MF,"[('maxIter', 15), ('num_neg', 0), ('rank', 32)...",ndcg@10,0.009711,0.001018,9.0
4,ALS_MF,"[('maxIter', 15), ('num_neg', 0), ('rank', 16)...",ndcg@10,0.009705,0.001017,10.0


# Finally train model using all training data, and evaluate on test data

In [12]:
best_params = {
    "rank": 16,
    "maxIter": 15,
    "regParam": 0.05,
    "num_neg": 0
}

als_final = ALS_MF(best_params)

In [13]:
test_evaluation(data_loader, als_final, spark, top_k=10, force_rewrite=True, oracle_type=None)

Rewriting files in /home/ds2019/log/ml-20m-5p/ALS_MF/maxIter_15-num_neg_0-rank_16-regParam_0.05/Ranking_fold_-1.parquet
Rewriting files in /home/ds2019/log/ml-20m-5p/ALS_MF/maxIter_15-num_neg_0-rank_16-regParam_0.05/Rating_fold_-1.parquet


ChainMap({'ndcg@10': 0.011559062754521145, 'precision@10': 0.003460713589393645}, {})