# Training a GRAPH-SAGE-T

This notebook is for training a GRAPHSAGE-T GNN

In [1]:
import sys
import os

# Go up 3 levels: rat → ibm_transactions_scripts → scripts → ROOT
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
print("Root directory:", ROOT_DIR)

# Add root directory to Python path
sys.path.append(ROOT_DIR)


Root directory: c:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection


In [2]:
import subprocess
print(subprocess.getoutput("nvidia-smi"))


Mon Nov 24 01:03:09 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 581.29                 Driver Version: 581.29         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4080 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   54C    P0             29W /  159W |    1293MiB /  12282MiB |     22%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [3]:
import torch
print("Torch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("Is CUDA available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")


Torch version: 2.5.1+cu121
CUDA version: 12.1
Is CUDA available: True
Device: NVIDIA GeForce RTX 4080 Laptop GPU


In [4]:
import torch
from tgn.modules.memory import Memory
from tgn.model.tgn import TGN

edge_index = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\edge_index.pt")
x = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\x.pt")
edge_attr = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\edge_attr.pt")
timestamps = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\timestamps.pt")
y_edge = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\y_edge.pt")
y_node = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\y_node.pt")


  edge_index = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\edge_index.pt")
  x = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\x.pt")
  edge_attr = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\edge_attr.pt")
  timestamps = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Detection\ibm_transcations_datasets\RAT\pyg_graph_hismall\timestamps.pt")
  y_edge = torch.load(r"C:\Users\yasmi\OneDrive\Desktop\Uni - Master's\Fall 2025\MLR 570\Motif-Aware-Temporal-GNNs-for-Anti-Money-Laundering-Det

In [5]:
print("edge_index:", edge_index.shape)
print("x:", x.shape)
print("edge_attr:", edge_attr.shape)
print("timestamps:", timestamps.shape)
print("y_edge:", y_edge.shape)
print("y_node:", y_node.shape)

print("labels distribution:", y_edge.bincount())

edge_index: torch.Size([2, 5082345])
x: torch.Size([515080, 27])
edge_attr: torch.Size([5082345, 3])
timestamps: torch.Size([5082345])
y_edge: torch.Size([5082345])
y_node: torch.Size([515080])
labels distribution: tensor([5073168,    9177])


## Subsample to 1-2M edges + stratified split

In [6]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# --------------------------------------------------------
# 1. Choose how many edges you want to debug on
#    (set to 1_000_000 or 2_000_000 as you like)
# --------------------------------------------------------
MAX_EDGES = 1_000_000   # change to 2_000_000 if you want

E_total = edge_index.size(1)
print("Total edges in full graph:", E_total)

num_keep = min(MAX_EDGES, E_total)
print(f"Using {num_keep:,} edges for debug subset")

# --------------------------------------------------------
# 2. Stratified subsampling by label (keep fraction of edges)
# --------------------------------------------------------
y_edge = y_edge.long()
pos_mask = (y_edge == 1)
neg_mask = (y_edge == 0)

pos_idx_all = pos_mask.nonzero(as_tuple=True)[0]
neg_idx_all = neg_mask.nonzero(as_tuple=True)[0]

num_pos = pos_idx_all.numel()
num_neg = neg_idx_all.numel()

print(f"Total positives: {num_pos:,}, negatives: {num_neg:,}")

# keep same positive ratio in the subset
pos_ratio = num_pos / float(E_total)
num_pos_keep = int(pos_ratio * num_keep)
num_neg_keep = num_keep - num_pos_keep

pos_perm = pos_idx_all[torch.randperm(num_pos)]
neg_perm = neg_idx_all[torch.randperm(num_neg)]

pos_keep = pos_perm[:min(num_pos_keep, num_pos)]
neg_keep = neg_perm[:min(num_neg_keep, num_neg)]

subset_idx = torch.cat([pos_keep, neg_keep])
subset_idx = subset_idx[torch.randperm(subset_idx.size(0))]

print(f"Subset edges: {subset_idx.numel():,}")
print("  Positives in subset:", (y_edge[subset_idx] == 1).sum().item())
print("  Negatives in subset:", (y_edge[subset_idx] == 0).sum().item())

# --------------------------------------------------------
# 3. Build subset tensors
# --------------------------------------------------------
edge_index_sub = edge_index[:, subset_idx]
timestamps_sub = timestamps[subset_idx]
y_edge_sub     = y_edge[subset_idx].float()

print("edge_index_sub shape:", edge_index_sub.shape)
print("timestamps_sub shape:", timestamps_sub.shape)
print("y_edge_sub shape:", y_edge_sub.shape)

# --------------------------------------------------------
# 4. Normalize node features (important)
# --------------------------------------------------------
x_d = x.to(device).float()
edge_index_d = edge_index_sub.to(device).long()
ts_d = timestamps_sub.to(device).float()
y_d = y_edge_sub.to(device).float()

with torch.no_grad():
    mean = x_d.mean(dim=0, keepdim=True)
    std  = x_d.std(dim=0, keepdim=True) + 1e-6
    x_d  = (x_d - mean) / std

# --------------------------------------------------------
# 5. Stratified train/val/test split on the subset
# --------------------------------------------------------
num_edges = y_d.shape[0]
print("Subset edges:", num_edges)

pos_idx = (y_d == 1).nonzero(as_tuple=True)[0]
neg_idx = (y_d == 0).nonzero(as_tuple=True)[0]

print("Positives in subset:", pos_idx.numel())
print("Negatives in subset:", neg_idx.numel())

# Split each class separately: 70% / 10% / 20%
def stratified_split(idx_tensor):
    n = idx_tensor.numel()
    perm = idx_tensor[torch.randperm(n)]
    n_train = int(0.7 * n)
    n_val   = int(0.1 * n)
    train = perm[:n_train]
    val   = perm[n_train : n_train + n_val]
    test  = perm[n_train + n_val :]
    return train, val, test

train_pos, val_pos, test_pos = stratified_split(pos_idx)
train_neg, val_neg, test_neg = stratified_split(neg_idx)

train_idx = torch.cat([train_pos, train_neg])
val_idx   = torch.cat([val_pos,   val_neg])
test_idx  = torch.cat([test_pos,  test_neg])

# shuffle within each split
train_idx = train_idx[torch.randperm(train_idx.size(0))]
val_idx   = val_idx[torch.randperm(val_idx.size(0))]
test_idx  = test_idx[torch.randperm(test_idx.size(0))]

print(f"Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
print("  Train positives:", (y_d[train_idx] == 1).sum().item())
print("  Val positives:  ", (y_d[val_idx]   == 1).sum().item())
print("  Test positives: ", (y_d[test_idx]  == 1).sum().item())


Device: cuda
Total edges in full graph: 5082345
Using 1,000,000 edges for debug subset
Total positives: 9,177, negatives: 5,073,168
Subset edges: 1,000,000
  Positives in subset: 1805
  Negatives in subset: 998195
edge_index_sub shape: torch.Size([2, 1000000])
timestamps_sub shape: torch.Size([1000000])
y_edge_sub shape: torch.Size([1000000])
Subset edges: 1000000
Positives in subset: 1805
Negatives in subset: 998195
Train=699999, Val=99999, Test=200002
  Train positives: 1263
  Val positives:   180
  Test positives:  362


## GraphSAGE-T Implementation

In [7]:
import torch
import torch.nn as nn
from torch_geometric.nn import SAGEConv

# ------------------------------------------------------------
# GraphSAGE-T  (GraphSAGE + sinusoidal time encoding)
# ------------------------------------------------------------
class TimeEncode(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        # Frequencies (log-spaced)
        self.w = nn.Parameter(torch.exp(
            torch.linspace(0, 3, dim)
        ), requires_grad=True)

    def forward(self, t):
        # t: shape [E] or [N]
        t = t.unsqueeze(-1)                     # [E, 1]
        out = torch.cat([torch.sin(t * self.w),
                         torch.cos(t * self.w)], dim=-1)
        return out.float()                      # [E, 2*dim]


class GraphSageT(nn.Module):
    def __init__(self, node_feat_dim, time_dim, hidden_dim=64, num_layers=2):
        super().__init__()
        self.time_dim = time_dim
        self.time_encoder = TimeEncode(time_dim)

        # first layer takes node features + time features
        self.convs = nn.ModuleList()
        self.convs.append(
            SAGEConv(node_feat_dim + 2*time_dim, hidden_dim)
        )

        for _ in range(num_layers - 1):
            self.convs.append(
                SAGEConv(hidden_dim, hidden_dim)
            )

    def forward(self, x, edge_index, timestamps):
        """
        x:           [N, F]
        edge_index:  [2, E]
        timestamps:  [E] on same device
        """

        # Compute time encodings per edge
        t_enc = self.time_encoder(timestamps)        # [E, 2*time_dim]

        # Scatter time encodings to nodes by averaging incoming edges
        N = x.size(0)
        T = torch.zeros((N, t_enc.size(1)),
                        device=x.device)
        dst = edge_index[1]

        # accumulate encodings per node
        T.index_add_(0, dst, t_enc)

        # normalize by degree
        deg = torch.bincount(dst, minlength=N).clamp(min=1).unsqueeze(-1)
        T = T / deg

        # concatenate: x || time
        h = torch.cat([x, T], dim=1)

        # run GraphSAGE layers
        for conv in self.convs:
            h = conv(h, edge_index)
            h = torch.relu(h)

        return h


  from .autonotebook import tqdm as notebook_tqdm


## Training + evaluation (full-batch per epoch)

In [8]:
import torch
import torch.nn as nn
from sklearn.metrics import (
    precision_recall_fscore_support,
    roc_auc_score,
    average_precision_score,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# --------------------------------------------------------
# 1. Model + link predictor
# --------------------------------------------------------
hidden_dim = 64

model = GraphSageT(
    node_feat_dim=x_d.shape[1],
    time_dim=8,        # matches your TimeEncode(dim=8) usage
    hidden_dim=hidden_dim,
    num_layers=2,
).to(device)

link_pred = nn.Sequential(
    nn.Linear(2 * hidden_dim, 64),
    nn.ReLU(),
    nn.Linear(64, 1),
).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(
    list(model.parameters()) + list(link_pred.parameters()),
    lr=5e-4,
    weight_decay=1e-5,
)

EPOCHS = 20

# --------------------------------------------------------
# 2. Helper: evaluate on a given split
# --------------------------------------------------------
def evaluate(indices):
    model.eval()
    link_pred.eval()

    with torch.no_grad():
        # full-graph forward once
        h = model(x_d, edge_index_d, ts_d)

        src = edge_index_d[0, indices]
        dst = edge_index_d[1, indices]
        labels = y_d[indices]

        logits = link_pred(torch.cat([h[src], h[dst]], dim=-1)).squeeze(-1)
        probs = torch.sigmoid(logits)

        loss = criterion(logits, labels).item()

        scores = probs.cpu().numpy()
        y_true = labels.cpu().numpy()

    # 0.5 threshold for hard predictions
    y_pred = (scores >= 0.5).astype(int)

    # metrics (binary)
    prec, rec, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="binary", zero_division=0
    )

    try:
        roc = roc_auc_score(y_true, scores)
    except ValueError:
        roc = float("nan")

    try:
        aupr = average_precision_score(y_true, scores)
    except ValueError:
        aupr = float("nan")

    return loss, prec, rec, f1, roc, aupr

# --------------------------------------------------------
# 3. Training loop (full-batch over train edges)
# --------------------------------------------------------
for epoch in range(1, EPOCHS + 1):
    model.train()
    link_pred.train()

    optimizer.zero_grad()

    # forward once over whole subset graph
    h = model(x_d, edge_index_d, ts_d)

    src = edge_index_d[0, train_idx]
    dst = edge_index_d[1, train_idx]
    labels = y_d[train_idx]

    logits = link_pred(torch.cat([h[src], h[dst]], dim=-1)).squeeze(-1)
    loss = criterion(logits, labels)

    loss.backward()
    optimizer.step()

    # validation
    val_loss, prec, rec, f1, roc, aupr = evaluate(val_idx)

    print(
        f"Epoch {epoch:02d} | "
        f"train_loss={loss.item():.4f} val_loss={val_loss:.4f} "
        f"P={prec:.3f} R={rec:.3f} F1={f1:.3f} "
        f"ROC-AUC={roc:.3f} AUPR={aupr:.3f}"
    )

# --------------------------------------------------------
# 4. Final test evaluation
# --------------------------------------------------------
test_loss, prec_t, rec_t, f1_t, roc_t, aupr_t = evaluate(test_idx)

print("\n=== FINAL TEST METRICS (subset) ===")
print(f"Test loss : {test_loss:.4f}")
print(f"Precision : {prec_t:.3f}")
print(f"Recall    : {rec_t:.3f}")
print(f"F1-score  : {f1_t:.3f}")
print(f"ROC-AUC   : {roc_t:.3f}")
print(f"AUC-PR    : {aupr_t:.3f}")


Device: cuda
Epoch 01 | train_loss=0.9363 val_loss=0.7760 P=0.002 R=0.844 F1=0.003 ROC-AUC=0.578 AUPR=0.003
Epoch 02 | train_loss=0.7757 val_loss=0.6952 P=0.002 R=0.761 F1=0.003 ROC-AUC=0.488 AUPR=0.002
Epoch 03 | train_loss=0.6954 val_loss=0.6501 P=0.001 R=0.533 F1=0.002 ROC-AUC=0.413 AUPR=0.002
Epoch 04 | train_loss=0.6508 val_loss=0.6198 P=0.001 R=0.439 F1=0.002 ROC-AUC=0.385 AUPR=0.002
Epoch 05 | train_loss=0.6207 val_loss=0.5974 P=0.001 R=0.383 F1=0.002 ROC-AUC=0.365 AUPR=0.001
Epoch 06 | train_loss=0.5984 val_loss=0.5791 P=0.001 R=0.311 F1=0.002 ROC-AUC=0.365 AUPR=0.001
Epoch 07 | train_loss=0.5802 val_loss=0.5634 P=0.001 R=0.278 F1=0.002 ROC-AUC=0.371 AUPR=0.001
Epoch 08 | train_loss=0.5646 val_loss=0.5500 P=0.001 R=0.200 F1=0.002 ROC-AUC=0.371 AUPR=0.001
Epoch 09 | train_loss=0.5513 val_loss=0.5371 P=0.001 R=0.122 F1=0.002 ROC-AUC=0.375 AUPR=0.001
Epoch 10 | train_loss=0.5384 val_loss=0.5259 P=0.001 R=0.056 F1=0.002 ROC-AUC=0.376 AUPR=0.001
Epoch 11 | train_loss=0.5273 val_loss