In [15]:
import sys
import random
from time import time

import pandas as pd
from tqdm import tqdm
import torch.optim as optim

from KGAT import KGAT
from utils.parser_kgat import *
from utils.log_helper import *
from utils.metrics import *
from utils.model_helper import *
from data_loader import DataLoaderKGAT

from scipy.sparse import coo_matrix
import copy


# Original code
def evaluate(model, dataloader, Ks, device):
    test_batch_size = dataloader.test_batch_size
    train_user_dict = dataloader.train_user_dict
    test_user_dict = dataloader.test_user_dict

    model.eval()

    user_ids = list(test_user_dict.keys())
    user_ids_batches = [user_ids[i: i + test_batch_size] for i in range(0, len(user_ids), test_batch_size)]
    user_ids_batches = [torch.LongTensor(d) for d in user_ids_batches]

    n_items = dataloader.n_items
    item_ids = torch.arange(n_items, dtype=torch.long).to(device)

    cf_scores = []
    metric_names = ['precision', 'recall', 'ndcg']
    metrics_dict = {k: {m: [] for m in metric_names} for k in Ks}

    with tqdm(total=len(user_ids_batches), desc='Evaluating Iteration') as pbar:
        for batch_user_ids in user_ids_batches:
            batch_user_ids = batch_user_ids.to(device)

            with torch.no_grad():
                batch_scores = model(batch_user_ids, item_ids, mode='predict')       # (n_batch_users, n_items)

            batch_scores = batch_scores.cpu()
            batch_metrics = calc_metrics_at_k(batch_scores, train_user_dict, test_user_dict, batch_user_ids.cpu().numpy(), item_ids.cpu().numpy(), Ks)

            cf_scores.append(batch_scores.numpy())
            for k in Ks:
                for m in metric_names:
                    metrics_dict[k][m].append(batch_metrics[k][m])
            pbar.update(1)

    cf_scores = np.concatenate(cf_scores, axis=0)
    for k in Ks:
        for m in metric_names:
            metrics_dict[k][m] = np.concatenate(metrics_dict[k][m]).mean()
    return cf_scores, metrics_dict

def predict(args):
    # GPU / CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load data
    data = DataLoaderKGAT(args, logging)

    # load model
    model = KGAT(args, data.n_users, data.n_entities, data.n_relations)
    checkpoint = torch.load(args.pretrain_model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    model.to(device)

    # predict
    Ks = eval(args.Ks)
    k_min = min(Ks)
    k_max = max(Ks)

    cf_scores, metrics_dict = evaluate(model, data, Ks, device)
    np.save(args.save_dir + 'cf_scores.npy', cf_scores)
    print('CF Evaluation: Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
        metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))

In [90]:
# kgat args (can adjust Ks and pretrain_model_path)
class KGAT_args():
    def __init__(self, 
                 seed=2019,
                 data_name='amazon-book', 
                 data_dir='datasets/', 
                 use_pretrain=1, 
                 pretrain_embedding_dir='datasets/pretrain/', 
                 pretrain_model_path='trained_model/KGAT/model_epoch280.pth', 
                 cf_batch_size=1024, 
                 kg_batch_size=2048, 
                 test_batch_size=10000, 
                 embed_dim=64, 
                 relation_dim=64, 
                 laplacian_type='random-walk', 
                 aggregation_type='bi-interaction', 
                 conv_dim_list='[64, 32, 16]', 
                 mess_dropout='[0.1, 0.1, 0.1]', 
                 kg_l2loss_lambda=1e-5, 
                 cf_l2loss_lambda=1e-5, 
                 lr=0.0001, 
                 n_epoch=1000, 
                 stopping_steps=10, 
                 cf_print_every=1, 
                 kg_print_every=1, 
                 evaluate_every=20, 
                 Ks='[10]'):
        
        self.seed = seed
        self.data_name = data_name
        self.data_dir = data_dir
        self.use_pretrain = use_pretrain
        self.pretrain_embedding_dir = pretrain_embedding_dir
        self.pretrain_model_path = pretrain_model_path
        self.cf_batch_size = cf_batch_size
        self.kg_batch_size = kg_batch_size
        self.test_batch_size = test_batch_size
        self.embed_dim = embed_dim
        self.relation_dim = relation_dim
        self.laplacian_type = laplacian_type
        self.aggregation_type = aggregation_type
        self.conv_dim_list = conv_dim_list
        self.mess_dropout = mess_dropout
        self.kg_l2loss_lambda = kg_l2loss_lambda
        self.cf_l2loss_lambda = cf_l2loss_lambda
        self.lr = lr
        self.n_epoch = n_epoch
        self.stopping_steps = stopping_steps
        self.cf_print_every = cf_print_every
        self.kg_print_every = kg_print_every
        self.evaluate_every = evaluate_every
        self.Ks = Ks
        save_dir = 'trained_model/KGAT/{}/embed-dim{}_relation-dim{}_{}_{}_{}_lr{}_pretrain{}/'.format(
        self.data_name, self.embed_dim, self.relation_dim, self.laplacian_type, self.aggregation_type,
        '-'.join([str(i) for i in eval(self.conv_dim_list)]), self.lr, self.use_pretrain)
        self.save_dir = save_dir

In [91]:
args = KGAT_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load data
data = DataLoaderKGAT(args, logging)

# load model
model = KGAT(args, data.n_users, data.n_entities, data.n_relations)
checkpoint = torch.load(args.pretrain_model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model.to(device)
pass

In [None]:
# Counts
print(data.n_items) # 0~24914
print(data.n_entities) # 24915~113486
print(data.n_users_entities) # 113487~184165

In [42]:
def find_all_path(start, end, target_r_1, target_r_2, adj_matrix):
    adj_first = adj_matrix[0].tocsr()
    adj_second = adj_matrix[target_r_1].tocsr()
    #adj_last = adj_matrix[target_r_2].tocsr()
    
    if target_r_2 == 0:
        reverse_adj_last = adj_matrix[1].tocsr()
    elif target_r_2 <= 40:
        reverse_adj_last = adj_matrix[target_r_2 + 39].tocsr()
    else:
        reverse_adj_last = adj_matrix[target_r_2 - 39].tocsr()
    
    # All items connect to start
    first_step_candidate = adj_first.indices[adj_first.indptr[start]:adj_first.indptr[start+1]]
    first_step_candidate = np.delete(first_step_candidate, np.where(first_step_candidate == end)[0])
    
    # All entity(uiei)/user(uiui) connect to end
    last_step_candidate = reverse_adj_last.indices[reverse_adj_last.indptr[end]:reverse_adj_last.indptr[end+1]]
    last_step_candidate = np.delete(last_step_candidate, np.where(last_step_candidate == start)[0])

    output_paths = []
    for h in first_step_candidate:
        tails = adj_second.indices[adj_second.indptr[h]:adj_second.indptr[h+1]]
        for t in tails:
            if t in last_step_candidate:
                output_paths.append([start, h, target_r_1, t, target_r_2, end])
    
    return output_paths

def find_all_path_all_r(start, end, relations, adj_matrix):
    all_paths = []
    for r in relations:
        if r == 1:
            r_1 = 1
            r_2 = 0
        else:
            r_1 = r
            r_2 = r + 39
        
        #t1 = time.perf_counter()
        paths = find_all_path(start, end, r_1, r_2, adj_matrix)
        all_paths += paths
        #t2 = time.perf_counter()
        #print("t = ", t2 - t1)
        paths = find_all_path(start, end, r_2, r_1, adj_matrix)
        all_paths += paths
        #t3 = time.perf_counter()
        #print("t2 = ", t3 - t2)
    return all_paths

In [87]:
# Get the score of a single path 
# 'sum', 'mul', 'max'
def path_score(path, A_in, mode='sum'):
    assert len(path) == 6  # [U, I, r_1, E, r_2, I]

    if mode == 'sum':
        return A_in[path[0]][path[1]] + A_in[path[1]][path[3]] + A_in[path[3]][path[5]]
    
    elif mode == 'mul':
        return A_in[path[0]][path[1]] * A_in[path[1]][path[3]] * A_in[path[3]][path[5]]
    
    elif mode == 'max':
        return max(A_in[path[0]][path[1]], A_in[path[1]][path[3]], A_in[path[3]][path[5]])
    
    else:
        raise ValueError('mode should be in ["sum", "mul", "max"].')

# Rerank the paths
def get_top_k_path(paths, A_in, mode='sum', k=5):
    assert len(paths) > 0
    
    scores = [path_score(path, A_in, mode) for path in paths]
    reranked_paths = [path for _, path in sorted(zip(scores,paths), reverse=True)]
    sorted_scores = sorted(scores, reverse=True)

    return reranked_paths[:k], sorted_scores[:k]

# Single user prediction
def predict_single(model, dataloader, user_id, relations, Ks, device):
    train_user_dict = dataloader.train_user_dict

    model.eval()
    input_user_id = torch.LongTensor([user_id]).to(device)

    n_items = dataloader.n_items
    item_ids = torch.arange(n_items, dtype=torch.long).to(device)
    k_max = max(Ks)

    with torch.no_grad():
        matching_scores = model(input_user_id, item_ids, mode='predict')[0].cpu().numpy()       # (n_batch_users, n_items)

    train_pos_item_list = train_user_dict[user_id]
    matching_scores[train_pos_item_list] = -np.inf
    top_k_items = np.argsort(-matching_scores)[:k_max]
    top_k_all_att_scores = []
    top_k_all_paths = []
    top_k_paths = []
    top_k_att_scores = []
    
    for item_id in top_k_items:
        paths = find_all_path_all_r(user_id, item_id, relations, dataloader.adjacency_dict)
        reranked_paths, scores = get_top_k_path(paths, model.A_in, mode='max', k=5)
        top_k_all_att_scores.append(scores)
        top_k_all_paths.append(reranked_paths)
        top_k_att_scores.append(scores[0])
        top_k_paths.append(reranked_paths[0])
    
    output = {}
    for k in Ks:
        reranked_items = [id for _, id in sorted(zip(top_k_att_scores[:k], top_k_items[:k]), reverse=True)]
        explainations = [path for _, path in sorted(zip(top_k_att_scores[:k], top_k_paths[:k]), reverse=True)]
        output[k] = {'item_ids': reranked_items, 'explainations': explainations}
    
    return output, top_k_all_att_scores, top_k_all_paths, top_k_items

# Check original model predcition hit the GT in top-k (the result is in hits_10.txt, hits_20.txt)
def predict_ori(model, dataloader, Ks, device):
    test_batch_size = dataloader.test_batch_size
    train_user_dict = dataloader.train_user_dict
    test_user_dict = dataloader.test_user_dict
    
    model.eval()
    user_ids = list(test_user_dict.keys())
    user_ids_batches = [user_ids[i: i + test_batch_size] for i in range(0, len(user_ids), test_batch_size)]
    user_ids_batches = [torch.LongTensor(d) for d in user_ids_batches]

    n_items = dataloader.n_items
    item_ids = torch.arange(n_items, dtype=torch.long).to(device)
    k_max = max(Ks)

    hit_users = []
    count = 113487
    with tqdm(total=len(user_ids_batches), desc='Evaluating Iteration') as pbar:
        for batch_user_ids in user_ids_batches:
            batch_user_ids = batch_user_ids.to(device)

            with torch.no_grad():
                matching_scores = model(batch_user_ids, item_ids, mode='predict')       # (n_batch_users, n_items)

            batch_user_ids = batch_user_ids.cpu().numpy()
            test_pos_item_binary = np.zeros([len(batch_user_ids), len(item_ids.cpu().numpy())], dtype=np.float32)
            for idx, u in enumerate(batch_user_ids):
                train_pos_item_list = train_user_dict[u]
                test_pos_item_list = test_user_dict[u]
                matching_scores[idx][train_pos_item_list] = -np.inf
                test_pos_item_binary[idx][test_pos_item_list] = 1

            try:
                _, rank_indices = torch.sort(matching_scores.cuda(), descending=True)    # try to speed up the sorting process
            except:
                _, rank_indices = torch.sort(matching_scores, descending=True)

            rank_indices = rank_indices.cpu()
            binary_hit = []
            for i in range(len(batch_user_ids)):
                binary_hit.append(test_pos_item_binary[i][rank_indices[i]])
            binary_hit = np.array(binary_hit, dtype=np.float32)[:, :k_max]
            hit_user = np.where(np.sum(binary_hit, axis=1) > 0)[0]
            hit_users += list(hit_user + count)
            count += test_batch_size
            pbar.update(1)
    
    return hit_users

# Get results of all users (explainable version)
def explainable_evaluate(model, dataloader, relations, Ks, device):
    test_batch_size = dataloader.test_batch_size
    train_user_dict = dataloader.train_user_dict
    test_user_dict = dataloader.test_user_dict

    model.eval()

    user_ids = list(test_user_dict.keys())

    n_items = dataloader.n_items
    item_ids = torch.arange(n_items, dtype=torch.long).to(device)

    metric_names = ['precision', 'recall', 'ndcg']
    metrics_dict = {k: {m: [] for m in metric_names} for k in Ks}

    with tqdm(total=len(user_ids), desc='Evaluating Iteration') as pbar:
        for user_id in user_ids:

            output, _, _ = predict_single(model, dataloader, user_id, relations, Ks, device)
            batch_metrics = exp_calc_metrics_at_k(output[max(Ks)]['item_ids'], train_user_dict, test_user_dict, [user_id], item_ids.cpu().numpy(), Ks)

            for k in Ks:
                for m in metric_names:
                    metrics_dict[k][m].append(batch_metrics[k][m])
            pbar.update(1)

    for k in Ks:
        for m in metric_names:
            metrics_dict[k][m] = np.concatenate(metrics_dict[k][m]).mean()
    return metrics_dict
    
def explainable_recommend(model, data, relations, Ks):
    # predict
    Ks = eval(Ks)
    k_min = min(Ks)
    k_max = max(Ks)

    metrics_dict = explainable_evaluate(model, data, relations, Ks, device)
    print('CF Evaluation: Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
        metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))

In [93]:
#delete [2, 3, 6, 11, 21, 22, 32, 35, 37] -> [4, 5, 8, 13, 23, 24, 34, 37, 39] ([43, 44, 47, 52, 62, 63, 73, 76, 78])
all_relations = range(2, 41)
delete_relations = [4, 5, 8, 13, 23, 24, 34, 37, 39]
target_relations = list(set(all_relations) - set(delete_relations))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

output, all_paths, all_scores, ori_top_k = predict_single(model, data, 113511, target_relations, [20], device)
#hits = predict_ori(model, data, [10], device)

Evaluating Iteration: 100%|██████████| 8/8 [02:38<00:00, 19.79s/it]


In [89]:
print(list(ori_top_k))
print(output[20]['item_ids'])
print(data.test_user_dict[113511])

[805, 4693, 324, 1473, 548, 271, 2249, 291, 537, 2841, 642, 2142, 4016, 1078, 6897, 5401, 2040, 534, 6891, 4025]
[6897, 6891, 5401, 4693, 4025, 4016, 2841, 2249, 2142, 2040, 1473, 1078, 805, 642, 548, 537, 534, 324, 291, 271]
[  805 15038]
