In [1]:
import torch
import os
import numpy as np
import tglite as tg

from tgat import TGAT
import support

In [2]:
DATA: str = 'wiki'  # 'wiki', 'reddit', 'mooc', 'mag', 'lastfm', 'gdelt', 'wiki-talk'
DATA_PATH: str = '/shared'
EPOCHS: int = 10
BATCH_SIZE: int = 200
LEARN_RATE: float = 0.0001
DROPOUT: float = 0.1
N_LAYERS: int = 2
N_HEADS: int = 2
N_NBRS: int = 20
DIM_TIME: int = 100
DIM_EMBED: int = 100
N_THREADS: int = 32
SAMPLING: str = 'recent'  # 'recent'or 'uniform'
OPT_DEDUP = True
OPT_CACHE = True
OPT_TIME = True
OPT_ALL = True
OPT_DEDUP: bool = OPT_DEDUP or OPT_ALL
OPT_CACHE: bool = OPT_CACHE or OPT_ALL
OPT_TIME: bool = OPT_TIME or OPT_ALL
CACHE_LIMIT: int = int(2e6)
TIME_WINDOW: int = int(1e4)

MOVE = True
GPU = 0
SEED = 1
PREFIX = ''

In [3]:
device = support.make_device(GPU)
model_path = support.make_model_path('tgat', PREFIX, DATA)
if SEED >= 0:
    support.set_seed(SEED)

In [4]:
### load data

g = support.load_graph(os.path.join(DATA_PATH, f'data/{DATA}/edges.csv'))
support.load_feats(g, DATA, DATA_PATH)
dim_efeat = 0 if g.efeat is None else g.efeat.shape[1]
dim_nfeat = g.nfeat.shape[1]

g.set_compute(device)
if MOVE:
    g.move_data(device)

ctx = tg.TContext(g)
ctx.need_sampling(True)
ctx.enable_embed_caching(OPT_CACHE, DIM_EMBED)
ctx.enable_time_precompute(OPT_TIME)
ctx.set_cache_limit(CACHE_LIMIT)
ctx.set_time_window(TIME_WINDOW)

num edges: 157474
num nodes: 9228
edge feat: torch.Size([157474, 172])
node feat: torch.Size([9228, 172])


In [5]:
### model

sampler = tg.TSampler(N_NBRS, strategy=SAMPLING, num_threads=N_THREADS)
model = TGAT(ctx,
    dim_node=dim_nfeat,
    dim_edge=dim_efeat,
    dim_time=DIM_TIME,
    dim_embed=DIM_EMBED,
    sampler=sampler,
    num_layers=N_LAYERS,
    num_heads=N_HEADS,
    dropout=DROPOUT,
    dedup=OPT_DEDUP,)
model = model.to(device)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE)

In [6]:
### training

train_end, val_end = support.data_split(g.num_edges(), 0.7, 0.15)
neg_sampler = lambda size: np.random.randint(0, g.num_nodes(), size)

trainer = support.LinkPredTrainer(
    ctx, model, criterion, optimizer, neg_sampler,
    EPOCHS, BATCH_SIZE, train_end, val_end,
    model_path, None)

trainer.train()
trainer.test()

epoch 0:
  loss:295.7572 val ap:0.9739 val auc:0.9782
  epoch | total:14.08s loop:12.34s eval:1.74s
   loop | forward:7.82s backward:4.46s sample:0.49s prep_batch:0.05s prep_input:1.03s post_update:0.00s
   comp | mem_update:0.00s time_zero:1.15s time_nbrs:0.93s self_attn:3.64s
epoch 1:
  loss:170.5700 val ap:0.9828 val auc:0.9853
  epoch | total:13.69s loop:11.75s eval:1.94s
   loop | forward:6.86s backward:4.83s sample:0.50s prep_batch:0.05s prep_input:0.93s post_update:0.00s
   comp | mem_update:0.00s time_zero:0.21s time_nbrs:1.18s self_attn:3.71s
epoch 2:
  loss:142.4620 val ap:0.9828 val auc:0.9855
  epoch | total:14.43s loop:12.56s eval:1.86s
   loop | forward:7.24s backward:5.26s sample:0.48s prep_batch:0.05s prep_input:0.82s post_update:0.00s
   comp | mem_update:0.00s time_zero:0.25s time_nbrs:1.30s self_attn:3.99s
epoch 3:
  loss:128.9173 val ap:0.9850 val auc:0.9872
  epoch | total:13.71s loop:12.58s eval:1.12s
   loop | forward:6.61s backward:5.91s sample:0.52s prep_batch: