In [1]:
import pickle
import torch
from torch.utils.data import DataLoader

from utils.dataloader import *
from utils import graph
from utils.loss import *
from utils.evaluator import *
from models.lightgcn import LightGCN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Tmall

In [2]:
behaviors = ['click', 'fav', 'cart', 'buy']

train_file = './data/Tmall/trn_'
test_file = './data/Tmall/tst_int'


In [3]:
train_u2i = []
for i in range(len(behaviors)):
    with open(train_file + behaviors[i], 'rb') as f:
        u2i = pickle.load(f)
        train_u2i.append(u2i)

        if behaviors[i] == 'buy':
            user_num = u2i.get_shape()[0]
            item_num = u2i.get_shape()[1]

train_dataset = TrainDataset(train_u2i, behaviors, item_num)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=2048, num_workers=8)

with open(test_file, 'rb') as f:
    test_dataset = TestDataset(pickle.load(f))
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=2048, num_workers=4, pin_memory=True)


In [4]:
adj_matrix = graph.create_adj_mats(train_u2i, user_num, item_num, behaviors, device)

checkpoint = torch.load('./checkpoints/tmall.pth')

gcn = LightGCN(user_num, item_num, behaviors, adj_matrix).to(device)
gcn.load_state_dict(checkpoint)

print('Tmall Evaluation Metrics:')
with torch.no_grad():
    user_embs, item_embs = gcn()
    test_res = test(
        test_loader,
        train_u2i[-1],
        user_embs[-1].detach(),
        item_embs[-1].detach(),
    )


Tmall Evaluation Metrics:
hr@10 = 0.617431
ndcg@10 = 0.398516
