In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys  
sys.path.insert(0, '..')

In [75]:
from utils import get_tafeng_graph, load_weights
from knowledge_graph.datasets import KgPosNegTriples, TimeSplittedDataset, KgCustomers
from knowledge_graph.layer_generators import LayerNodeGenerator
from utils import get_dates_for_split, get_graph_splits, get_test_interactions
from models.Model import Model
from models.config import Config
import numpy as np
import random
from tqdm.notebook import tqdm
import sys
import torch
import torch.optim as optim
from train import train_transR_one_epoch, train_lstm_one_epoch, evaluate
from datetime import datetime

In [5]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x16a6d66e228>

In [6]:
knowledge_graph = get_tafeng_graph(user_k_core=2, item_k_core=1)

2021-05-11 12:07:43,373 - TaFengGraph - [INFO] - loading entities
2021-05-11 12:07:43,631 - TaFengGraph - [INFO] - loading relations
2021-05-11 12:09:24,436 - TaFengGraph - [INFO] - loaded purchase
2021-05-11 12:10:27,438 - TaFengGraph - [INFO] - loaded bought_in
2021-05-11 12:11:35,236 - TaFengGraph - [INFO] - loaded belongs_to_age_group
2021-05-11 12:12:40,377 - TaFengGraph - [INFO] - loaded belongs_to_subclass


In [7]:
timestamps = knowledge_graph.relation_set.get_all_timestamps()
splitting_points = get_dates_for_split(timestamps, n_points=3)

In [8]:
splits = get_graph_splits(knowledge_graph, splitting_points)
del timestamps, splitting_points

2021-05-11 12:12:41,988 - TaFengGraph - [INFO] - converting belongs_to_age_group
2021-05-11 12:12:42,686 - TaFengGraph - [INFO] - converting belongs_to_subclass
2021-05-11 12:12:43,072 - TaFengGraph - [INFO] - converting bought_in
2021-05-11 12:12:44,386 - TaFengGraph - [INFO] - converting purchase
2021-05-11 12:12:47,163 - TaFengGraph - [INFO] - converting belongs_to_age_group
2021-05-11 12:12:47,771 - TaFengGraph - [INFO] - converting belongs_to_subclass
2021-05-11 12:12:48,033 - TaFengGraph - [INFO] - converting bought_in
2021-05-11 12:12:49,139 - TaFengGraph - [INFO] - converting purchase
2021-05-11 12:12:51,486 - TaFengGraph - [INFO] - converting belongs_to_age_group
2021-05-11 12:12:51,731 - TaFengGraph - [INFO] - converting belongs_to_subclass
2021-05-11 12:12:51,926 - TaFengGraph - [INFO] - converting bought_in
2021-05-11 12:12:53,027 - TaFengGraph - [INFO] - converting purchase
2021-05-11 12:12:55,556 - TaFengGraph - [INFO] - converting belongs_to_age_group
2021-05-11 12:12:55

In [9]:
train_splits = splits[:-1]
test_split = splits[-1]

In [22]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [11]:
pos_neg_triples_ds = [KgPosNegTriples(split) for split in train_splits]
ts_ds = TimeSplittedDataset(pos_neg_triples_ds)

In [70]:
customer_indices = list(map(lambda x: knowledge_graph.entity_set.entity2idx[x], knowledge_graph.entity_set.customer))
product_indices = list(map(lambda x: knowledge_graph.entity_set.entity2idx[x], knowledge_graph.entity_set.product))

In [13]:
customer_ds = KgCustomers(
    splits=train_splits,
    customer_indices=customer_indices,
    product_indices=product_indices,
    purchase_relation_idx=knowledge_graph.relation_set.relation2idx['purchase']
)

In [14]:
len(customer_ds.customers_to_iterate_over)

5791

In [76]:
config = Config(
    entity_embedding_dim=20,
    relation_embedding_dim=20,
    n_entities=len(knowledge_graph.entity_set),
    n_relations=len(knowledge_graph.relation_set),
    n_layers=2,
    transR_l2_weight=1e-5,
    concat_layers=True
)
model = Model(
    config,
    layer_generators=[LayerNodeGenerator(split, n_neighbours=8) for split in train_splits],
    device=device
).to(device)

agg_optimizer = optim.Adam(model.transR_aggregator.parameters())
lstm_optimizer = optim.Adam(model.lstm.parameters())

In [26]:
for i in range(2):
    print(f'Epoch #{i+1}:')
    train_transR_one_epoch(model, ts_ds, agg_optimizer, batch_size=64, verbose=50, use_tqdm=True)

Epoch #1:


HBox(children=(FloatProgress(value=0.0, max=873.0), HTML(value='')))

2021-05-11 12:24:58,635 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.4329060912132263, mean loss -> 0.5610813474655152
2021-05-11 12:31:27,363 - ..\train.py - [INFO] - Iter 100: batch loss -> 0.2911437153816223, mean loss -> 0.4531164687871933
2021-05-11 12:37:54,022 - ..\train.py - [INFO] - Iter 150: batch loss -> 0.28047630190849304, mean loss -> 0.3952358799179395
2021-05-11 12:44:31,008 - ..\train.py - [INFO] - Iter 200: batch loss -> 0.262432336807251, mean loss -> 0.3598401653766632
2021-05-11 12:51:00,421 - ..\train.py - [INFO] - Iter 250: batch loss -> 0.20597711205482483, mean loss -> 0.33574014925956724
2021-05-11 12:57:27,097 - ..\train.py - [INFO] - Iter 300: batch loss -> 0.20776987075805664, mean loss -> 0.3175469387571017
2021-05-11 13:03:54,647 - ..\train.py - [INFO] - Iter 350: batch loss -> 0.18854251503944397, mean loss -> 0.3035180486525808
2021-05-11 13:10:29,933 - ..\train.py - [INFO] - Iter 400: batch loss -> 0.21498727798461914, mean loss -> 0.2929042910039


Epoch #2:


HBox(children=(FloatProgress(value=0.0, max=873.0), HTML(value='')))

2021-05-11 14:17:34,245 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.18949273228645325, mean loss -> 0.1776418435573578
2021-05-11 14:23:58,344 - ..\train.py - [INFO] - Iter 100: batch loss -> 0.16455230116844177, mean loss -> 0.17953832425177096
2021-05-11 14:30:12,259 - ..\train.py - [INFO] - Iter 150: batch loss -> 0.16066722571849823, mean loss -> 0.18058854843179384
2021-05-11 14:36:23,236 - ..\train.py - [INFO] - Iter 200: batch loss -> 0.20244310796260834, mean loss -> 0.17881967730820178
2021-05-11 14:42:42,736 - ..\train.py - [INFO] - Iter 250: batch loss -> 0.10772552341222763, mean loss -> 0.17800150194764136
2021-05-11 14:49:00,000 - ..\train.py - [INFO] - Iter 300: batch loss -> 0.21922358870506287, mean loss -> 0.1769365935275952
2021-05-11 14:55:15,106 - ..\train.py - [INFO] - Iter 350: batch loss -> 0.14353381097316742, mean loss -> 0.17620552605816295
2021-05-11 15:01:21,196 - ..\train.py - [INFO] - Iter 400: batch loss -> 0.16486482322216034, mean loss -> 0.17538




In [27]:
# torch.save({
#     'transR_aggregator': model.transR_aggregator.state_dict(),
#     'agg_optimizer': agg_optimizer.state_dict()
# }, '../checkpoints/transR_aggregator_2_epochs')

In [63]:
model.transR_aggregator

TransrAggregator(
  (kgat): KGAT(
    (relation_embedder): Embedding(4, 20)
    (aggregator): RelationAttentiveAggregator()
    (activation): LeakyReLU(negative_slope=0.01, inplace=True)
    (node_layer_updating_matrices): ModuleList(
      (0): Linear(in_features=40, out_features=20, bias=True)
      (1): Linear(in_features=40, out_features=20, bias=True)
    )
  )
  (time_entity_embeddings): ModuleList(
    (0): Embedding(55868, 20)
    (1): Embedding(55868, 20)
    (2): Embedding(55868, 20)
  )
)

In [64]:
for i in range(4):
    print(f'Epoch #{i+1}:')
    train_lstm_one_epoch(model, customer_ds, lstm_optimizer, 32, 10, True)

Epoch #1:


HBox(children=(FloatProgress(value=0.0, max=181.0), HTML(value='')))

2021-05-11 16:24:13,219 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.6823303699493408, mean loss -> 0.6867675542831421
2021-05-11 16:24:47,699 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.669663667678833, mean loss -> 0.6807757467031479
2021-05-11 16:25:21,519 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.6311533451080322, mean loss -> 0.6718840380509694
2021-05-11 16:25:55,238 - ..\train.py - [INFO] - Iter 40: batch loss -> 0.6097091436386108, mean loss -> 0.6589302226901055
2021-05-11 16:26:31,028 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.5187261700630188, mean loss -> 0.6383015930652618
2021-05-11 16:27:05,768 - ..\train.py - [INFO] - Iter 60: batch loss -> 0.5159697532653809, mean loss -> 0.6143056020140648
2021-05-11 16:27:39,619 - ..\train.py - [INFO] - Iter 70: batch loss -> 0.46740835905075073, mean loss -> 0.5968025450195585
2021-05-11 16:28:14,452 - ..\train.py - [INFO] - Iter 80: batch loss -> 0.5448513627052307, mean loss -> 0.5823419105261565
2021-05-


Epoch #2:


HBox(children=(FloatProgress(value=0.0, max=181.0), HTML(value='')))

2021-05-11 16:34:32,812 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.4137633740901947, mean loss -> 0.42937292754650114
2021-05-11 16:35:07,039 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.6064690351486206, mean loss -> 0.42957859486341476
2021-05-11 16:35:40,732 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.3989071846008301, mean loss -> 0.427229384581248
2021-05-11 16:36:14,989 - ..\train.py - [INFO] - Iter 40: batch loss -> 0.3335401713848114, mean loss -> 0.42970529571175575
2021-05-11 16:36:49,124 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.4611482620239258, mean loss -> 0.42758774638175967
2021-05-11 16:37:23,317 - ..\train.py - [INFO] - Iter 60: batch loss -> 0.3683856129646301, mean loss -> 0.4237172454595566
2021-05-11 16:37:57,657 - ..\train.py - [INFO] - Iter 70: batch loss -> 0.398837149143219, mean loss -> 0.4195801160165242
2021-05-11 16:38:33,034 - ..\train.py - [INFO] - Iter 80: batch loss -> 0.32841968536376953, mean loss -> 0.4146486181765795
2021-


Epoch #3:


HBox(children=(FloatProgress(value=0.0, max=181.0), HTML(value='')))

2021-05-11 16:44:54,753 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.6502814292907715, mean loss -> 0.4385721027851105
2021-05-11 16:45:28,853 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.4671427011489868, mean loss -> 0.41407445520162584
2021-05-11 16:46:02,908 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.35569655895233154, mean loss -> 0.4115783860286077
2021-05-11 16:46:36,910 - ..\train.py - [INFO] - Iter 40: batch loss -> 0.33238497376441956, mean loss -> 0.4014429472386837
2021-05-11 16:47:11,208 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.48731017112731934, mean loss -> 0.39144979149103165
2021-05-11 16:47:45,365 - ..\train.py - [INFO] - Iter 60: batch loss -> 0.3344988226890564, mean loss -> 0.3899109703799089
2021-05-11 16:48:19,491 - ..\train.py - [INFO] - Iter 70: batch loss -> 0.46345099806785583, mean loss -> 0.39322544102157864
2021-05-11 16:48:54,100 - ..\train.py - [INFO] - Iter 80: batch loss -> 0.3982779383659363, mean loss -> 0.39110668320208786



Epoch #4:


HBox(children=(FloatProgress(value=0.0, max=181.0), HTML(value='')))

2021-05-11 16:55:15,845 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.41007253527641296, mean loss -> 0.40030136704444885
2021-05-11 16:55:50,759 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.34083259105682373, mean loss -> 0.3834927171468735
2021-05-11 16:56:25,046 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.3755955100059509, mean loss -> 0.38216613233089447
2021-05-11 16:56:58,943 - ..\train.py - [INFO] - Iter 40: batch loss -> 0.3641565442085266, mean loss -> 0.3727788753807545
2021-05-11 16:57:33,465 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.267130970954895, mean loss -> 0.36852777898311617
2021-05-11 16:58:08,048 - ..\train.py - [INFO] - Iter 60: batch loss -> 0.4541900157928467, mean loss -> 0.37372165645162264
2021-05-11 16:58:42,465 - ..\train.py - [INFO] - Iter 70: batch loss -> 0.28186336159706116, mean loss -> 0.3831715709396771
2021-05-11 16:59:16,671 - ..\train.py - [INFO] - Iter 80: batch loss -> 0.4205564260482788, mean loss -> 0.3888462821021676
20




In [65]:
# torch.save({
#     'lstm': model.lstm.state_dict(),
#     'lstm_optimizer': lstm_optimizer.state_dict()
# }, '../checkpoints/lstm_4_epochs')

In [77]:
load_weights(model, '../checkpoints/transR_aggregator_2_epochs', '../checkpoints/lstm_4_epochs')

In [78]:
model

Model(
  (transR_aggregator): TransrAggregator(
    (kgat): KGAT(
      (relation_embedder): Embedding(4, 20)
      (aggregator): RelationAttentiveAggregator()
      (activation): LeakyReLU(negative_slope=0.01, inplace=True)
      (node_layer_updating_matrices): ModuleList(
        (0): Linear(in_features=40, out_features=20, bias=True)
        (1): Linear(in_features=40, out_features=20, bias=True)
      )
    )
    (time_entity_embeddings): ModuleList(
      (0): Embedding(55868, 20)
      (1): Embedding(55868, 20)
      (2): Embedding(55868, 20)
    )
  )
  (lstm): LSTM(40, 40, batch_first=True)
)

In [102]:
test_interactions = get_test_interactions(
    customer_ds.customers_to_iterate_over,
    test_split,
    knowledge_graph.relation_set.relation2idx['purchase']
)

In [103]:
evaluate(model, test_interactions, product_indices, 10, 20, True)

HBox(children=(FloatProgress(value=0.0, max=467.0), HTML(value='')))




(0.012119588512644516, 0.015093678557731489, 0.07501472134104832)