In [1]:
import os
import sys
import random
from time import time

import pandas as pd
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

from model.KGAT import KGAT
from parser.parser_kgat import *
from utils.log_helper import *
from utils.metrics import *
from utils.model_helper import *
from data_loader.loader_kgat import DataLoaderKGAT


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import argparse

args = argparse.Namespace()
args.seed=2020
args.data_name="naver-toy"
args.data_dir="datasets/"
args.use_pretrain=0
args.pretrain_embedding_dir="datasets/pretrain/"
args.pretrain_model_path="trained_model/model.pth"

args.cf_batch_size=32
args.kg_batch_size=32
args.test_batch_size=10000

args.embed_dim=64
args.relation_dim=64
args.laplacian_type="random-walk"
args.aggregation_type="bi-interaction"
args.conv_dim_list="[64,32,16]"
args.mess_dropout="[0.1, 0.1, 0.1]"

args.kg_l2loss_lambda=1e-5
args.cf_l2loss_lambda=1e-5

args.lr=0.0001
args.n_epoch=1000
args.stopping_steps=10

args.cf_print_every=1
args.kg_print_every=1
args.evaluate_every=1

args.Ks = "[20,40,60,80,100]"

save_dir = 'trained_model/KGAT/{}/embed-dim{}_relation-dim{}_{}_{}_{}_lr{}_pretrain{}/'.format(
        args.data_name, args.embed_dim, args.relation_dim, args.laplacian_type, args.aggregation_type,
        '-'.join([str(i) for i in eval(args.conv_dim_list)]), args.lr, args.use_pretrain)
args.save_dir = save_dir

In [3]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

log_save_id = create_log_id(args.save_dir)
logging_config(folder=args.save_dir, name='log{:d}'.format(log_save_id), no_console=False)
logging.info(args)

# GPU / CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load data
dataloader = DataLoaderKGAT(args, logging)
if args.use_pretrain == 1:
    user_pre_embed = torch.tensor(dataloader.user_pre_embed)
    item_pre_embed = torch.tensor(dataloader.item_pre_embed)
else:
    user_pre_embed, item_pre_embed = None, None

# construct model & optimizer
model = KGAT(args, dataloader.n_users, dataloader.n_entities, dataloader.n_relations, dataloader.A_in, user_pre_embed, item_pre_embed)
if args.use_pretrain == 2:
    model = load_model(model, args.pretrain_model_path)

model.to(device)
logging.info(model)

cf_optimizer = optim.Adam(model.parameters(), lr=args.lr)
kg_optimizer = optim.Adam(model.parameters(), lr=args.lr)

# initialize metrics
best_epoch = -1
best_recall = 0

Ks = eval(args.Ks)
k_min = min(Ks)
k_max = max(Ks)

epoch_list = []
metrics_list = {k: {'precision': [], 'recall': [], 'ndcg': []} for k in Ks}


2022-05-03 22:10:51,414 - root - INFO - Namespace(seed=2020, data_name='naver-toy', data_dir='datasets/', use_pretrain=0, pretrain_embedding_dir='datasets/pretrain/', pretrain_model_path='trained_model/model.pth', cf_batch_size=32, kg_batch_size=32, test_batch_size=10000, embed_dim=64, relation_dim=64, laplacian_type='random-walk', aggregation_type='bi-interaction', conv_dim_list='[64,32,16]', mess_dropout='[0.1, 0.1, 0.1]', kg_l2loss_lambda=1e-05, cf_l2loss_lambda=1e-05, lr=0.0001, n_epoch=1000, stopping_steps=10, cf_print_every=1, kg_print_every=1, evaluate_every=1, Ks='[20,40,60,80,100]', save_dir='trained_model/KGAT/naver-toy/embed-dim64_relation-dim64_random-walk_bi-interaction_64-32-16_lr0.0001_pretrain0/')


All logs will be saved to trained_model/KGAT/naver-toy/embed-dim64_relation-dim64_random-walk_bi-interaction_64-32-16_lr0.0001_pretrain0/log38.log


2022-05-03 22:10:53,425 - root - INFO - n_users:           13498
2022-05-03 22:10:53,428 - root - INFO - n_items:           7285
2022-05-03 22:10:53,429 - root - INFO - n_entities:        7348
2022-05-03 22:10:53,429 - root - INFO - n_users_entities:  20846
2022-05-03 22:10:53,430 - root - INFO - n_relations:       8
2022-05-03 22:10:53,431 - root - INFO - n_h_list:          181028
2022-05-03 22:10:53,431 - root - INFO - n_t_list:          181028
2022-05-03 22:10:53,432 - root - INFO - n_r_list:          181028
2022-05-03 22:10:53,433 - root - INFO - n_cf_train:        68659
2022-05-03 22:10:53,433 - root - INFO - n_cf_test:         27569
2022-05-03 22:10:53,434 - root - INFO - n_kg_train:        181028
  d_inv = np.power(rowsum, -1.0).flatten()
2022-05-03 22:10:58,887 - root - INFO - KGAT(
  (entity_user_embed): Embedding(20846, 64)
  (relation_embed): Embedding(8, 64)
  (aggregator_layers): ModuleList(
    (0): Aggregator(
      (message_dropout): Dropout(p=0.1, inplace=False)
      

In [4]:
cf_total_loss = 0
n_cf_batch = dataloader.n_cf_train // dataloader.cf_batch_size + 1

# for iter in range(1, n_cf_batch+1):
cf_batch_user, cf_batch_pos_item, cf_batch_neg_item = dataloader.generate_cf_batch(dataloader.train_user_dict, 
                                                                                        dataloader.cf_batch_size)

In [5]:
(cf_batch_user)

tensor([17500, 17497, 10207, 18331, 19938, 14901, 14597, 15116, 13387, 14218,
        16244,  9943, 17464,  8849, 14848, 20404, 11281, 10781, 14061, 18329,
        10711, 17252,  9787, 15072, 14759,  9068, 18909, 18326,  9997, 15394,
        17532, 16153])

In [6]:
cf_batch_pos_item

tensor([ 630, 2150, 1635, 1126, 2014,  211, 1200, 2744, 1792, 3589, 3609, 4722,
        1544, 2329,  421, 3297, 1209, 1769, 5617, 1126, 1219, 1847, 1160, 1585,
        1187,  130, 3148, 3274,   87, 3766, 3101, 6601])

In [7]:
from data_loader.cf_loader import CF_Dataset

dataset = CF_Dataset(dataloader,  seed=args.seed, phase="train_user_set")

In [8]:
from data_loader.cf_loader import CF_DataLoader
train_loader = CF_DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=lambda batch:batch, pin_memory=False)

In [9]:
next(iter(train_loader))

(tensor([17500, 17497, 10207, 18331, 19938, 14901, 14597, 15116, 13387, 14218,
         16244,  9943, 17464,  8849, 14848, 20404, 11281, 10781, 14061, 18329,
         10711, 17252,  9787, 15072, 14759,  9068, 18909, 18326,  9997, 15394,
         17532, 16153]),
 tensor([ 630, 2150, 1635, 1126, 2014,  211, 1200, 2744, 1792, 3589, 3609, 4722,
         1544, 2329,  421, 3297, 1209, 1769, 5617, 1126, 1219, 1847, 1160, 1585,
         1187,  130, 3148, 3274,   87, 3766, 3101, 6601]),
 tensor([4488, 6774, 3655, 4253, 5280, 1994, 4787, 4030,  870, 2608, 3208, 6100,
         2127, 3530, 3646,   29, 4870, 4246, 5766,  750, 3426, 1556, 6921,  339,
         2672, 5185,  834, 3664,  712,  475,  654, 7045]))

---
generate_kg_batch


In [10]:
n_kg_batch = dataloader.n_kg_train // dataloader.kg_batch_size + 1
kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail = dataloader.generate_kg_batch(dataloader.train_kg_dict, dataloader.kg_batch_size, dataloader.n_users_entities)

In [11]:
kg_batch_head

tensor([20305, 20299,  5718, 15106, 14499, 15537, 12079, 13740, 17792,  5191,
        20233,  3003, 15001,  7866,  6867, 13426,  6727, 19809,  4879, 15449,
        14822,  3440,  5298, 16093, 20368, 17611, 15582,  1338, 14508,  1616,
        11303,  7111])

In [15]:
kg_batch_relation

tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0,
        0, 0, 0, 2, 0, 1, 0, 1])

In [12]:
from data_loader.kg_loader import KG_Dataset, KG_DataLoader
kg_dataset = KG_Dataset(dataloader, seed=args.seed, phase="train_user_set")


In [13]:
kg_dataloader = KG_DataLoader(kg_dataset, batch_size=32, shuffle=False, collate_fn=lambda batch:batch,pin_memory=False)

In [14]:
next(iter(kg_dataloader))

(tensor([20305, 20299,  5718, 15106, 14499, 15537, 12079, 13740, 17792,  5191,
         20233,  3003, 15001,  7866,  6867, 13426,  6727, 19809,  4879, 15449,
         14822,  3440,  5298, 16093, 20368, 17611, 15582,  1338, 14508,  1616,
         11303,  7111]),
 tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 4, 0, 4, 0, 1, 0, 0, 1, 3, 0,
         0, 0, 0, 3, 0, 1, 0, 1]),
 tensor([ 5693,  5647, 16084,  1071,  1487,  6519,  1505,  3764,  2388, 11973,
          6858, 11443,   258,  1081,  7345,   151,  7345,  4223, 10844,   246,
           655,  8773,  7297,  5515,  4989,  3675,  5254,  7334,  2950, 14834,
           141, 20518]),
 tensor([ 9056,  7491, 18523,  1661, 20039, 13472,  1994, 12427, 17010,   870,
         18992,  6100,  8065, 18757, 10057, 16413, 20630,  3733, 13958, 10724,
          2691, 17975, 15668,  2672,  7801, 12202, 15287,  3664,  8904, 16104,
          7045, 20835]))