In [1]:
import sys
import random
from time import time

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.optim as optim

from KGAT import KGAT
from utils.parser_kgat import *
from utils.log_helper import *
#from utils.metrics import *
from utils.model_helper import *
from data_loader import DataLoaderKGAT

from scipy.sparse import coo_matrix
import copy


# Original code
def evaluate(model, dataloader, user_ids, Ks, device):
    test_batch_size = len(user_ids) #dataloader.test_batch_size
    train_user_dict = dataloader.train_user_dict
    test_user_dict = dataloader.test_user_dict

    model.eval()

    #user_ids = list(test_user_dict.keys())
    user_ids_batches = [user_ids[i: i + test_batch_size] for i in range(0, len(user_ids), test_batch_size)]
    user_ids_batches = [torch.LongTensor(d) for d in user_ids_batches]

    n_items = dataloader.n_items
    item_ids = torch.arange(n_items, dtype=torch.long).to(device)

    cf_scores = []
    metric_names = ['precision', 'recall', 'ndcg']
    metrics_dict = {k: {m: [] for m in metric_names} for k in Ks}

    with tqdm(total=len(user_ids_batches), desc='Evaluating Iteration') as pbar:
        for batch_user_ids in user_ids_batches:
            batch_user_ids = batch_user_ids.to(device)

            with torch.no_grad():
                batch_scores = model(batch_user_ids, item_ids, mode='predict')       # (n_batch_users, n_items)

            batch_scores = batch_scores.cpu()
            batch_metrics = calc_metrics_at_k(batch_scores, train_user_dict, test_user_dict, batch_user_ids.cpu().numpy(), item_ids.cpu().numpy(), Ks)

            cf_scores.append(batch_scores.numpy())
            for k in Ks:
                for m in metric_names:
                    metrics_dict[k][m].append(batch_metrics[k][m])
            pbar.update(1)

    cf_scores = np.concatenate(cf_scores, axis=0)
    for k in Ks:
        for m in metric_names:
            metrics_dict[k][m] = np.concatenate(metrics_dict[k][m]).mean()
    return cf_scores, metrics_dict

def predict(args):
    # GPU / CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load data
    data = DataLoaderKGAT(args, logging)

    # load model
    model = KGAT(args, data.n_users, data.n_entities, data.n_relations)
    checkpoint = torch.load(args.pretrain_model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    model.to(device)

    # predict
    Ks = eval(args.Ks)
    k_min = min(Ks)
    k_max = max(Ks)

    cf_scores, metrics_dict = evaluate(model, data, Ks, device)
    np.save(args.save_dir + 'cf_scores.npy', cf_scores)
    print('CF Evaluation: Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
        metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))

In [2]:
import torch
import numpy as np
from sklearn.metrics import roc_auc_score, log_loss


def calc_recall(rank, ground_truth, k):
    """
    calculate recall of one example
    """
    return len(set(rank[:k]) & set(ground_truth)) / float(len(set(ground_truth)))


def precision_at_k(hit, k):
    """
    calculate Precision@k
    hit: list, element is binary (0 / 1)
    """
    hit = np.asarray(hit)[:k]
    return np.mean(hit)


def precision_at_k_batch(hits, k):
    """
    calculate Precision@k
    hits: array, element is binary (0 / 1), 2-dim
    """
    res = hits[:, :k].mean(axis=1)
    return res


def average_precision(hit, cut):
    """
    calculate average precision (area under PR curve)
    hit: list, element is binary (0 / 1)
    """
    hit = np.asarray(hit)
    precisions = [precision_at_k(hit, k + 1) for k in range(cut) if len(hit) >= k]
    if not precisions:
        return 0.
    return np.sum(precisions) / float(min(cut, np.sum(hit)))


def dcg_at_k(rel, k):
    """
    calculate discounted cumulative gain (dcg)
    rel: list, element is positive real values, can be binary
    """
    rel = np.asfarray(rel)[:k]
    dcg = np.sum((2 ** rel - 1) / np.log2(np.arange(2, rel.size + 2)))
    return dcg


def ndcg_at_k(rel, k):
    """
    calculate normalized discounted cumulative gain (ndcg)
    rel: list, element is positive real values, can be binary
    """
    idcg = dcg_at_k(sorted(rel, reverse=True), k)
    if not idcg:
        return 0.
    return dcg_at_k(rel, k) / idcg


def ndcg_at_k_batch(hits, k):
    """
    calculate NDCG@k
    hits: array, element is binary (0 / 1), 2-dim
    """
    hits_k = hits[:, :k]
    dcg = np.sum((2 ** hits_k - 1) / np.log2(np.arange(2, k + 2)), axis=1)

    sorted_hits_k = np.flip(np.sort(hits), axis=1)[:, :k]
    idcg = np.sum((2 ** sorted_hits_k - 1) / np.log2(np.arange(2, k + 2)), axis=1)

    idcg[idcg == 0] = np.inf
    ndcg = (dcg / idcg)
    return ndcg


def recall_at_k(hit, k, all_pos_num):
    """
    calculate Recall@k
    hit: list, element is binary (0 / 1)
    """
    hit = np.asfarray(hit)[:k]
    return np.sum(hit) / all_pos_num


def recall_at_k_batch(hits, k):
    """
    calculate Recall@k
    hits: array, element is binary (0 / 1), 2-dim
    """
    res = (hits[:, :k].sum(axis=1) / hits.sum(axis=1))
    return res


def F1(pre, rec):
    if pre + rec > 0:
        return (2.0 * pre * rec) / (pre + rec)
    else:
        return 0.


def calc_auc(ground_truth, prediction):
    try:
        res = roc_auc_score(y_true=ground_truth, y_score=prediction)
    except Exception:
        res = 0.
    return res


def logloss(ground_truth, prediction):
    logloss = log_loss(np.asarray(ground_truth), np.asarray(prediction))
    return logloss


def calc_metrics_at_k(cf_scores, train_user_dict, test_user_dict, user_ids, item_ids, Ks):
    """
    cf_scores: (n_users, n_items)
    """
    test_pos_item_binary = np.zeros([len(user_ids), len(item_ids)], dtype=np.float32)
    for idx, u in enumerate(user_ids):
        train_pos_item_list = train_user_dict[u]
        test_pos_item_list = test_user_dict[u]
        cf_scores[idx][train_pos_item_list] = -np.inf
        test_pos_item_binary[idx][test_pos_item_list] = 1

    try:
        _, rank_indices = torch.sort(cf_scores.cuda(), descending=True)    # try to speed up the sorting process
    except:
        _, rank_indices = torch.sort(cf_scores, descending=True)
    rank_indices = rank_indices.cpu()

    binary_hit = []
    for i in range(len(user_ids)):
        binary_hit.append(test_pos_item_binary[i][rank_indices[i]])
    binary_hit = np.array(binary_hit, dtype=np.float32)
    metrics_dict = {}
    for k in Ks:
        metrics_dict[k] = {}
        metrics_dict[k]['precision'] = precision_at_k_batch(binary_hit, k)
        metrics_dict[k]['recall']    = recall_at_k_batch(binary_hit, k)
        metrics_dict[k]['ndcg']      = ndcg_at_k_batch(binary_hit, k)
    return metrics_dict


def exp_calc_metrics_at_k(rank_indices, train_user_dict, test_user_dict, user_ids, item_ids, Ks):
    test_pos_item_binary = np.zeros([len(user_ids), len(item_ids)], dtype=np.float32)
    for idx, u in enumerate(user_ids):
        train_pos_item_list = train_user_dict[u]
        test_pos_item_list = test_user_dict[u]
        test_pos_item_binary[idx][test_pos_item_list] = 1

    rank_indices = torch.LongTensor([rank_indices])
    binary_hit = []
    for i in range(len(user_ids)):
        binary_hit.append(test_pos_item_binary[i][rank_indices[i]])
    binary_hit = np.array(binary_hit, dtype=np.float32)

    metrics_dict = {}
    for k in Ks:
        metrics_dict[k] = {}
        metrics_dict[k]['precision'] = precision_at_k_batch(binary_hit, k)
        metrics_dict[k]['recall']    = recall_at_k_batch(binary_hit, k)
        metrics_dict[k]['ndcg']      = ndcg_at_k_batch(binary_hit, k)
    return metrics_dict



In [3]:
# kgat args (can adjust Ks and pretrain_model_path)
class KGAT_args():
    def __init__(self, 
                 seed=2019,
                 data_name='amazon-book', 
                 data_dir='datasets/', 
                 use_pretrain=1, 
                 pretrain_embedding_dir='datasets/pretrain/', 
                 pretrain_model_path='trained_model/KGAT/model_epoch280.pth', 
                 cf_batch_size=1024, 
                 kg_batch_size=2048, 
                 test_batch_size=10000, 
                 embed_dim=64, 
                 relation_dim=64, 
                 laplacian_type='random-walk', 
                 aggregation_type='bi-interaction', 
                 conv_dim_list='[64, 32, 16]', 
                 mess_dropout='[0.1, 0.1, 0.1]', 
                 kg_l2loss_lambda=1e-5, 
                 cf_l2loss_lambda=1e-5, 
                 lr=0.0001, 
                 n_epoch=1000, 
                 stopping_steps=10, 
                 cf_print_every=1, 
                 kg_print_every=1, 
                 evaluate_every=20, 
                 Ks='[10]'):
        
        self.seed = seed
        self.data_name = data_name
        self.data_dir = data_dir
        self.use_pretrain = use_pretrain
        self.pretrain_embedding_dir = pretrain_embedding_dir
        self.pretrain_model_path = pretrain_model_path
        self.cf_batch_size = cf_batch_size
        self.kg_batch_size = kg_batch_size
        self.test_batch_size = test_batch_size
        self.embed_dim = embed_dim
        self.relation_dim = relation_dim
        self.laplacian_type = laplacian_type
        self.aggregation_type = aggregation_type
        self.conv_dim_list = conv_dim_list
        self.mess_dropout = mess_dropout
        self.kg_l2loss_lambda = kg_l2loss_lambda
        self.cf_l2loss_lambda = cf_l2loss_lambda
        self.lr = lr
        self.n_epoch = n_epoch
        self.stopping_steps = stopping_steps
        self.cf_print_every = cf_print_every
        self.kg_print_every = kg_print_every
        self.evaluate_every = evaluate_every
        self.Ks = Ks
        save_dir = 'trained_model/KGAT/{}/embed-dim{}_relation-dim{}_{}_{}_{}_lr{}_pretrain{}/'.format(
        self.data_name, self.embed_dim, self.relation_dim, self.laplacian_type, self.aggregation_type,
        '-'.join([str(i) for i in eval(self.conv_dim_list)]), self.lr, self.use_pretrain)
        self.save_dir = save_dir

In [4]:
args = KGAT_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load data
data = DataLoaderKGAT(args, logging)

# load model
model = KGAT(args, data.n_users, data.n_entities, data.n_relations)
checkpoint = torch.load(args.pretrain_model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model.to(device)
pass

  d_inv = np.power(rowsum, -1.0).flatten()


In [None]:
# Counts
print(data.n_items) # 0~24914
print(data.n_entities) # 24915~113486
print(data.n_users_entities) # 113487~184165

In [5]:
def find_all_path(start, end, target_r_1, target_r_2, adj_matrix):
    adj_first = adj_matrix[0].tocsr()
    adj_second = adj_matrix[target_r_1].tocsr()
    #adj_last = adj_matrix[target_r_2].tocsr()
    
    if target_r_2 == 0:
        reverse_adj_last = adj_matrix[1].tocsr()
    elif target_r_2 <= 40:
        reverse_adj_last = adj_matrix[target_r_2 + 39].tocsr()
    else:
        reverse_adj_last = adj_matrix[target_r_2 - 39].tocsr()
    
    # All items connect to start
    first_step_candidate = adj_first.indices[adj_first.indptr[start]:adj_first.indptr[start+1]]
    first_step_candidate = np.delete(first_step_candidate, np.where(first_step_candidate == end)[0])
    
    # All entity(uiei)/user(uiui) connect to end
    last_step_candidate = reverse_adj_last.indices[reverse_adj_last.indptr[end]:reverse_adj_last.indptr[end+1]]
    last_step_candidate = np.delete(last_step_candidate, np.where(last_step_candidate == start)[0])

    output_paths = []
    for h in first_step_candidate:
        tails = adj_second.indices[adj_second.indptr[h]:adj_second.indptr[h+1]]
        for t in tails:
            if t in last_step_candidate:
                output_paths.append([start, h, target_r_1, t, target_r_2, end])
    
    return output_paths

def find_all_path_all_r(start, end, relations, adj_matrix):
    all_paths = []
    for r in relations:
        if r == 1:
            r_1 = 1
            r_2 = 0
        else:
            r_1 = r
            r_2 = r + 39
        
        #t1 = time.perf_counter()
        paths = find_all_path(start, end, r_1, r_2, adj_matrix)
        all_paths += paths
        #t2 = time.perf_counter()
        #print("t = ", t2 - t1)
        paths = find_all_path(start, end, r_2, r_1, adj_matrix)
        all_paths += paths
        #t3 = time.perf_counter()
        #print("t2 = ", t3 - t2)
    return all_paths

In [18]:
# Get the score of a single path 
# 'sum', 'mul', 'max'
def path_score(path, A_in, mode='sum'):
    assert len(path) == 6  # [U, I, r_1, E, r_2, I]

    if mode == 'sum':
        return A_in[path[0]][path[1]] + A_in[path[1]][path[3]] + A_in[path[3]][path[5]]
    
    elif mode == 'mul':
        return A_in[path[0]][path[1]] * A_in[path[1]][path[3]] * A_in[path[3]][path[5]]
    
    elif mode == 'max':
        return max(A_in[path[0]][path[1]], A_in[path[1]][path[3]], A_in[path[3]][path[5]])
    
    elif mode == '12':
        return A_in[path[0]][path[1]] * A_in[path[1]][path[3]]
    
    elif mode == '23':
        return A_in[path[1]][path[3]] * A_in[path[3]][path[5]]

    elif mode == '13':
        return A_in[path[0]][path[1]] * A_in[path[3]][path[5]]

    else:
        raise ValueError('mode should be in ["sum", "mul", "max"].')

# Rerank the paths
def get_top_k_path(paths, A_in, mode='sum', k=5):
    assert len(paths) > 0
    
    scores = [path_score(path, A_in, mode) for path in paths]
    reranked_paths = [path for _, path in sorted(zip(scores,paths), reverse=True)]
    sorted_scores = sorted(scores, reverse=True)

    return reranked_paths[:k], sorted_scores[:k]

# Single user prediction
def predict_single(model, dataloader, user_id, relations, Ks, device, mode):
    train_user_dict = dataloader.train_user_dict

    model.eval()
    input_user_id = torch.LongTensor([user_id]).to(device)

    n_items = dataloader.n_items
    item_ids = torch.arange(n_items, dtype=torch.long).to(device)
    k_max = max(Ks)

    with torch.no_grad():
        matching_scores = model(input_user_id, item_ids, mode='predict')[0].cpu().numpy()       # (n_batch_users, n_items)

    train_pos_item_list = train_user_dict[user_id]
    matching_scores[train_pos_item_list] = -np.inf
    top_k_items = np.argsort(-matching_scores)[:k_max]
    top_k_all_att_scores = []
    top_k_all_paths = []
    top_k_paths = []
    top_k_att_scores = []
    
    for item_id in top_k_items:
        paths = find_all_path_all_r(user_id, item_id, relations, dataloader.adjacency_dict)
        reranked_paths, scores = get_top_k_path(paths, model.A_in, mode=mode, k=5)
        top_k_all_att_scores.append(scores)
        top_k_all_paths.append(reranked_paths)
        top_k_att_scores.append(scores[0])
        top_k_paths.append(reranked_paths[0])
    
    output = {}
    for k in Ks:
        reranked_items = [id for _, id in sorted(zip(top_k_att_scores[:k], top_k_items[:k]), reverse=True)]
        explainations = [path for _, path in sorted(zip(top_k_att_scores[:k], top_k_paths[:k]), reverse=True)]
        output[k] = {'item_ids': reranked_items, 'explainations': explainations}
    
    #print(top_k_items)
    #print(reranked_items)
    #print(dataloader.test_user_dict[user_id])
    return output, top_k_all_att_scores, top_k_all_paths, top_k_items

# Check original model predcition hit the GT in top-k (the result is in hits_10.txt, hits_20.txt)
def predict_ori(model, dataloader, Ks, device):
    test_batch_size = 3000 #dataloader.test_batch_size
    train_user_dict = dataloader.train_user_dict
    test_user_dict = dataloader.test_user_dict
    
    model.eval()
    user_ids = list(test_user_dict.keys())
    user_ids_batches = [user_ids[i: i + test_batch_size] for i in range(0, len(user_ids), test_batch_size)]
    user_ids_batches = [torch.LongTensor(d) for d in user_ids_batches]

    n_items = dataloader.n_items
    item_ids = torch.arange(n_items, dtype=torch.long).to(device)
    k_max = max(Ks)

    hit_users = []
    with tqdm(total=len(user_ids_batches), desc='Evaluating Iteration') as pbar:
        for batch_user_ids in user_ids_batches:
            batch_user_ids = batch_user_ids.to(device)

            with torch.no_grad():
                matching_scores = model(batch_user_ids, item_ids, mode='predict')       # (n_batch_users, n_items)

            batch_user_ids = batch_user_ids.cpu().numpy()
            test_pos_item_binary = np.zeros([len(batch_user_ids), len(item_ids.cpu().numpy())], dtype=np.float32)
            for idx, u in enumerate(batch_user_ids):
                train_pos_item_list = train_user_dict[u]
                test_pos_item_list = test_user_dict[u]
                matching_scores[idx][train_pos_item_list] = -np.inf
                test_pos_item_binary[idx][test_pos_item_list] = 1

            try:
                _, rank_indices = torch.sort(matching_scores.cuda(), descending=True)    # try to speed up the sorting process
            except:
                _, rank_indices = torch.sort(matching_scores, descending=True)

            rank_indices = rank_indices.cpu()
            binary_hit = []
            for i in range(len(batch_user_ids)):
                binary_hit.append(test_pos_item_binary[i][rank_indices[i]])
            binary_hit = np.array(binary_hit, dtype=np.float32)[:, :k_max]
            hit_user = np.where(np.sum(binary_hit, axis=1) >= 1)[0]
            hit_users += list(batch_user_ids[hit_user])
            pbar.update(1)
    
    return hit_users

# Get results of all users (explainable version)
def explainable_evaluate(model, dataloader, user_ids, relations, Ks, device, mode):
    test_batch_size = dataloader.test_batch_size
    train_user_dict = dataloader.train_user_dict
    test_user_dict = dataloader.test_user_dict

    model.eval()

    #user_ids = list(test_user_dict.keys())
    
    n_items = dataloader.n_items
    item_ids = torch.arange(n_items, dtype=torch.long).to(device)

    metric_names = ['precision', 'recall', 'ndcg']
    metrics_dict = {k: {m: [] for m in metric_names} for k in Ks}

    with tqdm(total=len(user_ids), desc='Evaluating Iteration') as pbar:
        for user_id in user_ids:

            output, _, _, ori_top_k = predict_single(model, dataloader, user_id, relations, Ks, device, mode)
            batch_metrics = exp_calc_metrics_at_k(output[max(Ks)]['item_ids'], train_user_dict, test_user_dict, [user_id], item_ids.cpu().numpy(), Ks)
            ori_metrics = exp_calc_metrics_at_k(ori_top_k, train_user_dict, test_user_dict, [user_id], item_ids.cpu().numpy(), Ks)
            for k in Ks:
                for m in metric_names:
                    metrics_dict[k][m].append(batch_metrics[k][m])
            pbar.update(1)

    for k in Ks:
        for m in metric_names:
            metrics_dict[k][m] = np.concatenate(metrics_dict[k][m]).mean()
    return metrics_dict
    
def explainable_recommend(model, data, relations, Ks):
    # predict
    #Ks = eval(Ks)
    k_min = min(Ks)
    k_max = max(Ks)
    metrics_dicts_sum = []
    metrics_dicts_mul = []
    metrics_dicts_max = []
    metrics_dicts = []
    with open('hits_10.txt', 'r') as f:
        all_hits = f.readlines()
    with open('results/Ks10_5_400.txt', 'w') as f:
        pass
    for i in range(len(all_hits)):
        all_hits[i] = int(all_hits[i])
    for i in range(5):
        user_ids = random.sample(all_hits, 400)
        cf_scores, metrics_dict = evaluate(model, data, user_ids, Ks, device)
        metrics_dicts.append(metrics_dict)
        print("Original")
        print('CF Evaluation: Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
           metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))
        with open('results/Ks10_5_400.txt', 'a') as f:
            f.write("Original\n")
            f.write('CF Evaluation: Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
           metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))
        metrics_dict = explainable_evaluate(model, data, user_ids, relations, Ks, device, mode='sum')
        metrics_dicts_sum.append(metrics_dict)
        print("Explainable sum")
        print('CF Evaluation: Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
           metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))
        with open('results/Ks10_5_400.txt', 'a') as f:
            f.write("Explainable sum\n")
            f.write('CF Evaluation: Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
           metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))
        metrics_dict = explainable_evaluate(model, data, user_ids, relations, Ks, device, mode='mul')
        metrics_dicts_mul.append(metrics_dict)
        print("Explainable mul")
        print('CF Evaluation: Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
           metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))
        with open('results/Ks10_5_400.txt', 'a') as f:
            f.write("Explainable mul\n")
            f.write('CF Evaluation: Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
           metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))
        metrics_dict = explainable_evaluate(model, data, user_ids, relations, Ks, device, mode='max')
        metrics_dicts_max.append(metrics_dict)
        with open('results/Ks10_5_400.txt', 'a') as f:
            f.write("Explainable max\n")
            f.write('CF Evaluation: Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
           metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))
        
    return metrics_dicts, metrics_dicts_sum, metrics_dicts_mul, metrics_dicts_max

In [61]:
np.where(np.sum(np.array([[0, 0, 0, 0, 0]]), axis=1) > 0)

(array([], dtype=int64),)

In [9]:
#delete [2, 3, 6, 11, 21, 22, 32, 35, 37] -> [4, 5, 8, 13, 23, 24, 34, 37, 39] ([43, 44, 47, 52, 62, 63, 73, 76, 78])
all_relations = range(2, 41)
delete_relations = [4, 5, 8, 13, 23, 24, 34, 37, 39]
target_relations = list(set(all_relations) - set(delete_relations))

Ks = [10]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#output, all_paths, all_scores, ori_top_k = predict_single(model, data, 113511, target_relations, [20], device)
#hits = predict_ori(model, data, [20], device)

In [106]:
with open('hits_20.txt', 'w') as f:
    for hit in hits:
        f.write(str(hit)+'\n')

In [89]:
print(list(ori_top_k))
print(output[20]['item_ids'])
print(data.test_user_dict[113511])

[805, 4693, 324, 1473, 548, 271, 2249, 291, 537, 2841, 642, 2142, 4016, 1078, 6897, 5401, 2040, 534, 6891, 4025]
[6897, 6891, 5401, 4693, 4025, 4016, 2841, 2249, 2142, 2040, 1473, 1078, 805, 642, 548, 537, 534, 324, 291, 271]
[  805 15038]


In [16]:
dicts = explainable_recommend(model, data, target_relations, Ks)

Evaluating Iteration: 100%|██████████| 1/1 [00:00<00:00,  3.04it/s]


Original
CF Evaluation: Precision [0.1270, 0.1270], Recall [0.6123, 0.6123], NDCG [0.3798, 0.3798]


  rank_indices = torch.LongTensor([rank_indices])
Evaluating Iteration:   1%|          | 1/100 [00:15<26:21, 15.98s/it]

[ 1220 13815  3461 16248 14322 18412 16374 20344  8904 15746]
[14322, 18412, 20344, 16374, 16248, 15746, 13815, 8904, 3461, 1220]
[1220]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}


Evaluating Iteration:   2%|▏         | 2/100 [00:32<26:15, 16.07s/it]

[13643  9003 10778 15825 14031 16662 19762 19942 19940  9001]
[10778, 19762, 19942, 13643, 15825, 16662, 14031, 19940, 9001, 9003]
[19942]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.5])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.31546488])}}


Evaluating Iteration:   3%|▎         | 3/100 [00:48<26:06, 16.15s/it]

[23799 24124 24281 23539 23740 24400 23538 24311 19243 24283]
[24311, 23539, 24400, 24283, 24281, 24124, 23799, 23740, 23538, 19243]
[19243 23799]
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.38162168])}}
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.79772289])}}


Evaluating Iteration:   4%|▍         | 4/100 [01:06<27:10, 16.98s/it]

[ 6350 17544 19780 16580 18465   290  6343 19649   783 21025]
[783, 19649, 6350, 19780, 18465, 17544, 290, 16580, 21025, 6343]
[2670 6350 7157 9790]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.5])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}


Evaluating Iteration:   5%|▌         | 5/100 [01:24<27:07, 17.13s/it]

[ 242  240  209  309 2696 3369 1342 5238 1955  781]
[5238, 1955, 1342, 209, 3369, 2696, 781, 309, 242, 240]
[ 242 2699]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.30103])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}


Evaluating Iteration:   6%|▌         | 6/100 [01:39<26:05, 16.66s/it]

[ 6110  9347  7455 12698 12523 22486 13099  6162 12729 14286]
[9347, 6110, 12698, 12523, 14286, 7455, 22486, 6162, 12729, 13099]
[9347]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}


Evaluating Iteration:   7%|▋         | 7/100 [01:55<25:21, 16.36s/it]

[13815  1220  3380  1214  2919  5315 18695 22571  3461  4441]
[22571, 18695, 2919, 1214, 13815, 5315, 4441, 3461, 3380, 1220]
[ 1220 20779]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}


Evaluating Iteration:   8%|▊         | 8/100 [02:11<25:09, 16.40s/it]

[ 6346 23377  6360 23427 23620 21602 23599 21132 17762 14330]
[23620, 17762, 21602, 23599, 23427, 23377, 21132, 14330, 6360, 6346]
[12645 23620 24445]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.38685281])}}


Evaluating Iteration:   9%|▉         | 9/100 [02:29<25:19, 16.70s/it]

[  514  4680  4852 13815  1918   498  1919  2085  4144  3435]
[2085, 13815, 4852, 4680, 4144, 3435, 1919, 1918, 514, 498]
[  595  1802 13815]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.43067656])}}


Evaluating Iteration:  10%|█         | 10/100 [02:57<30:27, 20.30s/it]

[11344 14605 20540  9928 17714 17144 17236 14222 22948 11171]
[22948, 20540, 17714, 17236, 17144, 14605, 14222, 11344, 11171, 9928]
[ 9562 11117 14218 16246 17144 17779 17855 19352 21223 22948]
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.85034491])}}
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.40298313])}}


Evaluating Iteration:  11%|█         | 11/100 [03:14<28:18, 19.08s/it]

[18520 16977 17723 23302 17776 17782 21156 20113 24078   283]
[24078, 23302, 21156, 20113, 18520, 17782, 17776, 17723, 16977, 283]
[21156 21603 22096]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.5])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.33333333])}}


Evaluating Iteration:  12%|█▏        | 12/100 [03:30<26:52, 18.32s/it]

[ 3626 21255  3631 15701 18694  4076  4079  3612  7591  9038]
[15701, 18694, 21255, 9038, 7591, 4079, 4076, 3631, 3626, 3612]
[ 7591 18079]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.38685281])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.30103])}}


Evaluating Iteration:  13%|█▎        | 13/100 [03:46<25:25, 17.53s/it]

[16374  3461 19264  4852 18695 15746  1115 13815  3434 18689]
[18695, 19264, 18689, 16374, 15746, 13815, 4852, 3461, 3434, 1115]
[18695]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.38685281])}}


Evaluating Iteration:  14%|█▍        | 14/100 [04:02<24:29, 17.09s/it]

[24245 23623 24535 21156 23799 22723 24247 24022 18520 24466]
[23623, 24022, 24245, 24535, 24466, 24247, 23799, 22723, 21156, 18520]
[23623 24535]
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.87721532])}}
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.6934264])}}


Evaluating Iteration:  15%|█▌        | 15/100 [04:19<24:02, 16.97s/it]

[ 747  710  296   45 2776 3657  650  402  625 1923]
[2776, 650, 3657, 1923, 747, 710, 625, 402, 296, 45]
[ 710 3828]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.35620719])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}


Evaluating Iteration:  16%|█▌        | 16/100 [04:35<23:22, 16.69s/it]

[17776 10765 17775 18520 17777 18409 18086 17784 21601 10968]
[21601, 18520, 18409, 18086, 17784, 17777, 17776, 17775, 10968, 10765]
[10765 18154]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}


Evaluating Iteration:  17%|█▋        | 17/100 [04:55<24:27, 17.68s/it]

[ 7782   402  3007  4314  1823  4794  1914  9529  1268 18742]
[9529, 18742, 7782, 4794, 4314, 3007, 1914, 1823, 402, 1268]
[ 7561  7778 10368 18742]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}


Evaluating Iteration:  18%|█▊        | 18/100 [05:10<23:22, 17.11s/it]

[21656  9813 21854   195 18254 21400  4540 21522 17806 19324]
[9813, 195, 21854, 17806, 21656, 21400, 19324, 18254, 4540, 21522]
[21854]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.5])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.5])}}


Evaluating Iteration:  19%|█▉        | 19/100 [05:26<22:36, 16.75s/it]

[ 6181 14757 12378 20389 14755 14758 14771 13419 16533 12222]
[14758, 20389, 14757, 14755, 12378, 16533, 14771, 13419, 12222, 6181]
[6181]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}


Evaluating Iteration:  20%|██        | 20/100 [05:47<24:00, 18.01s/it]

[10587 10027 15775  5323  6601 21066 14606 21577  9587 10579]
[21577, 21066, 15775, 14606, 10587, 10579, 10027, 9587, 6601, 5323]
[ 3609 10579 23389 24762]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.35620719])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}


Evaluating Iteration:  21%|██        | 21/100 [06:06<23:48, 18.08s/it]

[10765  6348 18520 17723 17289 10968 17782 18083 21752 17777]
[21752, 18520, 18083, 17782, 17777, 17723, 17289, 10968, 10765, 6348]
[10968]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.31546488])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.35620719])}}


Evaluating Iteration:  22%|██▏       | 22/100 [06:23<23:06, 17.78s/it]

[ 748  545  606  721  827 2303 1711 2381  562 4238]
[562, 606, 748, 827, 721, 545, 2303, 2381, 1711, 4238]
[562]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.30103])}}


Evaluating Iteration:  23%|██▎       | 23/100 [06:39<22:07, 17.24s/it]

[ 1919  4680 16374   733  3434  1122  5172 17608  1920  1918]
[5172, 17608, 16374, 4680, 3434, 1920, 1919, 1918, 1122, 733]
[ 1919 16441]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.33333333])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}


Evaluating Iteration:  24%|██▍       | 24/100 [06:55<21:27, 16.94s/it]

[  92  263  305 1171  324  313 2005 1077 1473 3422]
[324, 3422, 2005, 1473, 1171, 1077, 313, 305, 263, 92]
[305]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.31546488])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.5])}}


Evaluating Iteration:  25%|██▌       | 25/100 [07:21<24:41, 19.76s/it]

[17782 16981 22866 18085 18087 17776 18086 22940 17784 21752]
[22940, 22866, 21752, 18087, 18086, 18085, 17784, 17782, 17776, 16981]
[16977 16981 18085 18086 20186 20188 22767 22940 23455 24100]
{10: {'precision': array([0.4], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.79330099])}}
{10: {'precision': array([0.4], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.6677078])}}


Evaluating Iteration:  26%|██▌       | 26/100 [07:37<22:57, 18.62s/it]

[ 9732  4540 16980  8907 18000  1700 17144  6591 13815  9494]
[8907, 18000, 17144, 16980, 13815, 9732, 9494, 6591, 4540, 1700]
[4540]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.30103])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}


Evaluating Iteration:  27%|██▋       | 27/100 [07:55<22:29, 18.49s/it]

[18086 21603 24028 19813 16981 23882 21601 23421 23046 16979]
[23882, 24028, 23421, 23046, 21603, 21601, 19813, 18086, 16981, 16979]
[18086 22940 23046 23566]
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.45749453])}}
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.79772289])}}


Evaluating Iteration:  28%|██▊       | 28/100 [08:12<21:27, 17.89s/it]

[ 1503 19746  4541  9003  2158 23053  9001 11343  4219 14612]
[1503, 19746, 4219, 9001, 23053, 14612, 11343, 9003, 4541, 2158]
[ 2158 16282]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.38685281])}}


Evaluating Iteration:  29%|██▉       | 29/100 [08:28<20:33, 17.37s/it]

[17776 21156 10968 24281 24124 23538 18085 18086 21601 24050]
[24281, 24124, 24050, 23538, 21601, 21156, 18086, 18085, 17776, 10968]
[24248 24281]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.43067656])}}


Evaluating Iteration:  30%|███       | 30/100 [08:43<19:37, 16.83s/it]

[ 7263 22562  2051  6649 19342 18103  8019 15996  9285  8034]
[9285, 22562, 19342, 18103, 8034, 8019, 7263, 6649, 2051, 15996]
[8019]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.35620719])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.33333333])}}


Evaluating Iteration:  31%|███       | 31/100 [09:44<34:34, 30.06s/it]

[ 267  265 1917 1945   53   45  548 4011 5632  281]
[265, 5632, 4011, 1945, 1917, 548, 281, 267, 53, 45]
[   45   357   358   606   622   759  1085  1101  1122  1184  1604  1919
  3277  3434  5155  5161  5948  6371  6470  6751  6893  7417  7500  8151
  8783  9845 14257 18782]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.35620719])}}


Evaluating Iteration:  32%|███▏      | 32/100 [10:13<33:42, 29.74s/it]

[3864 2700 3669 3337 2442  286 2696 6809 3733 3654]
[6809, 2442, 3864, 3733, 3669, 3654, 3337, 2700, 2696, 286]
[  286  1493  2932  2934  2952  3104  6273  9278 19405]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.35620719])}}


Evaluating Iteration:  33%|███▎      | 33/100 [10:30<28:55, 25.90s/it]

[ 4919  1493  2970  2969  1091  9495  2974  2964  1507 18555]
[1507, 2964, 2970, 2974, 2969, 1493, 9495, 4919, 1091, 18555]
[ 2970 12528]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.5])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.5])}}


Evaluating Iteration:  34%|███▍      | 34/100 [10:50<26:18, 23.92s/it]

[23441 21996 22080 23862 23307 19343 22148 19345 22716 22140]
[19343, 22080, 22148, 23862, 22140, 19345, 22716, 23307, 21996, 23441]
[21996]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.30103])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}


Evaluating Iteration:  35%|███▌      | 35/100 [11:05<23:13, 21.44s/it]

[23995 20326 18438 23906 17837 22465 18164 22528 14137 20477]
[17837, 23995, 22528, 18438, 23906, 22465, 20326, 14137, 20477, 18164]
[20477]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.30103])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}


Evaluating Iteration:  36%|███▌      | 36/100 [11:23<21:45, 20.39s/it]

[ 1102  1088  1021  9039  1136  4075  4079 10587  9792  6583]
[9039, 1088, 10587, 9792, 6583, 4079, 4075, 1136, 1102, 1021]
[ 1021  6583 22463]
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.414437])}}
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.48381288])}}


Evaluating Iteration:  37%|███▋      | 37/100 [11:40<20:07, 19.17s/it]

[ 4542  7782  1981  6765   930  4445  7902 11217 11409  6085]
[930, 1981, 4445, 7902, 6765, 11409, 11217, 7782, 6085, 4542]
[1671 4445]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.5])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.35620719])}}


Evaluating Iteration:  38%|███▊      | 38/100 [11:58<19:42, 19.07s/it]

[ 795 1559 1560 3938 1568 1055 5524 1785 1097  355]
[1559, 795, 5524, 3938, 1785, 1568, 1560, 1097, 1055, 355]
[ 1097 21215 21346 24012]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.31546488])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.30103])}}


Evaluating Iteration:  39%|███▉      | 39/100 [12:29<22:54, 22.54s/it]

[ 884 3677 2063  886 4641  841  972 1625  919 6778]
[2063, 3677, 886, 884, 1625, 6778, 4641, 919, 972, 841]
[  863  1224  2063  3373  3677  7047  7934 10680 10723 19015]
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.6934264])}}


Evaluating Iteration:  40%|████      | 40/100 [12:45<20:29, 20.50s/it]

[13815 20344  3461 18695  1214  5849 16374  4680 19706 20743]
[20743, 18695, 1214, 20344, 19706, 16374, 13815, 5849, 4680, 3461]
[13815]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.33333333])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}


Evaluating Iteration:  41%|████      | 41/100 [13:01<18:52, 19.20s/it]

[ 1214  1223  7400 13815 19706  3461 15669 23718  7512 17086]
[7400, 15669, 7512, 1214, 23718, 19706, 13815, 3461, 1223, 17086]
[1214 1223]
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.44864382])}}
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}


Evaluating Iteration:  42%|████▏     | 42/100 [13:17<17:37, 18.23s/it]

[21255   736  4076  3626  3998  1674   709  3612  4447  2611]
[709, 1674, 21255, 3998, 4447, 4076, 3626, 3612, 2611, 736]
[1674]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.35620719])}}


Evaluating Iteration:  43%|████▎     | 43/100 [13:34<17:03, 17.96s/it]

[10967 17776 18520 18083 17782 18368 17775 17777 17784 22801]
[22801, 18520, 18368, 18083, 17784, 17782, 17777, 17776, 17775, 10967]
[18520 22767 22870]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.5])}}


Evaluating Iteration:  44%|████▍     | 44/100 [13:54<17:14, 18.47s/it]

[2417 3600 2421 3453 2411 9298 4703 2414 9294 9292]
[9292, 4703, 2414, 3600, 9298, 9294, 2417, 3453, 2421, 2411]
[ 2414  2417  2421  4703 10217]
{10: {'precision': array([0.4], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.6891352])}}
{10: {'precision': array([0.4], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.83884795])}}


Evaluating Iteration:  45%|████▌     | 45/100 [14:11<16:30, 18.01s/it]

[ 9933 18000 14612  4541 14790 22948 22130  8168 17236  4540]
[22948, 18000, 17236, 14790, 14612, 9933, 8168, 4541, 4540, 22130]
[ 9813  9941 18000]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}


Evaluating Iteration:  46%|████▌     | 46/100 [14:27<15:40, 17.42s/it]

[ 8443 10309  6249  6250  2421  5022 10005  2410  9151  4715]
[2410, 9151, 6249, 4715, 6250, 5022, 2421, 8443, 10309, 10005]
[10005 21215]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.33333333])}}


Evaluating Iteration:  47%|████▋     | 47/100 [14:43<15:07, 17.13s/it]

[ 650 2005 1668 1438 3012  302  441  512 5401  416]
[1438, 416, 1668, 512, 3012, 5401, 2005, 650, 441, 302]
[441 998]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.30103])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.33333333])}}


Evaluating Iteration:  48%|████▊     | 48/100 [15:00<14:39, 16.90s/it]

[24490 23537 24491 23565 24357 22007 24300 18087 18520 24880]
[23537, 24491, 24490, 24880, 24300, 24357, 23565, 22007, 18520, 18087]
[24491 24880]
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.65092093])}}
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.48381288])}}


Evaluating Iteration:  49%|████▉     | 49/100 [15:17<14:22, 16.92s/it]

[18000 18257 17663 14790 18254 21021 20230 14222  9936  9942]
[20230, 21021, 18257, 18254, 18000, 17663, 14790, 14222, 9942, 9936]
[ 9936 21021 23061]
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.56409209])}}
{10: {'precision': array([0.2], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.40298313])}}


Evaluating Iteration:  50%|█████     | 50/100 [15:33<13:51, 16.62s/it]

[ 1214  1216 13815 15669  3461  7400 20344 10576 20743  4441]
[7400, 15669, 20743, 1214, 20344, 13815, 10576, 4441, 3461, 1216]
[7400]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([1.])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.35620719])}}


Evaluating Iteration:  51%|█████     | 51/100 [15:51<13:54, 17.04s/it]

[ 267 3724 1039 1945 2802 5537 1248  266 1196   92]
[5537, 3724, 2802, 1945, 1248, 1196, 1039, 267, 266, 92]
[ 330 1211 1945]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.43067656])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.43067656])}}


Evaluating Iteration:  52%|█████▏    | 52/100 [16:09<13:50, 17.31s/it]

[ 2709  3255   579  3269  2802 15193  1513   323  1568  2055]
[1513, 579, 15193, 3269, 3255, 2802, 2709, 2055, 1568, 323]
[ 3269 10145 11541 13079]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.43067656])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.43067656])}}


Evaluating Iteration:  53%|█████▎    | 53/100 [16:25<13:17, 16.97s/it]

[ 9621  1214 17673  1216 19706  6408 16245  8907 20019 15883]
[6408, 20019, 15883, 16245, 8907, 1214, 19706, 17673, 9621, 1216]
[8270 8907]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.38685281])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.31546488])}}


Evaluating Iteration:  54%|█████▍    | 54/100 [16:41<12:55, 16.86s/it]

[ 9100  3404  8073 17800  3403  8074  1440 11105  3402 11426]
[11105, 11426, 9100, 3404, 17800, 8074, 8073, 3403, 1440, 3402]
[11426 15515]
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.63092975])}}
{10: {'precision': array([0.1], dtype=float32), 'recall': array([1.], dtype=float32), 'ndcg': array([0.28906483])}}


Evaluating Iteration:  54%|█████▍    | 54/100 [16:47<14:18, 18.66s/it]


KeyboardInterrupt: 