# Transfer Learning: Multi-GNN for Node-Level Sanctions Classification

Adapting the [IBM Multi-GNN](https://github.com/IBM/Multi-GNN) architecture (originally for **edge-level** AML detection) to perform **node-level** classification of sanctioned entities on blockchain transaction graphs.

**Key Architectural Changes:**
- **Edge-level → Node-level** prediction
- **CSV → Parquet** data format
- **Final MLP**: `3×hidden` → `1×hidden` (no src+dst+edge concatenation — use node embeddings directly)
- **Class-weighted loss** for severe imbalance (~82 sanctioned vs ~19,000 non-sanctioned)

In [None]:
# Step 1-2: Install required dependencies
!pip install torch torch-geometric pandas numpy scikit-learn matplotlib seaborn tqdm pyarrow

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, GATConv, BatchNorm, Linear
from torch_geometric.data import Data

import pandas as pd
import numpy as np
import json
import os

from sklearn.metrics import (
    f1_score, precision_score, recall_score, accuracy_score,
    confusion_matrix, classification_report
)
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Data Loading

Loading three data files:
- `formatted_transactions.parquet` — edge list with transaction features
- `node_labels.parquet` — node-level labels (sanctioned vs non-sanctioned)
- `data_splits.json` — train/validation split by edge IDs

In [None]:
# Step 3: Upload data files (Colab) or set local paths

# --- For Google Colab: uncomment these lines to upload files ---
# from google.colab import files
# uploaded = files.upload()

# --- Set data directory ---
DATA_DIR = "."  # adjust if your files are in a subfolder

transactions_path = os.path.join(DATA_DIR, "formatted_transactions.parquet")
labels_path = os.path.join(DATA_DIR, "node_labels.parquet")
splits_path = os.path.join(DATA_DIR, "data_splits.json")

# Load data
df_edges = pd.read_parquet(transactions_path)
df_labels = pd.read_parquet(labels_path)
with open(splits_path, 'r') as f:
    data_splits = json.load(f)

print("=== Transactions ===")
print(f"Shape: {df_edges.shape}")
print(f"Columns: {df_edges.columns.tolist()}")
print(df_edges.head())

print("\n=== Node Labels ===")
print(f"Shape: {df_labels.shape}")
print(f"Columns: {df_labels.columns.tolist()}")
print(df_labels.head())
print(f"\nLabel distribution:\n{df_labels.iloc[:, -1].value_counts()}")

print("\n=== Data Splits ===")
print(f"Keys: {list(data_splits.keys())}")
for k, v in data_splits.items():
    if isinstance(v, list):
        print(f"  {k}: {len(v)} entries (first 5: {v[:5]})")
    else:
        print(f"  {k}: {v}")

In [None]:
# Step 4-5: Configure column names based on the output above
# *** ADJUST THESE IF YOUR COLUMNS HAVE DIFFERENT NAMES ***

SRC_COL = "from_id"       # Source node column
DST_COL = "to_id"         # Destination node column

# Edge feature columns = everything except src/dst IDs
EXCLUDE_COLS = {SRC_COL, DST_COL}
EDGE_FEATURE_COLS = [c for c in df_edges.columns if c not in EXCLUDE_COLS]
print(f"Edge feature columns ({len(EDGE_FEATURE_COLS)}): {EDGE_FEATURE_COLS}")

# Node label columns (assumes first col = node ID, last col = label)
NODE_ID_COL = df_labels.columns[0]
LABEL_COL = df_labels.columns[-1]
print(f"Node ID column: {NODE_ID_COL}")
print(f"Label column: {LABEL_COL}")
print(f"\nSplit keys available: {list(data_splits.keys())}")

In [None]:
# Step 6-8: Build the PyTorch Geometric Data object

def z_norm(data):
    """Z-score normalization (same as original Multi-GNN)."""
    std = data.std(0).unsqueeze(0)
    std = torch.where(std == 0, torch.tensor(1.0), std)
    return (data - data.mean(0).unsqueeze(0)) / std


def build_pyg_data(df_edges, df_labels, data_splits,
                   src_col, dst_col, edge_feature_cols,
                   node_id_col, label_col):
    """
    Build a PyG Data object for node-level classification.

    Original IBM Multi-GNN builds an edge-level graph with per-edge labels.
    Here we keep the same graph structure but attach per-node labels and masks.
    """
    # --- Edge index [2, num_edges] ---
    src = torch.LongTensor(df_edges[src_col].values)
    dst = torch.LongTensor(df_edges[dst_col].values)
    edge_index = torch.stack([src, dst], dim=0)

    # --- Edge attributes ---
    edge_feat_df = df_edges[edge_feature_cols].copy()
    for col in edge_feat_df.columns:
        if edge_feat_df[col].dtype == object:
            edge_feat_df[col] = edge_feat_df[col].astype('category').cat.codes
    edge_attr = torch.tensor(edge_feat_df.values, dtype=torch.float32)

    # --- Node features: placeholder of all 1s (same as original) ---
    max_node_id = max(int(src.max()), int(dst.max())) + 1
    x = torch.ones(max_node_id, 1, dtype=torch.float32)

    # --- Node labels ---
    y = torch.zeros(max_node_id, dtype=torch.long)
    label_node_ids = df_labels[node_id_col].values
    label_values = df_labels[label_col].values
    for nid, lbl in zip(label_node_ids, label_values):
        if nid < max_node_id:
            y[int(nid)] = int(lbl)

    # --- Train/Val node masks derived from edge-based splits ---
    train_mask = torch.zeros(max_node_id, dtype=torch.bool)
    val_mask = torch.zeros(max_node_id, dtype=torch.bool)

    # Find train/val keys in data_splits
    train_key = val_key = None
    for k in data_splits.keys():
        kl = k.lower()
        if 'train' in kl:
            train_key = k
        elif 'val' in kl or 'valid' in kl:
            val_key = k

    if train_key and isinstance(data_splits[train_key], list):
        train_edge_ids = data_splits[train_key]
        train_edges_df = df_edges.iloc[train_edge_ids]
        train_nodes = set(train_edges_df[src_col].values) | set(train_edges_df[dst_col].values)
        for nid in train_nodes:
            if nid < max_node_id:
                train_mask[int(nid)] = True

    if val_key and isinstance(data_splits[val_key], list):
        val_edge_ids = data_splits[val_key]
        val_edges_df = df_edges.iloc[val_edge_ids]
        val_nodes = set(val_edges_df[src_col].values) | set(val_edges_df[dst_col].values)
        for nid in val_nodes:
            if nid < max_node_id:
                val_mask[int(nid)] = True
        # Nodes in both splits stay in train only
        val_mask = val_mask & ~train_mask

    if train_key is None or val_key is None:
        print("WARNING: Could not auto-detect train/val keys in data_splits.json.")
        print(f"  Available keys: {list(data_splits.keys())}")
        print("  Falling back to random 80/20 split on all nodes.")
        perm = torch.randperm(max_node_id)
        split_idx = int(0.8 * max_node_id)
        train_mask[perm[:split_idx]] = True
        val_mask[perm[split_idx:]] = True

    # --- Normalize features ---
    edge_attr = z_norm(edge_attr)
    x = z_norm(x)

    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        train_mask=train_mask,
        val_mask=val_mask,
    )
    return data


# Build the data object
data = build_pyg_data(
    df_edges, df_labels, data_splits,
    SRC_COL, DST_COL, EDGE_FEATURE_COLS,
    NODE_ID_COL, LABEL_COL
)

print("Graph Data Object:")
print(f"  Nodes:           {data.num_nodes:,}")
print(f"  Edges:           {data.num_edges:,}")
print(f"  Node features:   {data.x.shape}")
print(f"  Edge features:   {data.edge_attr.shape}")
print(f"  Labels:          {data.y.shape}")
print(f"  Train nodes:     {data.train_mask.sum().item():,}")
print(f"  Val nodes:       {data.val_mask.sum().item():,}")
print(f"  Sanctioned (y=1):{(data.y == 1).sum().item():,}")
print(f"  Non-sanctioned:  {(data.y == 0).sum().item():,}")
print(f"  Train sanctioned:{(data.y[data.train_mask] == 1).sum().item()}")
print(f"  Val sanctioned:  {(data.y[data.val_mask] == 1).sum().item()}")

## 2. Model Architecture

Adapting the IBM Multi-GNN models (GINe / GATe) from **edge-level** to **node-level** prediction.

| | Original (edge-level) | Adapted (node-level) |
|---|---|---|
| **After GNN layers** | Concatenate `[src ‖ dst ‖ edge]` | Use node embeddings directly |
| **MLP input dim** | `3 × n_hidden` | `1 × n_hidden` |
| **Output** | One prediction per edge | One prediction per node |

In [None]:
# Step 9-12: Adapted GNN models for node-level classification

class NodeGINe(nn.Module):
    """
    GIN with edge features (GINe), adapted for node-level classification.

    Original IBM Multi-GNN outputs edge predictions by concatenating
    src + dst + edge embeddings (3*n_hidden -> MLP).
    This version outputs node predictions directly from node embeddings
    after message passing (n_hidden -> MLP).
    """
    def __init__(self, num_features, num_gnn_layers, n_classes=2,
                 n_hidden=100, edge_updates=False, residual=True,
                 edge_dim=None, dropout=0.0, final_dropout=0.5):
        super().__init__()
        self.n_hidden = n_hidden
        self.num_gnn_layers = num_gnn_layers
        self.edge_updates = edge_updates
        self.final_dropout = final_dropout

        self.node_emb = nn.Linear(num_features, n_hidden)
        self.edge_emb = nn.Linear(edge_dim, n_hidden)

        self.convs = nn.ModuleList()
        self.emlps = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        for _ in range(self.num_gnn_layers):
            conv = GINEConv(nn.Sequential(
                nn.Linear(n_hidden, n_hidden),
                nn.ReLU(),
                nn.Linear(n_hidden, n_hidden)
            ), edge_dim=n_hidden)
            if self.edge_updates:
                self.emlps.append(nn.Sequential(
                    nn.Linear(3 * n_hidden, n_hidden),
                    nn.ReLU(),
                    nn.Linear(n_hidden, n_hidden),
                ))
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(n_hidden))

        # CHANGED: MLP input is n_hidden (node embedding only)
        # Original was 3*n_hidden (src_emb + dst_emb + edge_emb)
        self.mlp = nn.Sequential(
            Linear(n_hidden, 50), nn.ReLU(), nn.Dropout(self.final_dropout),
            Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),
            Linear(25, n_classes)
        )

    def forward(self, x, edge_index, edge_attr):
        src, dst = edge_index

        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)

        for i in range(self.num_gnn_layers):
            x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
            if self.edge_updates:
                edge_attr = edge_attr + self.emlps[i](
                    torch.cat([x[src], x[dst], edge_attr], dim=-1)
                ) / 2

        # CHANGED: Return node-level predictions directly
        # Original did: x[edge_index.T].reshape(-1, 2*n_hidden) then cat with edge_attr
        return self.mlp(x)


class NodeGATe(nn.Module):
    """
    GAT with edge features (GATe), adapted for node-level classification.
    Same structural change as NodeGINe: node embeddings -> MLP instead of
    edge readout -> MLP.
    """
    def __init__(self, num_features, num_gnn_layers, n_classes=2,
                 n_hidden=100, n_heads=4, edge_updates=False,
                 edge_dim=None, dropout=0.0, final_dropout=0.5):
        super().__init__()
        tmp_out = n_hidden // n_heads
        n_hidden = tmp_out * n_heads

        self.n_hidden = n_hidden
        self.n_heads = n_heads
        self.num_gnn_layers = num_gnn_layers
        self.edge_updates = edge_updates
        self.dropout = dropout
        self.final_dropout = final_dropout

        self.node_emb = nn.Linear(num_features, n_hidden)
        self.edge_emb = nn.Linear(edge_dim, n_hidden)

        self.convs = nn.ModuleList()
        self.emlps = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        for _ in range(self.num_gnn_layers):
            conv = GATConv(n_hidden, tmp_out, n_heads, concat=True,
                           dropout=dropout, add_self_loops=True, edge_dim=n_hidden)
            if self.edge_updates:
                self.emlps.append(nn.Sequential(
                    nn.Linear(3 * n_hidden, n_hidden),
                    nn.ReLU(),
                    nn.Linear(n_hidden, n_hidden),
                ))
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(n_hidden))

        # CHANGED: MLP input is n_hidden (node embedding only)
        self.mlp = nn.Sequential(
            Linear(n_hidden, 50), nn.ReLU(), nn.Dropout(self.final_dropout),
            Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),
            Linear(25, n_classes)
        )

    def forward(self, x, edge_index, edge_attr):
        src, dst = edge_index

        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)

        for i in range(self.num_gnn_layers):
            x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
            if self.edge_updates:
                edge_attr = edge_attr + self.emlps[i](
                    torch.cat([x[src], x[dst], edge_attr], dim=-1)
                ) / 2

        # CHANGED: Return node-level predictions directly
        return self.mlp(x)


print("NodeGINe and NodeGATe models defined.")

## 3. Training

- **Loss**: Weighted `CrossEntropyLoss` — automatically computes class weight from the train set imbalance ratio
- **Metrics**: Loss, Accuracy, Precision, Recall, F1 (tracked per epoch for train and val)
- **Checkpointing**: Best model saved based on validation F1 score

In [None]:
# Step 13-17: Training with class-weighted loss, metrics tracking, and checkpointing

def compute_metrics(pred, target):
    """Compute classification metrics."""
    pred_np = pred.cpu().numpy()
    target_np = target.cpu().numpy()
    return {
        'accuracy': accuracy_score(target_np, pred_np),
        'precision': precision_score(target_np, pred_np, zero_division=0),
        'recall': recall_score(target_np, pred_np, zero_division=0),
        'f1': f1_score(target_np, pred_np, zero_division=0),
    }


def train_epoch(model, data, optimizer, loss_fn):
    """Train for one epoch on the full graph (transductive)."""
    model.train()
    optimizer.zero_grad()

    out = model(data.x, data.edge_index, data.edge_attr)
    loss = loss_fn(out[data.train_mask], data.y[data.train_mask])

    loss.backward()
    optimizer.step()

    pred = out[data.train_mask].argmax(dim=-1)
    metrics = compute_metrics(pred, data.y[data.train_mask])
    metrics['loss'] = loss.item()
    return metrics


@torch.no_grad()
def evaluate(model, data, loss_fn, mask):
    """Evaluate on a subset of nodes defined by mask."""
    model.eval()

    out = model(data.x, data.edge_index, data.edge_attr)
    loss = loss_fn(out[mask], data.y[mask])

    pred = out[mask].argmax(dim=-1)
    metrics = compute_metrics(pred, data.y[mask])
    metrics['loss'] = loss.item()
    return metrics, pred, out[mask]


def train(model, data, n_epochs=100, lr=0.006, save_path="best_model.pt"):
    """Full training loop with validation and checkpointing."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Step 13: Class-weighted loss — weight the minority class by the imbalance ratio
    n_sanctioned = (data.y[data.train_mask] == 1).sum().float()
    n_non_sanctioned = (data.y[data.train_mask] == 0).sum().float()
    weight_ratio = (n_non_sanctioned / n_sanctioned).item() if n_sanctioned > 0 else 1.0

    class_weights = torch.FloatTensor([1.0, weight_ratio]).to(device)
    print(f"Class weights: [non-sanctioned: {class_weights[0]:.2f}, sanctioned: {class_weights[1]:.2f}]")
    print(f"  (ratio = {weight_ratio:.1f}:1)\n")

    loss_fn = nn.CrossEntropyLoss(weight=class_weights)

    # Tracking
    history = {
        'train_loss': [], 'train_acc': [], 'train_prec': [],
        'train_rec': [], 'train_f1': [],
        'val_loss': [], 'val_acc': [], 'val_prec': [],
        'val_rec': [], 'val_f1': [],
    }
    best_val_f1 = 0.0
    best_epoch = 0

    for epoch in tqdm(range(n_epochs), desc="Training"):
        # Step 14: Train
        train_metrics = train_epoch(model, data, optimizer, loss_fn)

        # Step 15: Validate
        val_metrics, _, _ = evaluate(model, data, loss_fn, data.val_mask)

        # Step 16: Track metrics
        for k in ['loss', 'accuracy', 'precision', 'recall', 'f1']:
            short = {'loss': 'loss', 'accuracy': 'acc', 'precision': 'prec',
                     'recall': 'rec', 'f1': 'f1'}[k]
            history[f'train_{short}'].append(train_metrics[k])
            history[f'val_{short}'].append(val_metrics[k])

        # Step 17: Save best checkpoint based on val F1
        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            best_epoch = epoch
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': best_val_f1,
            }, save_path)

        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"\nEpoch {epoch+1:3d} | "
                  f"Train Loss: {train_metrics['loss']:.4f} F1: {train_metrics['f1']:.4f} | "
                  f"Val Loss: {val_metrics['loss']:.4f} F1: {val_metrics['f1']:.4f} "
                  f"Prec: {val_metrics['precision']:.4f} Rec: {val_metrics['recall']:.4f}")

    print(f"\nBest validation F1: {best_val_f1:.4f} at epoch {best_epoch + 1}")
    return history, best_val_f1

In [None]:
# Initialize model and run training

# Hyperparameters (based on IBM Multi-GNN model_settings.json for GIN)
MODEL_TYPE = "gin"       # "gin" or "gat"
N_HIDDEN = 66
N_GNN_LAYERS = 2
EDGE_UPDATES = True      # Use edge update MLPs (--emlps flag in original)
DROPOUT = 0.01
FINAL_DROPOUT = 0.1
LR = 0.006
N_EPOCHS = 100
SAVE_PATH = "best_model.pt"

num_node_features = data.x.shape[1]
num_edge_features = data.edge_attr.shape[1]

if MODEL_TYPE == "gin":
    model = NodeGINe(
        num_features=num_node_features,
        num_gnn_layers=N_GNN_LAYERS,
        n_classes=2,
        n_hidden=N_HIDDEN,
        edge_updates=EDGE_UPDATES,
        edge_dim=num_edge_features,
        dropout=DROPOUT,
        final_dropout=FINAL_DROPOUT,
    )
elif MODEL_TYPE == "gat":
    model = NodeGATe(
        num_features=num_node_features,
        num_gnn_layers=N_GNN_LAYERS,
        n_classes=2,
        n_hidden=N_HIDDEN,
        n_heads=4,
        edge_updates=EDGE_UPDATES,
        edge_dim=num_edge_features,
        dropout=DROPOUT,
        final_dropout=FINAL_DROPOUT,
    )

model = model.to(device)
data = data.to(device)

print(f"Model: {MODEL_TYPE.upper()}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(model)
print()

# Run training
history, best_val_f1 = train(model, data, n_epochs=N_EPOCHS, lr=LR, save_path=SAVE_PATH)

## 4. Evaluation

Load the best checkpoint and run final evaluation on the validation set with confusion matrix and classification report.

In [None]:
# Step 18-19: Load best checkpoint and run final evaluation

checkpoint = torch.load(SAVE_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch'] + 1} "
      f"(val F1: {checkpoint['val_f1']:.4f})\n")

# Reconstruct loss function for evaluation
n_sanctioned = (data.y[data.train_mask] == 1).sum().float()
n_non_sanctioned = (data.y[data.train_mask] == 0).sum().float()
weight_ratio = (n_non_sanctioned / n_sanctioned).item() if n_sanctioned > 0 else 1.0
loss_fn = nn.CrossEntropyLoss(
    weight=torch.FloatTensor([1.0, weight_ratio]).to(device)
)

# Final validation evaluation
val_metrics, val_pred, val_logits = evaluate(model, data, loss_fn, data.val_mask)
val_true = data.y[data.val_mask].cpu().numpy()
val_pred_np = val_pred.cpu().numpy()

print("=== Final Validation Results ===")
for k, v in val_metrics.items():
    print(f"  {k:>10s}: {v:.4f}")

# Classification report
print("\n=== Classification Report ===")
print(classification_report(
    val_true, val_pred_np,
    target_names=["Non-sanctioned", "Sanctioned"],
    digits=4
))

# --- Plots ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Confusion matrix heatmap
cm = confusion_matrix(val_true, val_pred_np)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=["Non-sanctioned", "Sanctioned"],
            yticklabels=["Non-sanctioned", "Sanctioned"],
            ax=axes[0])
axes[0].set_xlabel("Predicted")
axes[0].set_ylabel("Actual")
axes[0].set_title("Confusion Matrix")

# Training curves
epochs_range = range(1, len(history['train_f1']) + 1)
axes[1].plot(epochs_range, history['train_f1'], label='Train F1', alpha=0.8)
axes[1].plot(epochs_range, history['val_f1'], label='Val F1', alpha=0.8)
axes[1].plot(epochs_range, history['val_rec'], label='Val Recall', alpha=0.6, linestyle='--')
axes[1].plot(epochs_range, history['val_prec'], label='Val Precision', alpha=0.6, linestyle='--')
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Score")
axes[1].set_title("Training Curves")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("training_results.png", dpi=150, bbox_inches='tight')
plt.show()

print("Plot saved to training_results.png")

In [None]:
# Step 20: Save trained model weights for later use on cross-chain test data

FINAL_SAVE_PATH = "trained_node_gnn_weights.pt"

torch.save({
    'model_type': MODEL_TYPE,
    'model_state_dict': model.state_dict(),
    'num_node_features': num_node_features,
    'num_edge_features': num_edge_features,
    'n_hidden': N_HIDDEN,
    'n_gnn_layers': N_GNN_LAYERS,
    'edge_updates': EDGE_UPDATES,
    'n_classes': 2,
    'dropout': DROPOUT,
    'final_dropout': FINAL_DROPOUT,
    'best_val_f1': best_val_f1,
    'best_epoch': checkpoint['epoch'],
}, FINAL_SAVE_PATH)

print(f"Model weights saved to: {FINAL_SAVE_PATH}")
print()
print("To reload for cross-chain inference:")
print("  ckpt = torch.load('trained_node_gnn_weights.pt')")
print("  model = NodeGINe(")
print("      num_features=ckpt['num_node_features'],")
print("      num_gnn_layers=ckpt['n_gnn_layers'],")
print("      n_classes=ckpt['n_classes'],")
print("      n_hidden=ckpt['n_hidden'],")
print("      edge_updates=ckpt['edge_updates'],")
print("      edge_dim=ckpt['num_edge_features'],")
print("      dropout=ckpt['dropout'],")
print("      final_dropout=ckpt['final_dropout'],")
print("  )")
print("  model.load_state_dict(ckpt['model_state_dict'])")