In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys  
sys.path.insert(0, '..')

In [3]:
from utils import get_tafeng_graph, load_weights
from knowledge_graph.datasets import KgPosNegTriples, TimeSplittedDataset, KgCustomers
from knowledge_graph.layer_generators import LayerNodeGenerator
from utils import get_dates_for_split, get_graph_splits, get_test_interactions
from models.Model import Model
from models.config import Config
import numpy as np
import random
from tqdm.notebook import tqdm
import sys
import torch
import torch.optim as optim
from train import train_transR_one_epoch, train_lstm_one_epoch, evaluate
from datetime import datetime

In [4]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x248ab620858>

In [5]:
knowledge_graph = get_tafeng_graph(user_k_core=2, item_k_core=1)

2021-05-11 18:36:55,723 - TaFengGraph - [INFO] - loading entities
2021-05-11 18:36:55,978 - TaFengGraph - [INFO] - loading relations
2021-05-11 18:38:37,854 - TaFengGraph - [INFO] - loaded purchase
2021-05-11 18:39:40,246 - TaFengGraph - [INFO] - loaded bought_in
2021-05-11 18:40:41,506 - TaFengGraph - [INFO] - loaded belongs_to_age_group
2021-05-11 18:41:42,857 - TaFengGraph - [INFO] - loaded belongs_to_subclass


In [6]:
timestamps = knowledge_graph.relation_set.get_all_timestamps()
splitting_points = get_dates_for_split(timestamps, n_points=5)

In [7]:
splitting_points

[datetime.datetime(2000, 11, 19, 3, 0),
 datetime.datetime(2000, 12, 8, 3, 0),
 datetime.datetime(2001, 1, 1, 3, 0),
 datetime.datetime(2001, 1, 21, 3, 0),
 datetime.datetime(2001, 2, 10, 3, 0)]

In [8]:
splits = get_graph_splits(knowledge_graph, splitting_points)
del timestamps, splitting_points

2021-05-11 18:42:21,304 - TaFengGraph - [INFO] - converting bought_in
2021-05-11 18:42:22,385 - TaFengGraph - [INFO] - converting belongs_to_age_group
2021-05-11 18:42:22,766 - TaFengGraph - [INFO] - converting belongs_to_subclass
2021-05-11 18:42:23,123 - TaFengGraph - [INFO] - converting purchase
2021-05-11 18:42:24,761 - TaFengGraph - [INFO] - converting bought_in
2021-05-11 18:42:25,692 - TaFengGraph - [INFO] - converting belongs_to_age_group
2021-05-11 18:42:25,914 - TaFengGraph - [INFO] - converting belongs_to_subclass
2021-05-11 18:42:26,092 - TaFengGraph - [INFO] - converting purchase
2021-05-11 18:42:27,789 - TaFengGraph - [INFO] - converting bought_in
2021-05-11 18:42:28,821 - TaFengGraph - [INFO] - converting belongs_to_age_group
2021-05-11 18:42:29,037 - TaFengGraph - [INFO] - converting belongs_to_subclass
2021-05-11 18:42:29,217 - TaFengGraph - [INFO] - converting purchase
2021-05-11 18:42:30,939 - TaFengGraph - [INFO] - converting bought_in
2021-05-11 18:42:31,595 - TaFe

In [9]:
train_splits = splits[:-1]
test_split = splits[-1]

In [10]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [11]:
pos_neg_triples_ds = [KgPosNegTriples(split) for split in train_splits]
ts_ds = TimeSplittedDataset(pos_neg_triples_ds)

In [12]:
customer_indices = list(map(lambda x: knowledge_graph.entity_set.entity2idx[x], knowledge_graph.entity_set.customer))
product_indices = list(map(lambda x: knowledge_graph.entity_set.entity2idx[x], knowledge_graph.entity_set.product))

In [13]:
customer_ds = KgCustomers(
    splits=train_splits,
    customer_indices=customer_indices,
    product_indices=product_indices,
    purchase_relation_idx=knowledge_graph.relation_set.relation2idx['purchase']
)

In [14]:
len(customer_ds.customers_to_iterate_over)

2018

In [17]:
config = Config(
    entity_embedding_dim=40,
    relation_embedding_dim=40,
    n_entities=len(knowledge_graph.entity_set),
    n_relations=len(knowledge_graph.relation_set),
    n_layers=2,
    transR_l2_weight=1e-5,
    concat_layers=True
)
model = Model(
    config,
    layer_generators=[LayerNodeGenerator(split, n_neighbours=10) for split in train_splits],
    device=device
).to(device)

agg_optimizer = optim.Adam(model.transR_aggregator.parameters())
lstm_optimizer = optim.Adam(model.lstm.parameters())

In [18]:
for i in range(10):
    print(f'Epoch #{i+1}:')
    train_transR_one_epoch(model, ts_ds, agg_optimizer, batch_size=64, verbose=50, use_tqdm=True)
    torch.save({
        'transR_aggregator': model.transR_aggregator.state_dict(),
        'agg_optimizer': agg_optimizer.state_dict()
    }, f'../checkpoints/transR_aggregator_{i}_epochs_40_40_10nb')

Epoch #1:


HBox(children=(FloatProgress(value=0.0, max=873.0), HTML(value='')))

2021-05-11 18:57:31,221 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.23862336575984955, mean loss -> 0.4497533056139946
2021-05-11 19:09:02,037 - ..\train.py - [INFO] - Iter 100: batch loss -> 0.19406430423259735, mean loss -> 0.3391272848844528
2021-05-11 19:21:08,044 - ..\train.py - [INFO] - Iter 150: batch loss -> 0.1555260270833969, mean loss -> 0.2881516190369924
2021-05-11 19:33:08,699 - ..\train.py - [INFO] - Iter 200: batch loss -> 0.205116868019104, mean loss -> 0.26186070501804354
2021-05-11 19:45:10,772 - ..\train.py - [INFO] - Iter 250: batch loss -> 0.15304410457611084, mean loss -> 0.24451343035697937
2021-05-11 19:57:19,082 - ..\train.py - [INFO] - Iter 300: batch loss -> 0.16409720480442047, mean loss -> 0.23085244578619799
2021-05-11 20:08:48,634 - ..\train.py - [INFO] - Iter 350: batch loss -> 0.1751822978258133, mean loss -> 0.22017998591065407
2021-05-11 20:20:24,017 - ..\train.py - [INFO] - Iter 400: batch loss -> 0.1601414829492569, mean loss -> 0.21156184747


Epoch #2:


HBox(children=(FloatProgress(value=0.0, max=873.0), HTML(value='')))

2021-05-11 22:20:45,308 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.1389664262533188, mean loss -> 0.11977308943867683
2021-05-11 22:32:15,311 - ..\train.py - [INFO] - Iter 100: batch loss -> 0.09873940795660019, mean loss -> 0.1200471780449152
2021-05-11 22:43:47,943 - ..\train.py - [INFO] - Iter 150: batch loss -> 0.1084408089518547, mean loss -> 0.12148398816585541
2021-05-11 22:55:18,026 - ..\train.py - [INFO] - Iter 200: batch loss -> 0.10475680977106094, mean loss -> 0.12020584288984537
2021-05-11 23:06:51,051 - ..\train.py - [INFO] - Iter 250: batch loss -> 0.11463617533445358, mean loss -> 0.1200295937359333
2021-05-11 23:18:20,829 - ..\train.py - [INFO] - Iter 300: batch loss -> 0.14164726436138153, mean loss -> 0.12014813462893169
2021-05-11 23:29:51,618 - ..\train.py - [INFO] - Iter 350: batch loss -> 0.13513967394828796, mean loss -> 0.12007419426526342
2021-05-11 23:41:23,351 - ..\train.py - [INFO] - Iter 400: batch loss -> 0.09856875985860825, mean loss -> 0.1197597


Epoch #3:


HBox(children=(FloatProgress(value=0.0, max=873.0), HTML(value='')))

2021-05-12 01:41:43,108 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.10192755609750748, mean loss -> 0.10874966725707054
2021-05-12 01:53:12,940 - ..\train.py - [INFO] - Iter 100: batch loss -> 0.10293164104223251, mean loss -> 0.10965817496180534
2021-05-12 02:04:43,100 - ..\train.py - [INFO] - Iter 150: batch loss -> 0.10254883766174316, mean loss -> 0.10856145794192949
2021-05-12 02:16:13,673 - ..\train.py - [INFO] - Iter 200: batch loss -> 0.12156584113836288, mean loss -> 0.10813035558909177
2021-05-12 02:27:46,670 - ..\train.py - [INFO] - Iter 250: batch loss -> 0.11363563686609268, mean loss -> 0.1080582674741745
2021-05-12 02:39:17,426 - ..\train.py - [INFO] - Iter 300: batch loss -> 0.0657184049487114, mean loss -> 0.10757759774724643
2021-05-12 02:50:48,568 - ..\train.py - [INFO] - Iter 350: batch loss -> 0.11249148100614548, mean loss -> 0.10863693205373628
2021-05-12 03:02:18,105 - ..\train.py - [INFO] - Iter 400: batch loss -> 0.08357327431440353, mean loss -> 0.10821


Epoch #4:


HBox(children=(FloatProgress(value=0.0, max=873.0), HTML(value='')))

2021-05-12 05:02:30,795 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.13374115526676178, mean loss -> 0.10052919238805771
2021-05-12 05:14:01,496 - ..\train.py - [INFO] - Iter 100: batch loss -> 0.09244241565465927, mean loss -> 0.10387442834675312
2021-05-12 05:25:31,075 - ..\train.py - [INFO] - Iter 150: batch loss -> 0.0954880639910698, mean loss -> 0.10214481666684151
2021-05-12 05:37:00,069 - ..\train.py - [INFO] - Iter 200: batch loss -> 0.07620261609554291, mean loss -> 0.10063076481223106
2021-05-12 05:48:32,469 - ..\train.py - [INFO] - Iter 250: batch loss -> 0.10760956257581711, mean loss -> 0.10102837723493575
2021-05-12 06:00:04,043 - ..\train.py - [INFO] - Iter 300: batch loss -> 0.0779295265674591, mean loss -> 0.10162569561352332
2021-05-12 06:11:34,481 - ..\train.py - [INFO] - Iter 350: batch loss -> 0.09632392972707748, mean loss -> 0.10147434823215008
2021-05-12 06:23:02,156 - ..\train.py - [INFO] - Iter 400: batch loss -> 0.09485574811697006, mean loss -> 0.10153


Epoch #5:


HBox(children=(FloatProgress(value=0.0, max=873.0), HTML(value='')))

2021-05-12 08:23:16,137 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.06150885298848152, mean loss -> 0.10016116268932819
2021-05-12 08:34:52,844 - ..\train.py - [INFO] - Iter 100: batch loss -> 0.11105694621801376, mean loss -> 0.09926087852567435
2021-05-12 08:46:29,163 - ..\train.py - [INFO] - Iter 150: batch loss -> 0.08770490437746048, mean loss -> 0.09843596145510673
2021-05-12 08:57:58,207 - ..\train.py - [INFO] - Iter 200: batch loss -> 0.09292431175708771, mean loss -> 0.09703299093991519
2021-05-12 09:09:32,149 - ..\train.py - [INFO] - Iter 250: batch loss -> 0.0992947444319725, mean loss -> 0.09698767648637295
2021-05-12 09:21:06,385 - ..\train.py - [INFO] - Iter 300: batch loss -> 0.06772222369909286, mean loss -> 0.09694320282588402
2021-05-12 09:32:37,693 - ..\train.py - [INFO] - Iter 350: batch loss -> 0.09673847258090973, mean loss -> 0.09687722657408034
2021-05-12 09:44:06,837 - ..\train.py - [INFO] - Iter 400: batch loss -> 0.10891266912221909, mean loss -> 0.0965


Epoch #6:


HBox(children=(FloatProgress(value=0.0, max=873.0), HTML(value='')))

2021-05-12 11:44:23,708 - ..\train.py - [INFO] - Iter 50: batch loss -> 0.06110602244734764, mean loss -> 0.09379192166030408
2021-05-12 11:55:55,148 - ..\train.py - [INFO] - Iter 100: batch loss -> 0.1005798727273941, mean loss -> 0.09453488323837518


KeyboardInterrupt: 

In [27]:
# torch.save({
#     'transR_aggregator': model.transR_aggregator.state_dict(),
#     'agg_optimizer': agg_optimizer.state_dict()
# }, '../checkpoints/transR_aggregator_2_epochs')

In [19]:
model.transR_aggregator

TransrAggregator(
  (kgat): KGAT(
    (relation_embedder): Embedding(4, 40)
    (aggregator): RelationAttentiveAggregator()
    (activation): LeakyReLU(negative_slope=0.01, inplace=True)
    (node_layer_updating_matrices): ModuleList(
      (0): Linear(in_features=80, out_features=40, bias=True)
      (1): Linear(in_features=80, out_features=40, bias=True)
    )
  )
  (time_entity_embeddings): ModuleList(
    (0): Embedding(55868, 40)
    (1): Embedding(55868, 40)
    (2): Embedding(55868, 40)
    (3): Embedding(55868, 40)
    (4): Embedding(55868, 40)
  )
)

In [20]:
for i in range(10):
    print(f'Epoch #{i+1}:')
    train_lstm_one_epoch(model, customer_ds, lstm_optimizer, 64, 10, True)

Epoch #1:


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

2021-05-12 12:05:58,865 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.6504267454147339, mean loss -> 0.6723116338253021
2021-05-12 12:08:38,496 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.5211248397827148, mean loss -> 0.629296749830246
2021-05-12 12:11:18,147 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.4017091989517212, mean loss -> 0.5603638390700022



Epoch #2:


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

2021-05-12 12:14:23,561 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.45109236240386963, mean loss -> 0.35554543435573577
2021-05-12 12:17:03,622 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.32455986738204956, mean loss -> 0.3451751470565796
2021-05-12 12:19:42,407 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.283122718334198, mean loss -> 0.33316895067691804



Epoch #3:


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

2021-05-12 12:22:49,613 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.3607744872570038, mean loss -> 0.2960245326161385
2021-05-12 12:25:30,752 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.2563594579696655, mean loss -> 0.2914949171245098
2021-05-12 12:28:09,670 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.45373111963272095, mean loss -> 0.2953773121039073



Epoch #4:


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

2021-05-12 12:31:14,658 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.3153270483016968, mean loss -> 0.25811757445335387
2021-05-12 12:33:53,873 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.1972748339176178, mean loss -> 0.2551142893731594
2021-05-12 12:36:32,079 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.33990249037742615, mean loss -> 0.2608024929960569



Epoch #5:


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

2021-05-12 12:39:38,645 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.2647398114204407, mean loss -> 0.25233646482229233
2021-05-12 12:42:17,769 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.23195703327655792, mean loss -> 0.2545634150505066
2021-05-12 12:44:57,796 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.2962835729122162, mean loss -> 0.266706391175588



Epoch #6:


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

2021-05-12 12:48:04,755 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.29964479804039, mean loss -> 0.23459210395812988
2021-05-12 12:50:43,620 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.2221505343914032, mean loss -> 0.24852111041545868
2021-05-12 12:53:24,615 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.17506732046604156, mean loss -> 0.25122401416301726



Epoch #7:


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

2021-05-12 12:56:32,570 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.2305637151002884, mean loss -> 0.24491228759288788
2021-05-12 12:59:12,160 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.2657455801963806, mean loss -> 0.24114710763096808
2021-05-12 13:01:51,340 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.19023822247982025, mean loss -> 0.23994937390089036



Epoch #8:


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

2021-05-12 13:04:56,791 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.18884576857089996, mean loss -> 0.24196859449148178
2021-05-12 13:07:35,629 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.23561444878578186, mean loss -> 0.2281319208443165
2021-05-12 13:10:15,026 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.2811773121356964, mean loss -> 0.23388133943080902



Epoch #9:


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

2021-05-12 13:13:20,133 - ..\train.py - [INFO] - Iter 10: batch loss -> 0.26523536443710327, mean loss -> 0.23100804090499877
2021-05-12 13:15:58,474 - ..\train.py - [INFO] - Iter 20: batch loss -> 0.1809316724538803, mean loss -> 0.24011314809322357
2021-05-12 13:18:46,875 - ..\train.py - [INFO] - Iter 30: batch loss -> 0.27311423420906067, mean loss -> 0.23443628698587418



Epoch #10:


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

KeyboardInterrupt: 

In [77]:
load_weights(model, '../checkpoints/transR_aggregator_2_epochs', '../checkpoints/lstm_4_epochs')

In [21]:
model

Model(
  (transR_aggregator): TransrAggregator(
    (kgat): KGAT(
      (relation_embedder): Embedding(4, 40)
      (aggregator): RelationAttentiveAggregator()
      (activation): LeakyReLU(negative_slope=0.01, inplace=True)
      (node_layer_updating_matrices): ModuleList(
        (0): Linear(in_features=80, out_features=40, bias=True)
        (1): Linear(in_features=80, out_features=40, bias=True)
      )
    )
    (time_entity_embeddings): ModuleList(
      (0): Embedding(55868, 40)
      (1): Embedding(55868, 40)
      (2): Embedding(55868, 40)
      (3): Embedding(55868, 40)
      (4): Embedding(55868, 40)
    )
  )
  (lstm): LSTM(80, 80, batch_first=True)
)

In [22]:
test_interactions = get_test_interactions(
    customer_ds.customers_to_iterate_over,
    test_split,
    knowledge_graph.relation_set.relation2idx['purchase']
)

In [23]:
evaluate(model, test_interactions, product_indices, 32, 20, True)

HBox(children=(FloatProgress(value=0.0, max=56.0), HTML(value='')))




(0.0429541595925292, 0.08082151887958186, 0.572799475568759)