In [3]:
# ========= All-in-one: models + trainer (fixed for robust GAT/GATv2) =========
import os, numpy as np, pandas as pd, torch
from torch import nn
from torch_geometric.datasets import EllipticBitcoinDataset
from torch_geometric.utils import to_undirected, add_self_loops
from torch_geometric.nn import GATConv, GATv2Conv, GCNConv, SAGEConv

EMB_DIM = 64  # standardized embedding size

# ---------- Base ----------
class BaseGNN(nn.Module):
    def __init__(self):
        super().__init__()
        self._last_z = None  # cache embeddings

    def forward(self, x, edge_index):
        z, logits = self._embed_and_logits(x, edge_index)
        self._last_z = z
        return logits

    @torch.no_grad()
    def embed(self, x, edge_index):
        self.eval()
        z, _ = self._embed_and_logits(x, edge_index)
        return z


# ---------- GCN ----------
class GCNNet(BaseGNN):
    def __init__(self, in_dim, hidden=128, drop=0.5):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden, cached=True, normalize=True)
        self.bn1   = nn.BatchNorm1d(hidden)
        self.drop  = nn.Dropout(drop)
        self.act   = nn.ReLU()
        self.conv2 = GCNConv(hidden, 2, cached=True, normalize=True)  # logits
        self.emb_proj = nn.Linear(hidden, EMB_DIM)
        self.emb_bn   = nn.BatchNorm1d(EMB_DIM)

    def _embed_and_logits(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = self.act(self.bn1(h))
        h = self.drop(h)
        logits = self.conv2(h, edge_index)
        z = self.emb_bn(self.emb_proj(h))
        return z, logits


# ---------- Skip-GCN ----------
class SkipGCNNet(BaseGNN):
    def __init__(self, in_dim, hidden=128, drop=0.5):
        super().__init__()
        self.in_proj = nn.Linear(in_dim, EMB_DIM, bias=False)
        self.g1 = GCNConv(in_dim, hidden)
        self.bn1 = nn.BatchNorm1d(hidden)
        self.g2 = GCNConv(hidden, EMB_DIM)
        self.bn2 = nn.BatchNorm1d(EMB_DIM)
        self.drop = nn.Dropout(drop)
        self.act = nn.ReLU()
        self.head = nn.Linear(EMB_DIM, 2)

    def _embed_and_logits(self, x, edge_index):
        skip = self.in_proj(x)
        h = self.g1(x, edge_index); h = self.act(self.bn1(h)); h = self.drop(h)
        h = self.g2(h, edge_index); h = self.bn2(h)
        z = self.act(h + skip); z = self.drop(z)
        logits = self.head(z)
        return z, logits


# ---------- GraphSAGE ----------
class SAGENet(BaseGNN):
    def __init__(self, in_dim, hidden=256, drop=0.5):
        super().__init__()
        self.s1 = SAGEConv(in_dim, hidden)
        self.bn1 = nn.BatchNorm1d(hidden)
        self.s2 = SAGEConv(hidden, EMB_DIM)
        self.bn2 = nn.BatchNorm1d(EMB_DIM)
        self.drop = nn.Dropout(drop)
        self.act = nn.ReLU()
        self.head = nn.Linear(EMB_DIM, 2)

    def _embed_and_logits(self, x, edge_index):
        h = self.s1(x, edge_index); h = self.act(self.bn1(h)); h = self.drop(h)
        z = self.s2(h, edge_index); z = self.act(self.bn2(z)); z = self.drop(z)
        logits = self.head(z)
        return z, logits


# ---------- Classic GAT (direct logits, LayerNorm, no double self-loops) ----------
class GATNet(BaseGNN):
    def __init__(self, in_dim, hidden=128, heads=4, drop=0.5):
        super().__init__()
        self.g1 = GATConv(in_dim, hidden, heads=heads, concat=True,
                          dropout=drop, add_self_loops=False)
        self.n1 = nn.LayerNorm(hidden * heads)
        self.act = nn.ELU()
        self.drop = nn.Dropout(drop)
        # direct logits head (no EMB projection in the classification path)
        self.g2 = GATConv(hidden * heads, 2, heads=1, concat=False,
                          dropout=drop, add_self_loops=False)
        # separate embedding projection from the penultimate representation
        self.emb = nn.Sequential(
            nn.Linear(hidden * heads, EMB_DIM, bias=False),
            nn.LayerNorm(EMB_DIM)
        )

    def _embed_and_logits(self, x, edge_index):
        h1 = self.g1(x, edge_index); h1 = self.act(self.n1(h1)); h1 = self.drop(h1)
        logits = self.g2(h1, edge_index)
        z = self.emb(h1)
        return z, logits


# ---------- GATv2 (direct logits, LayerNorm, no double self-loops) ----------
class GATv2Net(BaseGNN):
    def __init__(self, in_dim, hidden=128, heads=4, drop=0.5):
        super().__init__()
        self.g1 = GATv2Conv(in_dim, hidden, heads=heads, concat=True,
                            dropout=drop, add_self_loops=False)
        self.n1 = nn.LayerNorm(hidden * heads)
        self.act = nn.ELU()
        self.drop = nn.Dropout(drop)
        self.g2 = GATv2Conv(hidden * heads, 2, heads=1, concat=False,
                            dropout=drop, add_self_loops=False)  # direct logits
        self.emb = nn.Sequential(
            nn.Linear(hidden * heads, EMB_DIM, bias=False),
            nn.LayerNorm(EMB_DIM)
        )

    def _embed_and_logits(self, x, edge_index):
        h1 = self.g1(x, edge_index); h1 = self.act(self.n1(h1)); h1 = self.drop(h1)
        logits = self.g2(h1, edge_index)
        z = self.emb(h1)
        return z, logits


# ---------- EvolveGCN (guarded) ----------
try:
    from torch_geometric_temporal.nn.recurrent import EvolveGCNO
    HAVE_TGT = True
except Exception:
    HAVE_TGT = False

class EvolveGCNWrapper(BaseGNN):
    def __init__(self, in_dim):
        super().__init__()
        if not HAVE_TGT:
            raise RuntimeError("EvolveGCN requires torch_geometric_temporal + temporal snapshots.")
        self.cell = EvolveGCNO(in_dim, EMB_DIM)
        self.head = nn.Linear(EMB_DIM, 2)

    def _embed_and_logits(self, x, edge_index):
        raise RuntimeError("Provide temporal snapshots iterator for EvolveGCN.")


# ---------- Trainer ----------
def gnn_train(
    arch="gcn",
    root="./elliptic_data",
    seed=42,
    epochs=1000,
    patience=50,
    lr=None,
    wd=None,
    device=None,
    save_to_disk=False,
    out_dir="embeddings",
    use_rolled_val=True,   # use 31–34 for threshold tuning
    grad_clip=None         # None disables; set e.g. 2.0 for GCN/SAGE
):
    """
    Train a GNN on Elliptic with temporal splits and save 64-d embeddings.

    Usage:
        results = gnn_train("gatv2", epochs=500)
        results = gnn_train("gcn", lr=0.01)
        results = gnn_train("gat", grad_clip=None)
    """
    torch.manual_seed(seed); np.random.seed(seed)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Data: temporal split + preprocess
    dataset = EllipticBitcoinDataset(root=root)
    data = dataset[0]
    feat_df = pd.read_csv(os.path.join(root, "raw", "elliptic_txs_features.csv"), header=None)
    time_step_cpu = torch.from_numpy(feat_df[1].values)  # CPU tensor for masks
    labeled_cpu = (data.y != 2).cpu()

    # Wider validation window for robust threshold transfer
    # val: 31–34, train: <31, test: >=35
    val_mask_cpu   = (time_step_cpu >= 31) & (time_step_cpu < 35) & labeled_cpu if use_rolled_val \
                     else ((time_step_cpu >= 32) & (time_step_cpu < 35) & labeled_cpu)
    train_mask_cpu = (time_step_cpu < (31 if use_rolled_val else 35)) & labeled_cpu & (~val_mask_cpu)
    test_mask_cpu  = (time_step_cpu >= 35) & labeled_cpu

    # Graph + self-loops (added once globally)
    data.edge_index = to_undirected(data.edge_index, num_nodes=data.num_nodes)
    data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.num_nodes)

    # Train-only z-score
    with torch.no_grad():
        tr = train_mask_cpu
        mu = data.x[tr].mean(0, keepdim=True)
        std = data.x[tr].std(0, keepdim=True).clamp_min(1e-6)
        data.x = ((data.x - mu) / std).to(torch.float)

    # Send to device
    data = data.to(device); data.y = data.y.long().to(device)
    data.train_mask = train_mask_cpu.to(device)
    data.val_mask   = val_mask_cpu.to(device)
    data.test_mask  = test_mask_cpu.to(device)

    # Model factory + defaults
    arch_lc = arch.lower()
    defaults = {
        "gatv2": dict(lr=3e-3, wd=5e-4, grad_clip=None),
        "gat":   dict(lr=3e-3, wd=5e-4, grad_clip=None),
        "gcn":   dict(lr=1e-2, wd=5e-4, grad_clip=2.0),
        "skip_gcn": dict(lr=1e-2, wd=5e-4, grad_clip=2.0),
        "sage":  dict(lr=1e-2, wd=5e-4, grad_clip=2.0),
        "evolvegcn": dict(lr=5e-3, wd=5e-4, grad_clip=None),
    }
    if arch_lc not in defaults:
        raise ValueError("arch must be one of: gcn, skip_gcn, gat, gatv2, sage, evolvegcn")

    factory = {
        "gatv2":     lambda: GATv2Net(in_dim=data.num_features, hidden=128, heads=4, drop=0.5),
        "gat":       lambda: GATNet(in_dim=data.num_features,  hidden=128, heads=4, drop=0.5),
        "gcn":       lambda: GCNNet(in_dim=data.num_features,  hidden=128, drop=0.5),
        "skip_gcn":  lambda: SkipGCNNet(in_dim=data.num_features, hidden=128, drop=0.5),
        "sage":      lambda: SAGENet(in_dim=data.num_features,  hidden=256, drop=0.5),
        "evolvegcn": lambda: EvolveGCNWrapper(in_dim=data.num_features),
    }
    model = factory[arch_lc]().to(device)
    lr = defaults[arch_lc]["lr"] if lr is None else lr
    wd = defaults[arch_lc]["wd"] if wd is None else wd
    if grad_clip is None and defaults[arch_lc]["grad_clip"] is not None:
        grad_clip = defaults[arch_lc]["grad_clip"]

    # Loss/optim (class imbalance)
    y_tr = data.y[data.train_mask]
    pos = int((y_tr == 1).sum())
    neg = int((y_tr == 0).sum())
    weights = torch.tensor([1.0, neg / max(1, pos)], dtype=torch.float, device=device)
    criterion = nn.CrossEntropyLoss(weight=weights)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # Helpers
    @torch.no_grad()
    def best_threshold(logits, y_true):
        p1 = logits.softmax(1)[:, 1]
        best_f1, best_t = -1.0, 0.5
        for t in torch.linspace(0.05, 0.95, steps=37, device=p1.device):
            yhat = (p1 >= t).long()
            tp = ((yhat==1)&(y_true==1)).sum().item()
            fp = ((yhat==1)&(y_true==0)).sum().item()
            fn = ((yhat==0)&(y_true==1)).sum().item()
            P = 0.0 if tp+fp==0 else tp/(tp+fp)
            R = 0.0 if tp+fn==0 else tp/(tp+fn)
            F1 = 0.0 if P+R==0 else 2*P*R/(P+R)
            if F1 > best_f1:
                best_f1, best_t = F1, float(t.item())
        return best_t, best_f1

    @torch.no_grad()
    def report(split_mask, thr):
        logits = model(data.x, data.edge_index)[split_mask]
        y = data.y[split_mask]
        p1 = logits.softmax(1)[:, 1]
        yhat = (p1 >= thr).long()
        tp = ((yhat==1)&(y==1)).sum().item()
        fp = ((yhat==1)&(y==0)).sum().item()
        fn = ((yhat==0)&(y==1)).sum().item()
        tn = ((yhat==0)&(y==0)).sum().item()
        P = 0.0 if tp+fp==0 else tp/(tp+fp)
        R = 0.0 if tp+fn==0 else tp/(tp+fn)
        F1 = 0.0 if P+R==0 else 2*P*R/(P+R)
        Acc = (tp+tn)/max(1,tp+tn+fp+fn)
        return dict(P=P, R=R, F1=F1, Acc=Acc, microF1=Acc)

    # Train (early stop on val F1)
    best = {"val_f1": -1.0, "thr": 0.5, "state": None}
    waited = 0
    for epoch in range(1, epochs+1):
        model.train(); opt.zero_grad()
        logits_all = model(data.x, data.edge_index)
        loss = criterion(logits_all[data.train_mask], data.y[data.train_mask])
        loss.backward()
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        opt.step()

        if epoch % 5 == 0 or epoch == 1:
            model.eval()
            val_logits = model(data.x, data.edge_index)[data.val_mask]
            thr, f1 = best_threshold(val_logits, data.y[data.val_mask])
            if f1 > best["val_f1"]:
                best = {
                    "val_f1": f1,
                    "thr": thr,
                    "state": {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                }
                waited = 0
            else:
                waited += 5
            print(f"[{arch.upper():6}] Epoch {epoch:04d} | Loss {loss.item():.4f} | Val F1(illicit) {f1:.3f} | Thr {thr:.2f}")
            if waited >= patience:
                break

    # Test + embeddings
    model.load_state_dict(best["state"]); model.to(device); model.eval()
    test = report(data.test_mask, thr=best["thr"])
    with torch.no_grad():
        z = model.embed(data.x, data.edge_index).detach().cpu().numpy()

    var_name = f"{arch_lc}_embeddings"
    globals()[var_name] = z
    nice = arch.replace("_"," ").title().replace("Gcn","GCN").replace("Gat","GAT")
    print("Embeddings ready:"); print(f"{nice}: {z.shape}")

    if save_to_disk:
        os.makedirs(out_dir, exist_ok=True)
        np.save(os.path.join(out_dir, f"{arch_lc}_embeddings.npy"), z)

    print(f"\n[{arch.upper()}] TEST → F1(illicit): {test['F1']:.3f}, Precision: {test['P']:.3f}, "
          f"Recall: {test['R']:.3f}, Acc: {test['Acc']:.3f}, Micro-F1: {test['microF1']:.3f}")

    print("TIP: Call gnn_train('arch_name') where arch_name ∈ [gcn, skip_gcn, gat, gatv2, sage, evolvegcn]")

    return {
        "arch": arch_lc,
        "val_best_F1": best["val_f1"],
        "best_threshold": best["thr"],
        "test_metrics": test,
        "embeddings_var": var_name,
        "embeddings_shape": z.shape,
        "lr": lr,
        "weight_decay": wd,
    }

In [4]:
import pandas as pd

# 1) Train all (except EvolveGCN)
arch_list = ["gcn", "skip_gcn", "gat", "gatv2", "sage"]
all_results = {}
for arch in arch_list:
    print(f"\n=== Training {arch.upper()} ===")
    all_results[arch] = gnn_train(arch)

# 2) Build a summary table
rows = []
for arch, r in all_results.items():
    m = r["test_metrics"]
    rows.append({
        "Model": arch.upper(),
        "F1": m["F1"],
        "Precision": m["P"],
        "Recall": m["R"],
        "Accuracy": m["Acc"],
        "Micro-F1": m["microF1"],
        "Val best F1": r["val_best_F1"],
        "Best thr": r["best_threshold"],
        "Embeddings": r["embeddings_shape"],
    })

df = pd.DataFrame(rows).sort_values("F1", ascending=False).reset_index(drop=True)
df


=== Training GCN ===
[GCN   ] Epoch 0001 | Loss 0.6614 | Val F1(illicit) 0.428 | Thr 0.88
[GCN   ] Epoch 0005 | Loss 0.3461 | Val F1(illicit) 0.417 | Thr 0.95
[GCN   ] Epoch 0010 | Loss 0.2892 | Val F1(illicit) 0.522 | Thr 0.85
[GCN   ] Epoch 0015 | Loss 0.2556 | Val F1(illicit) 0.629 | Thr 0.90
[GCN   ] Epoch 0020 | Loss 0.2280 | Val F1(illicit) 0.689 | Thr 0.88
[GCN   ] Epoch 0025 | Loss 0.2106 | Val F1(illicit) 0.718 | Thr 0.82
[GCN   ] Epoch 0030 | Loss 0.1947 | Val F1(illicit) 0.713 | Thr 0.77
[GCN   ] Epoch 0035 | Loss 0.1837 | Val F1(illicit) 0.740 | Thr 0.70
[GCN   ] Epoch 0040 | Loss 0.1730 | Val F1(illicit) 0.746 | Thr 0.68
[GCN   ] Epoch 0045 | Loss 0.1635 | Val F1(illicit) 0.773 | Thr 0.62
[GCN   ] Epoch 0050 | Loss 0.1552 | Val F1(illicit) 0.801 | Thr 0.57
[GCN   ] Epoch 0055 | Loss 0.1455 | Val F1(illicit) 0.801 | Thr 0.62
[GCN   ] Epoch 0060 | Loss 0.1446 | Val F1(illicit) 0.801 | Thr 0.60
[GCN   ] Epoch 0065 | Loss 0.1368 | Val F1(illicit) 0.797 | Thr 0.62
[GCN   ] Epo

[GAT   ] Epoch 0180 | Loss 0.2869 | Val F1(illicit) 0.768 | Thr 0.90
[GAT   ] Epoch 0185 | Loss 0.2922 | Val F1(illicit) 0.771 | Thr 0.90
[GAT   ] Epoch 0190 | Loss 0.2965 | Val F1(illicit) 0.791 | Thr 0.88
[GAT   ] Epoch 0195 | Loss 0.2822 | Val F1(illicit) 0.782 | Thr 0.85
[GAT   ] Epoch 0200 | Loss 0.2817 | Val F1(illicit) 0.794 | Thr 0.88
[GAT   ] Epoch 0205 | Loss 0.2875 | Val F1(illicit) 0.799 | Thr 0.88
[GAT   ] Epoch 0210 | Loss 0.2821 | Val F1(illicit) 0.789 | Thr 0.88
[GAT   ] Epoch 0215 | Loss 0.2804 | Val F1(illicit) 0.792 | Thr 0.88
[GAT   ] Epoch 0220 | Loss 0.2801 | Val F1(illicit) 0.780 | Thr 0.93
[GAT   ] Epoch 0225 | Loss 0.2863 | Val F1(illicit) 0.785 | Thr 0.88
[GAT   ] Epoch 0230 | Loss 0.2713 | Val F1(illicit) 0.799 | Thr 0.93
[GAT   ] Epoch 0235 | Loss 0.2827 | Val F1(illicit) 0.797 | Thr 0.90
[GAT   ] Epoch 0240 | Loss 0.2703 | Val F1(illicit) 0.803 | Thr 0.93
[GAT   ] Epoch 0245 | Loss 0.2732 | Val F1(illicit) 0.794 | Thr 0.88
[GAT   ] Epoch 0250 | Loss 0.2726 

[GATV2 ] Epoch 0245 | Loss 0.2541 | Val F1(illicit) 0.808 | Thr 0.90
[GATV2 ] Epoch 0250 | Loss 0.2464 | Val F1(illicit) 0.794 | Thr 0.85
[GATV2 ] Epoch 0255 | Loss 0.2444 | Val F1(illicit) 0.791 | Thr 0.90
[GATV2 ] Epoch 0260 | Loss 0.2466 | Val F1(illicit) 0.802 | Thr 0.88
[GATV2 ] Epoch 0265 | Loss 0.2478 | Val F1(illicit) 0.790 | Thr 0.90
[GATV2 ] Epoch 0270 | Loss 0.2506 | Val F1(illicit) 0.798 | Thr 0.90
Embeddings ready:
GATv2: (203769, 64)

[GATV2] TEST → F1(illicit): 0.537, Precision: 0.494, Recall: 0.588, Acc: 0.934, Micro-F1: 0.934
TIP: Call gnn_train('arch_name') where arch_name ∈ [gcn, skip_gcn, gat, gatv2, sage, evolvegcn]

=== Training SAGE ===
[SAGE  ] Epoch 0001 | Loss 0.7016 | Val F1(illicit) 0.420 | Thr 0.20
[SAGE  ] Epoch 0005 | Loss 0.3386 | Val F1(illicit) 0.554 | Thr 0.70
[SAGE  ] Epoch 0010 | Loss 0.2570 | Val F1(illicit) 0.593 | Thr 0.90
[SAGE  ] Epoch 0015 | Loss 0.2057 | Val F1(illicit) 0.754 | Thr 0.95
[SAGE  ] Epoch 0020 | Loss 0.1713 | Val F1(illicit) 0.75

Unnamed: 0,Model,F1,Precision,Recall,Accuracy,Micro-F1,Val best F1,Best thr,Embeddings
0,SKIP_GCN,0.606314,0.609711,0.602955,0.94913,0.94913,0.86747,0.45,"(203769, 64)"
1,GATV2,0.537099,0.494182,0.588181,0.934133,0.934133,0.816169,0.875,"(203769, 64)"
2,GCN,0.535433,0.485714,0.596491,0.932753,0.932753,0.826291,0.725,"(203769, 64)"
3,GAT,0.534131,0.580087,0.494922,0.943911,0.943911,0.860258,0.9,"(203769, 64)"
4,SAGE,0.446101,0.323493,0.718375,0.884103,0.884103,0.90411,0.925,"(203769, 64)"
