In [1]:
import argparse
from os import device_encoding
import numpy as np
from data_loader import load_data
from train import train
import torch


import random
import numpy as np
import torch

from model import RippleNet


seed = 2020
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = argparse.Namespace()

args.dataset = "movie"
args.dim=16
args.n_hop=2
args.kge_weight=0.01
args.l2_weight=1e-7
args.lr=0.0001
args.batch_size=32
args.n_epoch=10
args.n_memory=32
args.item_update_mode="plus_transform"
args.using_all_hops=True
args.use_cuda=True

In [3]:
device = torch.device("cuda" if args.use_cuda else torch.device("cpu"))
show_loss = False
data_info = load_data(args)

reading rating file ...
splitting dataset ...
reading KG file ...
constructing knowledge graph ...
constructing ripple set ...


In [4]:
train_data = data_info[0]
eval_data = data_info[1]
test_data = data_info[2]
n_entity = data_info[3]
n_relation = data_info[4]
ripple_set = data_info[5]

In [5]:
model = RippleNet(args, n_entity, n_relation)
if args.use_cuda:
    model.cuda()
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    args.lr,
)


In [6]:
def get_feed_dict(args, model, data, ripple_set, start, end):
    items = torch.LongTensor(data[start:end, 1])
    labels = torch.LongTensor(data[start:end, 2])
    memories_h, memories_r, memories_t = [], [], []
    for i in range(args.n_hop):
        memories_h.append(torch.LongTensor([ripple_set[user][i][0] for user in data[start:end, 0]]))
        memories_r.append(torch.LongTensor([ripple_set[user][i][1] for user in data[start:end, 0]]))
        memories_t.append(torch.LongTensor([ripple_set[user][i][2] for user in data[start:end, 0]]))
    if args.use_cuda:
        items = items.cuda()
        labels = labels.cuda()
        memories_h = list(map(lambda x: x.cuda(), memories_h))
        memories_r = list(map(lambda x: x.cuda(), memories_r))
        memories_t = list(map(lambda x: x.cuda(), memories_t))
    return items, labels, memories_h, memories_r,memories_t


In [7]:
start = 0
items, labels, memories_h, memories_r,memories_t = get_feed_dict(args, model, train_data, ripple_set, start, start + args.batch_size)

In [8]:
items

tensor([   0, 1687, 1179, 1691,  670, 1696, 2082,  939,   91, 1883, 1760, 1889,
        1383,  745,  767, 1949, 1275,  334,  688,  151,   69, 1526,  735,  725,
        1925, 2011,  421,  768, 1554, 1175,  669,  418], device='cuda:0')

In [9]:
from data_loader import CustomDataset, CustomDataLoader

dataset = CustomDataset(args, train_data, ripple_set)

In [10]:
dataloader = CustomDataLoader(dataset, batch_size=32, shuffle=False, sampler=None, collate_fn=lambda batch:batch, pin_memory=False)

In [11]:
train_data.shape

(452257, 3)

In [14]:
labels

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 1, 1, 1, 1], device='cuda:0')

In [12]:
next(iter(dataloader))

(tensor([   0, 1687, 1179, 1691,  670, 1696, 2082,  939,   91, 1883, 1760, 1889,
         1383,  745,  767, 1949, 1275,  334,  688,  151,   69, 1526,  735,  725,
         1925, 2011,  421,  768, 1554, 1175,  669], device='cuda:0'),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 1, 1, 1, 1], device='cuda:0'),
 [tensor([[1760, 1179, 1696, 1383,  670,  745,    0, 1383, 1691,    0, 1696,  939,
            767, 1696,   91,  939,    0, 1687, 1383, 1696, 1179, 1687, 1179,    0,
              0,   91, 1687, 1883, 1889,    0,  745,  939],
          [1760, 1179, 1696, 1383,  670,  745,    0, 1383, 1691,    0, 1696,  939,
            767, 1696,   91,  939,    0, 1687, 1383, 1696, 1179, 1687, 1179,    0,
              0,   91, 1687, 1883, 1889,    0,  745,  939],
          [1760, 1179, 1696, 1383,  670,  745,    0, 1383, 1691,    0, 1696,  939,
            767, 1696,   91,  939,    0, 1687, 1383, 1696, 1179, 1687, 1179,    0,
              0,   9