## Note
In this notebook we will load a trained GMF++ model, and go over the evaluation procedure. The GMF++ is based on simple model introduced by [He et al](https://arxiv.org/abs/1708.05031). You can try to adapt other models such as MLP and NMF. The [original implementation](https://github.com/hexiangnan/neural_collaborative_filtering/tree/4aab159e81c44b062c091bdaed0ab54ac632371f) as well as other implemntations are available for single market settings.     

In [1]:
import argparse
import pandas as pd
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset, ConcatDataset

import os
from os import path
import json
import resource
import sys
import pickle
from zipfile import ZipFile

sys.path.insert(1, 'src')
from model import Model
from utils import *
from data import *
from train_baseline import *

INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [2]:
parser = create_arg_parser()

tgt_market = 't2' 
src_markets = 'none' # 'none' | 's1' | 's1_s2_s3'
exp_names = 'toytest'
tgt_market_valid = f'DATA/{tgt_market}/valid_run.tsv'
tgt_market_test = f'DATA/{tgt_market}/test_run.tsv'


args = parser.parse_args(f'--tgt_market {tgt_market} --src_markets {src_markets} \
            --tgt_market_valid {tgt_market_valid} --tgt_market_test {tgt_market_test} --cuda'.split()) #

args.device = torch.device('cuda:2' if torch.cuda.is_available() and args.cuda else 'cpu')
print("Device:", args.device)

Device: cuda:2


In [3]:
# load pretrained model
model_dir = f'checkpoints/{tgt_market}_{src_markets}_{exp_names}.model'
id_bank_dir = f'checkpoints/{tgt_market}_{src_markets}_{exp_names}.pickle'

with open(id_bank_dir, 'rb') as centralid_file:
    my_id_bank = pickle.load(centralid_file)

mymodel = Model(args, my_id_bank)
mymodel.load(model_dir)

Model is NMF++!
NMF(
  (embedding_user_gmf): Embedding(5483, 64)
  (embedding_user_mlp): Embedding(5483, 64)
  (embedding_item_gmf): Embedding(2963, 64)
  (embedding_item_mlp): Embedding(2963, 64)
  (gmf): GMF(
    (affine_output): Sequential(
      (0): Linear(in_features=64, out_features=1, bias=True)
    )
    (logistic): Sigmoid()
  )
  (mlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=128, out_features=32, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): ReLU(inplace=True)
      (4): Linear(in_features=32, out_features=32, bias=True)
      (5): ReLU(inplace=True)
      (6): Linear(in_features=32, out_features=1, bias=True)
    )
    (logistic): Sigmoid()
  )
)
Pretrained weights from checkpoints/t2_none_toytest.model are loaded!


In [4]:
# for name, param in mymodel.model.named_parameters():
# 	print(name,param)

In [5]:
############
## Target Market Evaluation data
############
tgt_task_generator = TaskGenerator(None, my_id_bank)
args.batch_size = 10240
tgt_valid_dataloader = tgt_task_generator.instance_a_market_valid_dataloader(args.tgt_market_valid, args.batch_size)
tgt_test_dataloader = tgt_task_generator.instance_a_market_valid_dataloader(args.tgt_market_test, args.batch_size)
print('loaded target test and validation data!')

loaded target test and validation data!


In [6]:
run_dir = './baseline_outputs/sample_run/'

def initia_write_run_file(run_mf, file_address):
    with open(file_address, 'w') as fo:
        fo.write('userId\titemId\tscore\n')
        for u_id in run_mf:
            for p_id in run_mf[u_id]:
                fo.write('{}\t{}\t{}\n'.format(u_id, p_id, run_mf[u_id][p_id]))

# 优化前t1 5m, t2 22m
# 优化后t1 6.3s, t2 26.7s
valid_run_mf = mymodel.predict(tgt_valid_dataloader)
test_run_mf = mymodel.predict(tgt_test_dataloader)

valid_run_mf = conver_data(valid_run_mf, my_id_bank)
test_run_mf = conver_data(test_run_mf, my_id_bank)

In [7]:
def write_run_file(run_mf, file_address):
    with open(file_address, 'w') as fo:
        fo.write('userId\titemId\tscore\n')
        for row in run_mf.itertuples():
            u_id = row.userId
            for score, i_id in enumerate(row.user_sort_item):
                fo.write('{}\t{}\t{}\n'.format(u_id, i_id, (10-score)/10+2))  # 随便给个递减分数就行

write_run_file(valid_run_mf, path.join(run_dir, tgt_market, 'valid_pred.tsv'))
write_run_file(test_run_mf, path.join(run_dir, tgt_market, 'test_pred.tsv'))

# get full evaluation on validation set using pytrec_eval.
tgt_valid_qrel = read_qrel_file('DATA/{}/valid_qrel.tsv'.format(tgt_market))
task_ov, task_ind = get_evaluations_final(valid_run_mf, tgt_valid_qrel)

In [8]:
# Zip the run files into a single archive to prepare for submission    
! cd {run_dir} && zip -r ../sample_run.zip ./

print("*** Validating the submission Zip file ***")
# Run the validate_submission.py script to check if the file format is okay and get the performance on validation set.
! python validate_submission.py ./baseline_outputs/sample_run.zip

updating: t1/ (stored 0%)
updating: t1/test_pred.tsv (deflated 80%)
updating: t1/valid_pred.tsv (deflated 80%)
updating: t2/ (stored 0%)
updating: t2/test_pred.tsv (deflated 79%)
updating: t2/valid_pred.tsv (deflated 79%)
*** Validating the submission Zip file ***
Extracting the submission zip file
Validating the file structure of the submission
File structure validation successfully passed
Evaluating the validation set
