# Training a PyG Temporal Graph

This notebook is for training a PyTorch Geometric Temporal (PYG-TGN/TGAT)

> !NOTE
> MUST REFERENCE THIS PAPER: https://arxiv.org/abs/2006.10637

In [1]:
import sys
import os

# Go up 3 levels: rat → ibm_transactions_scripts → scripts → ROOT
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
print("Root directory:", ROOT_DIR)

# Add root directory to Python path
sys.path.append(ROOT_DIR)


Root directory: c:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection


In [2]:
import subprocess
print(subprocess.getoutput("nvidia-smi"))


Sun Nov 23 00:10:20 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 581.29                 Driver Version: 581.29         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4080 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   56C    P5             27W /  159W |    1076MiB /  12282MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [3]:
import torch
print("Torch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("Is CUDA available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")


Torch version: 2.5.1+cu121
CUDA version: 12.1
Is CUDA available: True
Device: NVIDIA GeForce RTX 4080 Laptop GPU


In [4]:
import torch
from tgn.modules.memory import Memory
from tgn.model.tgn import TGN

edge_index = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\edge_index.pt")
x = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\x.pt")
edge_attr = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\edge_attr.pt")
timestamps = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\timestamps.pt")
y_edge = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\y_edge.pt")
y_node = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\y_node.pt")


  edge_index = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\edge_index.pt")
  x = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\x.pt")
  edge_attr = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\edge_attr.pt")
  timestamps = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\timestamps.pt")
  y_edge = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Det

In [5]:
print("edge_index:", edge_index.shape)
print("x:", x.shape)
print("edge_attr:", edge_attr.shape)
print("timestamps:", timestamps.shape)
print("y_edge:", y_edge.shape)
print("y_node:", y_node.shape)

print("labels distribution:", y_edge.bincount())

edge_index: torch.Size([2, 5082345])
x: torch.Size([515080, 27])
edge_attr: torch.Size([5082345, 3])
timestamps: torch.Size([5082345])
y_edge: torch.Size([5082345])
y_node: torch.Size([515080])
labels distribution: tensor([5073168,    9177])


## Create a time-sorted view + train/val/test split

TGN is temporal, so splitting must be chronological, not random.

In [6]:
import torch

# Sort edges by time (important for temporal models)
sorted_idx = torch.argsort(timestamps)

edge_index_sorted = edge_index[:, sorted_idx]
edge_attr_sorted  = edge_attr[sorted_idx]
timestamps_sorted = timestamps[sorted_idx]
y_edge_sorted     = y_edge[sorted_idx]

num_edges = edge_index_sorted.size(1)
print("Num edges (sorted):", num_edges)

# Temporal split: 70% train, 15% val, 15% test
train_ratio = 0.7
val_ratio   = 0.15

train_end = int(train_ratio * num_edges)
val_end   = int((train_ratio + val_ratio) * num_edges)

train_slice = slice(0, train_end)
val_slice   = slice(train_end, val_end)
test_slice  = slice(val_end, num_edges)

print("Train edges:", train_end)
print("Val edges:", val_end - train_end)
print("Test edges:", num_edges - val_end)


Num edges (sorted): 5082345
Train edges: 3557641
Val edges: 762352
Test edges: 762352


## Build a tiny helper “dataset” class for TGN

In [7]:
class EdgeSequence:
    """
    Thin wrapper around your sorted edge tensors.
    This matches the typical TGN-style interface:
      - sources
      - destinations
      - timestamps
      - edge_features
      - labels
    """
    def __init__(self, edge_index, edge_attr, timestamps, labels):
        # edge_index: [2, E]
        self.sources      = edge_index[0]      # [E]
        self.destinations = edge_index[1]      # [E]
        self.timestamps   = timestamps         # [E]
        self.edge_features = edge_attr         # [E, D_e]
        self.labels       = labels             # [E]

        self.num_edges = self.sources.size(0)

    def get_batch(self, start, end):
        """
        Return a slice batch [start:end] as
        (src, dst, t, edge_feat, labels)
        """
        s = slice(start, end)
        return (self.sources[s],
                self.destinations[s],
                self.timestamps[s],
                self.edge_features[s],
                self.labels[s])


In [8]:
train_data = EdgeSequence(
    edge_index_sorted[:, train_slice],
    edge_attr_sorted[train_slice],
    timestamps_sorted[train_slice],
    y_edge_sorted[train_slice]
)

val_data = EdgeSequence(
    edge_index_sorted[:, val_slice],
    edge_attr_sorted[val_slice],
    timestamps_sorted[val_slice],
    y_edge_sorted[val_slice]
)

test_data = EdgeSequence(
    edge_index_sorted[:, test_slice],
    edge_attr_sorted[test_slice],
    timestamps_sorted[test_slice],
    y_edge_sorted[test_slice]
)

print("Train edges:", train_data.num_edges)
print("Val edges:", val_data.num_edges)
print("Test edges:", test_data.num_edges)


Train edges: 3557641
Val edges: 762352
Test edges: 762352


## Plugging into TGN implementation

#### Creating a tiny MLP link predictor

In [10]:
import torch
import torch.nn as nn

class LinkPredictor(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_dim, in_dim),
            nn.ReLU(),
            nn.Linear(in_dim, 1)
        )

    def forward(self, src_emb, dst_emb):
        x = torch.cat([src_emb, dst_emb], dim=-1)
        return torch.sigmoid(self.mlp(x))


In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Convert node & edge features to numpy for TGN
node_features_np = x.cpu().numpy().astype(np.float32)
edge_features_np = edge_attr_sorted.cpu().numpy().astype(np.float32)

num_nodes = node_features_np.shape[0]
num_edges = edge_features_np.shape[0]

# ----------------------------------------------------
# 1. Build NeighborFinder (REQUIRED BY TGN)
# ----------------------------------------------------
from tgn.utils.neighbor_finder import NeighborFinder

src_np = edge_index_sorted[0].cpu().numpy()
dst_np = edge_index_sorted[1].cpu().numpy()
ts_np  = timestamps_sorted.cpu().numpy()
eid_np = np.arange(num_edges)

adj_list = [[] for _ in range(num_nodes)]
for s, d, t, eid in zip(src_np, dst_np, ts_np, eid_np):
    adj_list[s].append((d, eid, t))
    adj_list[d].append((s, eid, t))   # IBM is undirected temporal

neighbor_finder = NeighborFinder(adj_list, uniform=True)

# ----------------------------------------------------
# 2. Create TGN model properly
# ----------------------------------------------------
from tgn.model.tgn import TGN

tgn = TGN(
    neighbor_finder=neighbor_finder,
    node_features=node_features_np,
    edge_features=edge_features_np,
    device=device,
    n_layers=2,
    n_heads=2,
    dropout=0.1,
    use_memory=True,
    memory_update_at_start=True,

    # MUST MATCH feature dimensions
    message_dimension=edge_features_np.shape[1],     # e.g., ~27–50
    memory_dimension=node_features_np.shape[1],      # = 27

    embedding_module_type="graph_attention",
    message_function="mlp",
    n_neighbors=20,
).to(device)


# ----------------------------------------------------
# 3. Link predictor MLP
# ----------------------------------------------------
class LinkPredictor(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim * 2, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, 1)
        )

    def forward(self, src, dst):
        return self.mlp(torch.cat([src, dst], dim=-1)).squeeze(-1)

link_predictor = LinkPredictor(emb_dim=node_features_np.shape[1]).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(list(tgn.parameters()) + list(link_predictor.parameters()), lr=1e-3)


## Building training dataset with negative sampling

TGN requires training triplets:
(source_node, positive_destination_node, negative_destination_node)
at timestamp t

This means we need to construct:

- src_np — source nodes

- dst_np — destination nodes

- neg_dst_np — randomly sampled negative nodes

- ts_np — timestamps

- eid_np — edge indices

- y_np — labels (1 for real edges, 0 for negative edges)

#### Creating training triplets

In [12]:
import numpy as np

num_edges = edge_index_sorted.shape[1]
num_nodes = x.shape[0]

# Positive edges
src = edge_index_sorted[0].cpu().numpy()
dst = edge_index_sorted[1].cpu().numpy()
ts  = timestamps_sorted.cpu().numpy()
eid = np.arange(num_edges)

# Negative sampling: sample random destination nodes
neg_dst = np.random.randint(0, num_nodes, size=num_edges)

# Convert to tensors for TGN
src_t = torch.from_numpy(src).long().to(device)
dst_t = torch.from_numpy(dst).long().to(device)
neg_t = torch.from_numpy(neg_dst).long().to(device)
ts_t  = torch.from_numpy(ts).float().to(device)
eid_t = torch.from_numpy(eid).long().to(device)

### Build a DataLoader for batching
TGN must process interactions in temporal order, so a simple sequential batch loader works.

In [13]:
def get_batch_indices(batch_size, total):
    return [(i, min(i+batch_size, total)) for i in range(0, total, batch_size)]

batches = get_batch_indices(20000, num_edges)
len(batches)

255

## Training Loop

In [18]:
import math
import time
import torch.nn as nn

# ============================================================
# CONFIG: much smaller subset + smaller batches for debugging
# ============================================================
MAX_EDGES_DEBUG = 20_000    # <-- try 20k first, then you can scale up
BATCH_SIZE      = 1_000     # <-- smaller batch so each step is visible
NUM_EPOCHS      = 1         # <-- start with 1 epoch just to verify it runs
LR              = 1e-4

# Slice tensors to a manageable size
src_debug = src_t[:MAX_EDGES_DEBUG]
dst_debug = dst_t[:MAX_EDGES_DEBUG]
neg_debug = neg_t[:MAX_EDGES_DEBUG]
ts_debug  = ts_t[:MAX_EDGES_DEBUG]
eid_debug = eid_t[:MAX_EDGES_DEBUG]

num_edges_debug = src_debug.size(0)
print(f"Using {num_edges_debug:,} edges for debug training")

# Build batches over the debug subset
batches = [
    (start, min(start + BATCH_SIZE, num_edges_debug))
    for start in range(0, num_edges_debug, BATCH_SIZE)
]
print(f"Total batches per epoch: {len(batches)}")

# ============================================================
# LOSS + OPTIMIZER
# ============================================================
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(tgn.parameters(), lr=LR)

tgn.train()
losses = []

for epoch in range(NUM_EPOCHS):
    print(f"\n=== Epoch {epoch + 1}/{NUM_EPOCHS} ===")

    # Reset memory at epoch start
    if tgn.use_memory:
        tgn.memory.__init_memory__()

    epoch_loss = 0.0
    epoch_start = time.perf_counter()

    for batch_idx, (start, end) in enumerate(batches):
        batch_start = time.perf_counter()

        # Select batch window from the *debug* tensors
        src_b = src_debug[start:end]
        dst_b = dst_debug[start:end]
        neg_b = neg_debug[start:end]
        ts_b  = ts_debug[start:end]
        eid_b = eid_debug[start:end]

        # ----------------------------------------------------
        # Forward pass through TGN (TGN expects numpy arrays)
        # ----------------------------------------------------
        pos_scores, neg_scores = tgn.compute_edge_probabilities(
            src_b.cpu().numpy(),
            dst_b.cpu().numpy(),
            neg_b.cpu().numpy(),
            ts_b.cpu().numpy(),
            eid_b.cpu().numpy(),
        )

        # Move back to device for loss computation
        pos_scores = pos_scores.to(device)
        neg_scores = neg_scores.to(device)

        # Labels
        pos_y = torch.ones_like(pos_scores)
        neg_y = torch.zeros_like(neg_scores)

        # Combine
        scores = torch.cat([pos_scores, neg_scores], dim=0)
        labels = torch.cat([pos_y, neg_y], dim=0)

        # Loss + backward
        loss = criterion(scores, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        batch_time = time.perf_counter() - batch_start
        print(
            f"  Batch {batch_idx + 1}/{len(batches)} "
            f"[edges {start}:{end}] "
            f"loss={loss.item():.4f}  "
            f"time={batch_time:.3f}s"
        )

    epoch_time = time.perf_counter() - epoch_start
    print(f"Epoch {epoch + 1} total loss: {epoch_loss:.4f}")
    print(f"Epoch {epoch + 1} time: {epoch_time:.2f}s")

    losses.append(epoch_loss)


Using 20,000 edges for debug training
Total batches per epoch: 20

=== Epoch 1/1 ===


KeyboardInterrupt: 

In [17]:
print("node_raw_features:", tgn.node_raw_features.device)
print("edge_raw_features:", tgn.edge_raw_features.device)
print("time_encoder:", next(tgn.time_encoder.parameters()).device)

print("memory.memory:", tgn.memory.memory.device)
print("memory.last_update:", tgn.memory.last_update.device)

print("message_aggregator:", tgn.message_aggregator.device)
print("message_function device (if any):", 
      tgn.message_function.mlp[0].weight.device if hasattr(tgn.message_function, 'mlp') else "N/A")

print("embedding_module type:", type(tgn.embedding_module))


node_raw_features: cuda:0
edge_raw_features: cuda:0
time_encoder: cuda:0
memory.memory: cuda:0
memory.last_update: cuda:0
message_aggregator: cuda
message_function device (if any): cuda:0
embedding_module type: <class 'tgn.modules.embedding_module.GraphAttentionEmbedding'>
