In [1]:
import os
import argparse
import json
from os import path
from copy import deepcopy

from utils import *

from dataloaders import dataloader_factory
from scheduler.utils import *

from collections import Counter
from tqdm import tqdm

The Zen of Python, by Tim Peters

Beautiful is better than ugly.
Explicit is better than implicit.
Simple is better than complex.
Complex is better than complicated.
Flat is better than nested.
Sparse is better than dense.
Readability counts.
Special cases aren't special enough to break the rules.
Although practicality beats purity.
Errors should never pass silently.
Unless explicitly silenced.
In the face of ambiguity, refuse the temptation to guess.
There should be one-- and preferably only one --obvious way to do it.
Although that way may not be obvious at first unless you're Dutch.
Now is better than never.
Although never is often better than *right* now.
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.
Namespaces are one honking great idea -- let's do more of those!


In [2]:
# config_path = "/data01/wushiguang-slurm/storage01/___lfms_base_2022/base_gru4rec_02-25_11:06_0/config.json"
# data_path = "./Data/Cache/lastfm_small-5-5.pkl"
# base_model_path = "/data01/wushiguang-slurm/storage/___lfms_base_2022/base_gru4rec_02-25_11:06_0/gru4rec_logs/checkpoint/best_acc_model.pth"
# soft_model_path = "/data01/wushiguang-slurm/storage/___lfms_soft_2022/pop_T-3.0-al-0.5-gru4rec_02-25_12:49_t5/student_gru4rec_logs/checkpoint/best_acc_model.pth"
# ed_model_path = "/data01/wushiguang-slurm/storage/___lfms_ensemble_distill_diff_model_2020/gru_gru4rec-en-0.1-T-3.0-al-0.5_02-25_15:55_t17/student_gru4rec_logs/checkpoint/best_acc_model.pth"
# ep_model_path = "/data01/wushiguang-slurm/storage/___lfms_ensemble_distill_partial_trained_2022/ed_gru4rec-en-0.5-T-3.0-al-0.75-sr-0.8_02-27_18:13_t21/student_gru4rec_logs/checkpoint/best_acc_model.pth"
# dvae_model_path = "/data01/wushiguang-slurm/storage/___lfms_dvae_2022/partial_gru4rec-T-6.0-al-0.75-dal-0.75-sr-0.8_03-01_04:24_t33/student_gru4rec_logs/checkpoint/best_acc_model.pth"

# config_path = "/data01/wushiguang-slurm/storage01/___yelp_base_2022/base_gru4rec_03-06_19:36_0/config.json"
# data_path = "./Data/Cache/yelp-5-5.pkl"
# base_model_path = "/data01/wushiguang-slurm/storage01/___yelp_base_2022/base_gru4rec_03-06_19:36_0/gru4rec_logs/checkpoint/best_acc_model.pth"
# soft_model_path = "/data01/wushiguang-slurm/storage/___yelp_soft_2022/pop_gru4rec-T-3.0-al-0.25_03-13_15:27_t4/student_gru4rec_logs/checkpoint/best_acc_model.pth"
# ed_model_path = "/data01/wushiguang-slurm/storage/___yelp_ed_2022/ed_diff_2022_gru4rec-en-0.5-T-1.0-al-0.5_03-08_18:50_t2/student_gru4rec_logs/checkpoint/best_acc_model.pth"
# ep_model_path = "/data01/wushiguang-slurm/storage/___yelp_ensemble_distill_partial_trained_2022/ed_gru4rec-en-0.5-T-6.0-al-0.25-sr-0.8_03-09_07:52_t7/student_gru4rec_logs/checkpoint/best_acc_model.pth"
# dvae_model_path = "/data01/wushiguang-slurm/storage/___yelp_dvae_2022/teacher_gru4rec-T-3.0-al-0.5-dal-0.5-sr-0.8_03-12_06:15_t11/student_gru4rec_logs/checkpoint/best_acc_model.pth"

config_path = "/data01/wushiguang-slurm/storage01/___el_base_2022/base_gru4rec_03-11_20:50_0/config.json"
data_path = "./Data/Cache/electronics-5-5.pkl"
base_model_path = "/data01/wushiguang-slurm/storage01/___el_base_2022/base_gru4rec_03-11_20:50_0/gru4rec_logs/checkpoint/best_acc_model.pth"
soft_model_path = "/data01/wushiguang-slurm/storage01/___el_soft_2022/pop_gru4rec-T-6.0-al-0.75_03-15_10:57_t9/student_gru4rec_logs/checkpoint/best_acc_model.pth"
ed_model_path = "/data01/wushiguang-slurm/storage01/___el_ensemble_distill_2022/ed_diff_2022_gru4rec-en-0.5-T-3.0-al-0.75_03-16_12:13_t6/student_gru4rec_logs/checkpoint/best_acc_model.pth"
ep_model_path = "/data01/wushiguang-slurm/storage01/___el_test_2022/ed_partial_gru4rec-en-0.5-T-3.0-al-0.75-sr-0.8_03-16_11:12_t6/student_gru4rec_logs/checkpoint/best_acc_model.pth"
dvae_model_path = "/data01/wushiguang-slurm/storage01/___el_dvae_2022/teacher_gru4rec-T-3.0-al-0.75-dal-0.5-sr-0.8_03-19_13:19_t12/student_gru4rec_logs/checkpoint/best_acc_model.pth"


In [3]:
with open(path.normpath(config_path), 'r') as f:
    args = argparse.Namespace()
    args.__dict__.update(json.load(f))
    if args.kwargs is not None:
        args.__dict__update(args.kwargs)

args.do_sampling = False

In [4]:
train_loader, val_loader, test_loader, dataset = dataloader_factory(args)

item_train, item_valid, item_test, usernum, itemnum, rating_train, rating_valid, rating_test = dataset

In [5]:
popularity = Counter() 

for user in range(0, usernum):
    popularity.update(item_train[user])
    popularity.update(item_valid[user])
    popularity.update(item_test[user])

pop = [0] * (itemnum + 1)

for i, v in popularity.items():
    pop[i] = v

item_inter = torch.tensor(pop).to(torch.float)

def _calculate_ARP(batch, model):
    seqs = batch[0]
    answer = batch[1]
    
    batch = [x.to('cuda') for x in batch]

    batch_size = len(seqs)
    with torch.no_grad():
        scores = model.full_sort_predict(batch)

    row = []
    col = []

    for i in range(batch_size):
        seq = list(set(seqs[i].tolist()) | set(answer[i].tolist()))
        seq.remove(answer[i][0].item())
        if itemnum + 1 in seq:
            seq.remove(itemnum + 1)
        row += [i] * len(seq)
        col += seq
    scores[row, col] = -1e9

    rank = (-scores).argsort(dim=1)[:, :10]

    return  item_inter[rank].mean()

def calculate_ARP(tag: str, model):
    model.eval()

    rate = 0.
    t = 1
#     tq = tqdm(test_loader)

    for batch in test_loader:
        rate += _calculate_ARP(batch, model)

#         tq.set_description(f"tag: {tag}, rate: {rate / t}")

        t += 1

    rate /= len(test_loader)

    print(tag, ":", rate.item())

In [6]:
# model = generate_model(args, args.model_code, dataset, 'cpu')

# load_state_from_given_path(model, model_path, 'cpu', must_exist=True)

path_list = [base_model_path, soft_model_path, ed_model_path, ep_model_path, dvae_model_path]

model_list = [generate_model(args, args.model_code, dataset, 'cuda') for p in path_list]

tag_list = ['base', 'soft', 'post-model', 'post-data', 'dvae']

for m, p in zip(model_list, path_list):
    load_state_from_given_path(m, p, 'cuda', must_exist=True)

In [7]:
for t, m in zip(tag_list, model_list):
    calculate_ARP(t, m)

base : 1148.6190185546875
soft : 749.2748413085938
post-model : 728.935546875
post-data : 703.6705322265625
dvae : 745.12841796875


In [8]:
rand_model = generate_model(args, args.model_code, dataset, 'cuda')
calculate_ARP('rand', rand_model)

rand : 16.378864288330078
