## Common Functions

### Load data and split

In [105]:
# data split, copied from Sev's experiment ipynb
import pandas as pd

txs_classes = pd.read_csv('../data/elliptic_txs_classes.csv')
txs_edges = pd.read_csv('../data/elliptic_txs_edgelist.csv')
txs_features = pd.read_csv('../data/elliptic_txs_features.csv', header=None)


# join features with classes using tx id (1st column of txs_features)
txs_data = txs_features.merge(txs_classes, left_on=0, right_on='txId', how='left')

# convert class labels to integers
# 1: licit (0), 2: illicit (1), unknown: -1
label_mapping = {'1': 0, '2': 1, 'unknown': -1}
txs_data['class'] = txs_data['class'].map(label_mapping).astype(int)

# split data and edges into train and test according to timestep (2nd column of txs_features)
train_data_all = txs_data[txs_data[1] <= 34]
test_data_all = txs_data[txs_data[1] > 34]

# separate datasets with labels(1 or 2) from those without labels(class=unknown)
train_data_labeled = train_data_all[train_data_all['class'].isin([0, 1])]
test_data_labeled = test_data_all[test_data_all['class'].isin([0, 1])]

# process edges like data: add timestep info and split into train and test
txs_edges = txs_edges.merge(txs_features[[0, 1]], left_on='txId1', right_on=0, how='left').rename(columns={1: 'timestep'}).drop(columns=[0])
train_edges_all = txs_edges[txs_edges['timestep'] <= 34]
test_edges_all = txs_edges[txs_edges['timestep'] > 34]
train_edges_labeled = train_edges_all[train_edges_all['txId1'].isin(train_data_labeled['txId']) & train_edges_all['txId2'].isin(train_data_labeled['txId'])]
test_edges_labeled = test_edges_all[test_edges_all['txId1'].isin(test_data_labeled['txId']) & test_edges_all['txId2'].isin(test_data_labeled['txId'])]

# print sizes of datasets
print(f"Train data all: {train_data_all.shape}, Train data labeled: {train_data_labeled.shape}")
print(f"Test data all: {test_data_all.shape}, Test data labeled: {test_data_labeled.shape}")
print(f"Train edges all: {train_edges_all.shape}, Train edges labeled: {train_edges_labeled.shape}")
print(f"Test edges all: {test_edges_all.shape}, Test edges labeled: {test_edges_labeled.shape}")


Train data all: (136265, 169), Train data labeled: (29894, 169)
Test data all: (67504, 169), Test data labeled: (16670, 169)
Train edges all: (156843, 3), Train edges labeled: (22898, 3)
Test edges all: (77512, 3), Test edges labeled: (13726, 3)


In [106]:

# test data group dict, group by timestep for evaluation on each timestep
test_data_labeled_timestep = {}
test_edges_labeled_timestep = {}
for t in range(35, 50):
    test_data_labeled_timestep[t] = test_data_labeled[test_data_labeled[1] == t]
    test_edges_labeled_timestep[t] = test_edges_labeled[test_edges_labeled['timestep'] == t]
print(f"Test data labeled grouped by timestep: {[ (t, df.shape[0]) for t, df in test_data_labeled_timestep.items() ]}")
print(f"Test edges labeled grouped by timestep: {[ (t, df.shape[0]) for t, df in test_edges_labeled_timestep.items() ]}")

Test data labeled grouped by timestep: [(35, 1341), (36, 1708), (37, 498), (38, 756), (39, 1183), (40, 1211), (41, 1132), (42, 2154), (43, 1370), (44, 1591), (45, 1221), (46, 712), (47, 846), (48, 471), (49, 476)]
Test edges labeled grouped by timestep: [(35, 1002), (36, 1148), (37, 423), (38, 653), (39, 1055), (40, 1180), (41, 1048), (42, 1443), (43, 935), (44, 1497), (45, 1346), (46, 388), (47, 822), (48, 371), (49, 415)]


### Create DGL graph

In [107]:
# create DGL graphs - copied from Sev's experiment ipynb
import warnings
import dgl
import torch

# create DGL graphs for train and test data
def create_dgl_graph(data, edges, features="all"):
    # features: all or local
    # all: all features except txId, timestep, class; 
    # local: only local features (first 94 features except timestep) (column 2 to 95)
    node_ids = data['txId'].tolist()
    id_to_idx = {node_id: idx for idx, node_id in enumerate(node_ids)}
    
    src = edges['txId1'].map(id_to_idx).tolist()
    dst = edges['txId2'].map(id_to_idx).tolist()
    
    g = dgl.graph((src, dst), num_nodes=len(node_ids))
    if features == "local": # local features
        features = torch.tensor(data.iloc[:, 2: 96].values, dtype=torch.float32)
    else: # all features include local + one hop features
        features = torch.tensor(data.iloc[:, 2:-2].values, dtype=torch.float32)
    labels = torch.tensor(data['class'].values, dtype=torch.long)
    
    g.ndata['feat'] = features
    g.ndata['label'] = labels

    timestep_col = next((col for col in ('timestep', 1, '1') if col in data.columns), None)
    if timestep_col is not None:
        timesteps = torch.tensor(data[timestep_col].to_numpy(dtype='int64'), dtype=torch.long)
        g.ndata['timestep'] = timesteps
    else:
        warnings.warn(
            'create_dgl_graph could not find a timestep column; downstream modules requiring time context will fail.',
            RuntimeWarning,
        )
    
    return g

# graph with all features
train_labeled_graph = create_dgl_graph(train_data_labeled, train_edges_labeled)
test_labeled_graph = create_dgl_graph(test_data_labeled, test_edges_labeled)
train_all_graph = create_dgl_graph(train_data_all, train_edges_all)
test_all_graph = create_dgl_graph(test_data_all, test_edges_all)

print(f"Train labeled graph: {train_labeled_graph}")
print(f"Test labeled graph: {test_labeled_graph}")
print(f"Train all graph: {train_all_graph}")
print(f"Test all graph: {test_all_graph}")


# graph with local features only
train_labeled_graph_local = create_dgl_graph(train_data_labeled, train_edges_labeled, features="local")
test_labeled_graph_local = create_dgl_graph(test_data_labeled, test_edges_labeled, features="local")
train_all_graph_local = create_dgl_graph(train_data_all, train_edges_all, features="local")
test_all_graph_local = create_dgl_graph(test_data_all, test_edges_all, features="local")

print(f"Train labeled graph local features: {train_labeled_graph_local}")
print(f"Test labeled graph local features: {test_labeled_graph_local}")
print(f"Train all graph local features: {train_all_graph_local}")
print(f"Test all graph local features: {test_all_graph_local}")


Train labeled graph: Graph(num_nodes=29894, num_edges=22898,
      ndata_schemes={'feat': Scheme(shape=(165,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'timestep': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})
Test labeled graph: Graph(num_nodes=16670, num_edges=13726,
      ndata_schemes={'feat': Scheme(shape=(165,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'timestep': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})
Train all graph: Graph(num_nodes=136265, num_edges=156843,
      ndata_schemes={'feat': Scheme(shape=(165,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'timestep': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})
Test all graph: Graph(num_nodes=67504, num_edges=77512,
      ndata_schemes={'feat': Scheme(shape=(165,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'timestep': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})


In [117]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [109]:
# Prepare_training
def prepare_train_graphs(labeled_only=True, local_feature=False):
    print(f"============= Preparing graphs: labeled_only={labeled_only}, local_feature={local_feature} =================")
    if labeled_only:
        if local_feature:
            train_graph = dgl.to_bidirected(train_labeled_graph_local, copy_ndata=True)
            test_graph = dgl.to_bidirected(test_labeled_graph_local, copy_ndata=True)
        else:
            train_graph = dgl.to_bidirected(train_labeled_graph, copy_ndata=True)
            test_graph = dgl.to_bidirected(test_labeled_graph, copy_ndata=True)
    else:
        if local_feature:
            train_graph = dgl.to_bidirected(train_all_graph_local, copy_ndata=True)
            test_graph = dgl.to_bidirected(test_all_graph_local, copy_ndata=True)
        else:
            train_graph = dgl.to_bidirected(train_all_graph, copy_ndata=True)
            test_graph = dgl.to_bidirected(test_all_graph, copy_ndata=True)

    # --- Check Train Graph ---
    train_degrees = train_graph.in_degrees()
    train_isolated_nodes = (train_degrees == 0).sum().item()
    train_total_nodes = train_graph.num_nodes()
    train_percent_isolated = (train_isolated_nodes / train_total_nodes) * 100

    print(f"Training Graph: Percentage isolated: {train_percent_isolated:.2f}%; Isolated nodes: {train_isolated_nodes}; Total nodes: {train_total_nodes}")

    # --- Check Test Graph ---
    test_degrees = test_graph.in_degrees()
    test_isolated_nodes = (test_degrees == 0).sum().item()
    test_total_nodes = test_graph.num_nodes()
    test_percent_isolated = (test_isolated_nodes / test_total_nodes) * 100

    print(f"Test Graph: Percentage isolated: {test_percent_isolated:.2f}%; Total nodes: {test_total_nodes}; Isolated nodes: {test_isolated_nodes}\n")

    train_graph = dgl.add_self_loop(train_graph)
    test_graph = dgl.add_self_loop(test_graph)

    train_features = train_graph.ndata['feat']
    train_labels = train_graph.ndata['label']
    train_mask = (train_labels >= 0)
    test_features = test_graph.ndata['feat']
    test_labels = test_graph.ndata['label']
    test_mask = (test_labels >= 0)

    train_graph = train_graph.to(device)
    test_graph = test_graph.to(device)
    train_features = train_features.to(device)
    train_labels = train_labels.to(device)
    test_features = test_features.to(device)
    test_labels = test_labels.to(device)

    return train_graph, train_features, train_labels, train_mask, test_graph, test_features, test_labels, test_mask

train_graph_labeled, train_feature_labeled, train_labels_labeled, train_mask_labeled, \
    test_graph_labeled, test_features_labeled, test_labels_labeled, test_mask_labeled \
        = prepare_train_graphs(labeled_only=True, local_feature=False)

train_graph_all, train_feature_all, train_labels_all, train_mask_all, \
    test_graph_all, test_features_all, test_labels_all, test_mask_all \
        = prepare_train_graphs(labeled_only=False, local_feature=False)

train_graph_labeled_local, train_feature_labeled_local, train_labels_labeled_local, train_mask_labeled_local, \
    test_graph_labeled_local, test_features_labeled_local, test_labels_labeled_local, test_mask_labeled_local \
        = prepare_train_graphs(labeled_only=True, local_feature=True)

train_graph_all_local, train_feature_all_local, train_labels_all_local, train_mask_all_local, \
    test_graph_all_local, test_features_all_local, test_labels_all_local, test_mask_all_local \
        = prepare_train_graphs(labeled_only=False, local_feature=True)


Training Graph: Percentage isolated: 21.46%; Isolated nodes: 6415; Total nodes: 29894
Test Graph: Percentage isolated: 25.64%; Total nodes: 16670; Isolated nodes: 4275

Training Graph: Percentage isolated: 0.00%; Isolated nodes: 0; Total nodes: 136265
Test Graph: Percentage isolated: 0.00%; Total nodes: 67504; Isolated nodes: 0

Training Graph: Percentage isolated: 21.46%; Isolated nodes: 6415; Total nodes: 29894
Test Graph: Percentage isolated: 25.64%; Total nodes: 16670; Isolated nodes: 4275

Training Graph: Percentage isolated: 0.00%; Isolated nodes: 0; Total nodes: 136265
Test Graph: Percentage isolated: 0.00%; Total nodes: 67504; Isolated nodes: 0



In [110]:
# time steped testing graphs
test_labeled_graphs_timestep = {}
for t in range(35, 50):
    data_t = test_data_labeled_timestep[t]
    edges_t = test_edges_labeled_timestep[t]
    test_graph_t = create_dgl_graph(data_t, edges_t)
    test_graph_t = dgl.to_bidirected(test_graph_t, copy_ndata=True)
    test_graph_t = dgl.add_self_loop(test_graph_t)
    test_graph_t = test_graph_t.to(device)
    test_feature_t = test_graph_t.ndata['feat'].to(device)
    test_label_t = test_graph_t.ndata['label'].to(device)
    test_mask_t = (test_label_t >= 0).to(device)
    test_labeled_graphs_timestep[t] = (test_graph_t, test_feature_t, test_label_t, test_mask_t)

# test stepped testing graphs local features
# test_labeled_graphs_timestep_local = {}
# for t in range(35, 50):
#     data_t = test_data_labeled_timestep[t]
#     edges_t = test_edges_labeled_timestep[t]
#     test_labeled_graphs_timestep_local[t] = create_dgl_graph(data_t, edges_t, features="local")

### Training function

In [111]:
from sklearn.metrics import classification_report
from sklearn.metrics import precision_recall_fscore_support
import torch.nn as nn
import torch
import matplotlib.pyplot as plt

def evaluate_model(model, test_graph, test_features, test_labels, test_mask, criterion):
    model.eval()
    with torch.no_grad():
        test_logits = model(test_graph, test_features)
        test_preds = test_logits.argmax(dim=1)
        test_loss = criterion(test_logits, test_labels).item()
        
        # Get masked predictions and labels
        masked_preds = test_preds[test_mask].cpu().numpy()
        masked_labels = test_labels[test_mask].cpu().numpy()
        
        # Use sklearn's precision_recall_fscore_support with zero_division handling
        # pos_label=0 for illicit class, average=None to get per-class metrics
        precision, recall, f1, support = precision_recall_fscore_support(
            masked_labels, masked_preds, 
            labels=[0, 1],  # illicit=0, licit=1
            average=None,
            zero_division=0.0
        )
        
        # Extract illicit (class 0) metrics
        test_precision = precision[0] if len(precision) > 0 else 0.0
        test_recall = recall[0] if len(recall) > 0 else 0.0
        test_f1 = f1[0] if len(f1) > 0 else 0.0
        
        test_report = classification_report(
            masked_labels, masked_preds, 
            target_names=['illicit', 'licit'],
            zero_division=0
        )
        
        return test_loss, test_precision, test_recall, test_f1, test_report

def plot_training_history(history, model_name):

    # --- Create X-Axes ---
    
    # 1. X-axis for 'train_loss' (1 point per epoch)
    # We add a small offset (0.5) to starting_epoch for clearer plotting
    # if we resume training, so the first point isn't hidden.
    starting_epoch = history.get('test_epochs', [1])[0] - 1
    total_train_epochs = len(history.get('train_loss', []))
    train_loss_epochs = list(range(starting_epoch + 1, starting_epoch + 1 + total_train_epochs))

    # 2. X-axis for all test/validation metrics (sparse)
    # This list is saved directly in our history object
    eval_epochs = history.get('test_epochs', [])

    plt.figure(figsize=(15, 6))

    # --- Plot 1: Loss (Train vs. Test) ---
    plt.subplot(1, 2, 1)
    
    # Plot Train Loss (dense)
    if train_loss_epochs:
        plt.plot(train_loss_epochs, history['train_loss'], label='Train Loss', alpha=0.7, zorder=1)
    
    # Plot Test Loss (sparse)
    if eval_epochs:
        plt.plot(eval_epochs, history['test_loss'], label='Test Loss', 
                 marker='o', linestyle='--', linewidth=2, markersize=5, zorder=2)
        
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title(f'{model_name} Loss over Epochs')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.5)

    # --- Plot 2: F1 Score (Train vs. Test) ---
    plt.subplot(1, 2, 2)
    
    # Plot Train F1 (sparse, from eval steps)
    if eval_epochs and 'train_f1' in history:
        plt.plot(eval_epochs, history['train_f1'], label='Train F1 Score', 
                 marker='s', linestyle=':', linewidth=2, markersize=5)
    
    # Plot Test F1 (sparse, from eval steps)
    if eval_epochs and 'test_f1' in history:
        plt.plot(eval_epochs, history['test_f1'], label='Test F1 Score', 
                 color='orange', marker='o', linestyle='--', linewidth=2, markersize=5)
        
    plt.xlabel('Epochs')
    plt.ylabel('F1 Score')
    plt.title(f'{model_name} F1 Score at Evaluation Steps')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.5)

    plt.tight_layout()
    plt.show()

def train_model(model, model_name, optimizer, criterion, train_graph, train_features, train_labels, train_mask,
                test_graph, test_features, test_labels, test_mask,
                num_epochs, test_every=100, previous_history=None, print_best_report=True, show_plots=False,
                early_stopping_patience=None, checkpoint_path=None):
    # if previous_history is provided, resume training from there
    if previous_history is not None:
        history = previous_history
        starting_epoch = len(history["train_loss"]) + 1
    else:
        history = {"train_loss": [], "train_f1": [], "train_precision": [], "train_recall": [],
               "test_loss": [], "test_f1": [], "test_precision": [], "test_recall": [],
               "test_epochs": [],
               "best_test_f1": 0.0, "best_report": None, "best_model_state": None, "best_epoch": -1, 
               "last_test_f1": 0.0, "last_report": None, "latest_model_state": None}
        starting_epoch = 1

    epochs_since_improvement = 0
    checkpoint_target = checkpoint_path
    if checkpoint_target:
        checkpoint_dir = os.path.dirname(checkpoint_target)
        if checkpoint_dir:
            os.makedirs(checkpoint_dir, exist_ok=True)

    final_epoch_of_training = starting_epoch + num_epochs - 1
    # training 
    for epoch in range(starting_epoch, final_epoch_of_training + 1):
        model.train()
        logits = model(train_graph, train_features)
        loss = criterion(logits, train_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        history["train_loss"].append(loss.item())

        if epoch == 1 or epoch % test_every == 0 or epoch == final_epoch_of_training:
            history["test_epochs"].append(epoch)
            train_pred = logits.argmax(dim=1)
            train_precision = ((train_pred[train_mask] == 0) & (train_labels[train_mask] == 0)).sum().item() / (train_pred[train_mask] == 0).sum().item()
            train_recall = ((train_pred[train_mask] == 0) & (train_labels[train_mask] == 0)).sum().item() / (train_labels[train_mask] == 0).sum().item()
            train_f1 = 2 * train_precision * train_recall / (train_precision + train_recall)
            history["train_f1"].append(train_f1)
            history["train_precision"].append(train_precision)
            history["train_recall"].append(train_recall)
            
            test_loss, test_precision, test_recall, test_f1, test_report = evaluate_model(model, test_graph, test_features, test_labels, test_mask, criterion)
            history["test_loss"].append(test_loss)
            history["test_f1"].append(test_f1)
            history["test_precision"].append(test_precision)
            history["test_recall"].append(test_recall)
            print(f"Epoch {epoch:03d}: Loss {loss.item():.4f}, Train F1 {train_f1:.4f}, Test Loss {test_loss:.4f}, Test F1 {test_f1:.4f}")
            improved = False
        
            if test_f1 > history["best_test_f1"]:
                history["best_test_f1"] = test_f1
                history["best_report"] = test_report
                history["best_epoch"] = epoch
                history["best_model_state"] = model.state_dict()
                improved = True
                if checkpoint_target:
                    torch.save(history["best_model_state"], checkpoint_target)
            if early_stopping_patience is not None:
                if improved:
                    epochs_since_improvement = 0
                else:
                    epochs_since_improvement += 1
                    if epochs_since_improvement >= early_stopping_patience:
                        print(f"Early stopping at epoch {epoch} after {early_stopping_patience} eval steps without improvement")
                        break

    # final evaluation on test set
    # final_test_loss, final_test_precision, final_test_recall, final_test_f1, final_test_report = evaluate_model(model, test_graph, test_features, test_labels, test_mask, criterion)
    history["latest_model_state"] = model.state_dict()
    history["last_test_f1"] = history["test_f1"][-1]
    history["last_report"] = test_report

    print(f"{model_name} Last Classification Report on Labeled Test Graph:")
    print(history["last_report"])

    if print_best_report:
        print(f"{model_name} Best Classification Report on Labeled Test Graph at epoch {history['best_epoch']}:")
        print(history["best_report"])

    if show_plots:
        plot_training_history(history, model_name)

    return history

In [112]:

from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np

# Evaluate a trained model on per-timestep labeled test graphs and aggregate before/after t=43.
# Uses existing variables: test_labeled_graphs_timestep, gt (model), device, classification_report, dgl, torch, criterion
def evaluate_timestep(model_to_test, feature_builder=None):
    model_to_test.eval()

    if feature_builder is None:
        def _identity_builder(_, feats):
            return feats
        feature_builder = _identity_builder

    per_t_metrics = {}
    all_preds_before = []
    all_trues_before = []
    all_preds_after = []
    all_trues_after = []
    f1_list = []

    for t in sorted(test_labeled_graphs_timestep.keys()):
        g_eval, feats, labels, mask = test_labeled_graphs_timestep[t]
        if g_eval is None or g_eval.num_nodes() == 0:
            print(f"t={t}: empty graph, skipping")
            continue

        feats_for_model = feature_builder(g_eval, feats)
        loss, precision, recall, f1, report = evaluate_model(model_to_test, g_eval, feats_for_model, labels, mask, criterion)
        f1_list.append(f1)
        per_t_metrics[t] = {
            "n_nodes": int(g_eval.num_nodes()),
            "loss": loss,
            "precision_illicit": precision,
            "recall_illicit": recall,
            "f1_illicit": f1,
            "classification_report": report
        }

        # record masked count
        per_t_metrics[t]["n_masked"] = int(mask.sum().item())

        # get predictions for masked nodes
        with torch.no_grad():
            logits = model_to_test(g_eval, feats_for_model)
            preds = logits.argmax(dim=1)

        masked_preds = preds[mask]
        masked_labels = labels[mask]

        # accumulate for before/after 43
        arr_preds = masked_preds.cpu().numpy()
        arr_trues = masked_labels.cpu().numpy()
        if t < 43:
            all_preds_before.append(arr_preds)
            all_trues_before.append(arr_trues)
        else:
            all_preds_after.append(arr_preds)
            all_trues_after.append(arr_trues)

    # --- After the loop: aggregate and print reports ---
    def aggregate_and_report(preds_list, trues_list, label=""):
        if not preds_list:
            print(f"No data for {label}")
            return None
        preds_all = np.concatenate(preds_list)
        trues_all = np.concatenate(trues_list)
        # use sklearn to compute illicit-focused metrics (class 0 == illicit)

        precision = float(precision_score(trues_all, preds_all, pos_label=0, zero_division=0))
        recall = float(recall_score(trues_all, preds_all, pos_label=0, zero_division=0))
        f1 = float(f1_score(trues_all, preds_all, pos_label=0, zero_division=0))
        report = classification_report(trues_all, preds_all, target_names=['illicit', 'licit'], zero_division=0)
        return {"precision": precision, "recall": recall, "f1": f1, "report": report}

    agg_before = aggregate_and_report(all_preds_before, all_trues_before, label="t < 43")
    agg_after  = aggregate_and_report(all_preds_after,  all_trues_after,  label="t >= 43")

    return {
        "f1_list": f1_list,
        "agg_before_f1": agg_before["f1"] if agg_before else None,
        "agg_after_f1": agg_after["f1"] if agg_after else None,
        "per_t_metrics": per_t_metrics,
        "agg_before": agg_before,
        "agg_after": agg_after
    }


In [113]:

def report_timestep_performance(model_name, model, feature_builder=None):
    metrics = evaluate_timestep(model, feature_builder=feature_builder)
    print(f"[{model_name}] per-timestep illicit metrics:")
    for t in sorted(metrics['per_t_metrics'].keys()):
        metric = metrics['per_t_metrics'][t]
        print(f"  t={t:02d} | nodes={metric['n_masked']:5d} | Loss={metric['loss']:.4f} | P={metric['precision_illicit']:.4f} | R={metric['recall_illicit']:.4f} | F1={metric['f1_illicit']:.4f}")
    agg_before = metrics.get('agg_before')
    if agg_before:
        print(f"  Aggregate t<43 | P={agg_before['precision']:.4f} | R={agg_before['recall']:.4f} | F1={agg_before['f1']:.4f}")
    agg_after = metrics.get('agg_after')
    if agg_after:
        print(f"  Aggregate t>=43 | P={agg_after['precision']:.4f} | R={agg_after['recall']:.4f} | F1={agg_after['f1']:.4f}")
    print()
    return metrics


In [114]:
import json
import os
import numpy as np
import torch
import torch


def save_history(history, sub_dir, save_dir = "checkpoints"):
    # Save training history (JSON-safe) and model states (.pt)
    os.makedirs(os.path.join(save_dir, sub_dir), exist_ok=True)

    # Extract and remove model-state entries from history before JSON serialization
    best_state = history.get("best_model_state", None)
    latest_state = history.get("latest_model_state", None)

    history_copy = {k: v for k, v in history.items() if k not in ("best_model_state", "latest_model_state")}

    history_path = os.path.join(save_dir, sub_dir, "history.json")
    with open(history_path, "w") as f:
        json.dump(history_copy, f, indent=2)

    # Save model state dicts (if present)

    if best_state is not None:
        torch.save(best_state, os.path.join(save_dir, sub_dir, "best_model_state.pt"))
    if latest_state is not None:
        torch.save(latest_state, os.path.join(save_dir, sub_dir, "latest_model_state.pt"))

    print("Saved:")
    print(f" - History JSON: {history_path}")
    if best_state is not None:
        print(f" - Best state: {os.path.join(save_dir, sub_dir, 'best_model_state.pt')}")
    if latest_state is not None:
        print(f" - Latest state: {os.path.join(save_dir, sub_dir, 'latest_model_state.pt')}")

def load_history(sub_dir, save_dir="checkpoints", model=None, map_location=None):
    """
    Load saved training history and model state dicts from disk.

    Args:
        sub_dir (str): subdirectory under save_dir where files are stored.
        save_dir (str): root checkpoints directory (default "checkpoints").
        model (nn.Module, optional): if provided and a gt_state_dict exists, it will be loaded into this model.
        map_location (str or torch.device, optional): passed to torch.load (default "cpu" if None).

    Returns:
        dict: {
            "history": dict or None,
            "best_model_state": state_dict or None,
            "latest_model_state": state_dict or None,
            "gt_state_dict": state_dict or None,
            "model_loaded": bool
        }
    """
    base_path = os.path.join(save_dir, sub_dir)
    if map_location is None:
        map_location = "cpu"

    result = {}

    if not os.path.isdir(base_path):
        raise FileNotFoundError(f"Directory not found: {base_path}")

    # history JSON (name used by save_history)
    history_path = os.path.join(base_path, "history.json")
    if os.path.exists(history_path):
        with open(history_path, "r") as f:
            result = json.load(f)

    def _load_pt(fname):
        p = os.path.join(base_path, fname)
        if os.path.exists(p):
            return torch.load(p, map_location=map_location)
        return None

    result["best_model_state"] = _load_pt("best_model_state.pt")
    result["latest_model_state"] = _load_pt("latest_model_state.pt")

    # brief prints for confirmation
    print(f"Loaded history from {base_path}")
    return result

Random Forest Baseline

In [69]:
from sklearn.ensemble import RandomForestClassifier

# licit node class=1, illicit node class=0
# train and evaluate a random forest classifier on the train and test data with labels
# n_estimators=50, max_features=50
# evaluate on both licit and illicit nodes' precision and recall and f1-score, also include micro and macro averages
clf = RandomForestClassifier(n_estimators=50, max_features=50, random_state=42)
clf.fit(train_data_labeled.iloc[:, 1:-2], train_data_labeled['class'])
test_preds = clf.predict(test_data_labeled.iloc[:, 1:-2]) 
from sklearn.metrics import classification_report
report = classification_report(test_data_labeled['class'], test_preds, target_names=['illicit', 'licit'])
print("Random Forest Classifier Report on Labeled Test Data:")
print(report)

Random Forest Classifier Report on Labeled Test Data:
              precision    recall  f1-score   support

     illicit       0.91      0.73      0.81      1083
       licit       0.98      1.00      0.99     15587

    accuracy                           0.98     16670
   macro avg       0.95      0.86      0.90     16670
weighted avg       0.98      0.98      0.98     16670



## GCN

### Model Definition

In [70]:
from dgl.nn import GraphConv

embedding_dim = 100

class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super().__init__()
        self.conv1 = GraphConv(in_feats, hidden_size, allow_zero_in_degree=True)
        self.conv2 = GraphConv(hidden_size, num_classes, allow_zero_in_degree=True)

    def forward(self, g, feat):
        h = self.conv1(g, feat)
        h = torch.relu(h)
        h = self.conv2(g, h)
        return h

### Training without unknown nodes

In [71]:
class_weights = torch.tensor([0.7, 0.3], dtype=torch.float32, device=device)
gcn_labeled = GCN(train_feature_labeled.shape[1], embedding_dim, 2).to(device)
gcn_labeled_optimizer = torch.optim.Adam(gcn_labeled.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

history_gcn_labeled = train_model(
    gcn_labeled,
    "GCN - Labeled Only",
    gcn_labeled_optimizer,
    criterion,
    train_graph_labeled,
    train_feature_labeled,
    train_labels_labeled,
    train_mask_labeled,
    test_graph_labeled,
    test_features_labeled,
    test_labels_labeled,
    test_mask_labeled,
    num_epochs=800,
    test_every=50,
    early_stopping_patience=20,
    checkpoint_path="checkpoints/GCN/base_labeled/best_model_state.pt"
)

save_history(history_gcn_labeled, sub_dir="GCN_base_labeled")

gcn_labeled_timestep_metrics = report_timestep_performance("GCN - Labeled Only", gcn_labeled)


Epoch 001: Loss 0.4449, Train F1 0.0086, Test Loss 0.3614, Test F1 0.2572
Epoch 050: Loss 0.1756, Train F1 0.7916, Test Loss 0.2952, Test F1 0.4281
Epoch 100: Loss 0.1390, Train F1 0.8346, Test Loss 0.2781, Test F1 0.5230
Epoch 150: Loss 0.1190, Train F1 0.8612, Test Loss 0.2864, Test F1 0.5432
Epoch 200: Loss 0.1053, Train F1 0.8763, Test Loss 0.2971, Test F1 0.5420
Epoch 250: Loss 0.0947, Train F1 0.8900, Test Loss 0.3105, Test F1 0.5469
Epoch 300: Loss 0.0862, Train F1 0.8976, Test Loss 0.3254, Test F1 0.5529
Epoch 350: Loss 0.0794, Train F1 0.9048, Test Loss 0.3406, Test F1 0.5605
Epoch 400: Loss 0.0739, Train F1 0.9127, Test Loss 0.3553, Test F1 0.5593
Epoch 450: Loss 0.0694, Train F1 0.9191, Test Loss 0.3679, Test F1 0.5606
Epoch 500: Loss 0.0657, Train F1 0.9235, Test Loss 0.3789, Test F1 0.5544
Epoch 550: Loss 0.0625, Train F1 0.9267, Test Loss 0.3875, Test F1 0.5489
Epoch 600: Loss 0.0597, Train F1 0.9299, Test Loss 0.3940, Test F1 0.5508
Epoch 650: Loss 0.0572, Train F1 0.933

### Training with unknown nodes

In [72]:
class_weights = torch.tensor([0.7, 0.3], dtype=torch.float32, device=device)
gcn_all = GCN(train_feature_all.shape[1], embedding_dim, 2).to(device)
gcn_all_optimizer = torch.optim.Adam(gcn_all.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

history_gcn_all = train_model(
    gcn_all,
    "GCN - All Nodes",
    gcn_all_optimizer,
    criterion,
    train_graph_all,
    train_feature_all,
    train_labels_all,
    train_mask_all,
    test_graph_all,
    test_features_all,
    test_labels_all,
    test_mask_all,
    num_epochs=800,
    test_every=50,
    early_stopping_patience=20,
    checkpoint_path="checkpoints/GCN/base_all/best_model_state.pt"
)

save_history(history_gcn_all, sub_dir="GCN_base_all")

gcn_all_timestep_metrics = report_timestep_performance("GCN - All Nodes", gcn_all)


Epoch 001: Loss 0.9807, Train F1 0.2198, Test Loss 0.9828, Test F1 0.1355
Epoch 050: Loss 0.2360, Train F1 0.7031, Test Loss 0.4017, Test F1 0.3109
Epoch 100: Loss 0.1978, Train F1 0.7648, Test Loss 0.3351, Test F1 0.3862
Epoch 150: Loss 0.1752, Train F1 0.7952, Test Loss 0.2977, Test F1 0.4368
Epoch 200: Loss 0.1600, Train F1 0.8126, Test Loss 0.2815, Test F1 0.4541
Epoch 250: Loss 0.1487, Train F1 0.8240, Test Loss 0.2761, Test F1 0.4643
Epoch 300: Loss 0.1399, Train F1 0.8364, Test Loss 0.2757, Test F1 0.4779
Epoch 350: Loss 0.1327, Train F1 0.8470, Test Loss 0.2772, Test F1 0.4849
Epoch 400: Loss 0.1266, Train F1 0.8544, Test Loss 0.2788, Test F1 0.4980
Epoch 450: Loss 0.1212, Train F1 0.8628, Test Loss 0.2795, Test F1 0.5211
Epoch 500: Loss 0.1164, Train F1 0.8675, Test Loss 0.2806, Test F1 0.5505
Epoch 550: Loss 0.1120, Train F1 0.8740, Test Loss 0.2839, Test F1 0.5559
Epoch 600: Loss 0.1080, Train F1 0.8787, Test Loss 0.2886, Test F1 0.5580
Epoch 650: Loss 0.1043, Train F1 0.882

## BGRL + GCN

### Pretraining Setup

In [73]:
import copy

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv

# BGRL self-supervised pretraining on the full training set
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ssl_embedding_dim = 100
pretrain_epochs = 30
edge_drop_prob = 0.2
feat_drop_prob = 0.1
momentum = 0.99
ssl_lr = 1e-3
raw_feature_count = 94

base_pretrain_graph = train_all_graph.to(device)
pretrain_features = base_pretrain_graph.ndata['feat']
num_total_features = pretrain_features.shape[1]
agg_feature_count = num_total_features - raw_feature_count
if agg_feature_count < 0:
    raise ValueError("raw_feature_count is larger than available feature dimensions")

print(f"Train graph for SSL: {base_pretrain_graph}")

# Pre-compute adjacency helper so edge dropout never isolates a node
src_full, dst_full = base_pretrain_graph.edges()
edge_indices = torch.arange(base_pretrain_graph.num_edges(), device=device)
node_edge_nodes = torch.cat([src_full, dst_full])
node_edge_edges = torch.cat([edge_indices, edge_indices])
order = torch.argsort(node_edge_nodes)
node_edge_nodes = node_edge_nodes[order]
node_edge_edges = node_edge_edges[order]
node_ptr = torch.searchsorted(node_edge_nodes, torch.arange(base_pretrain_graph.num_nodes() + 1, device=device))
edge_helper = {
    'src': src_full,
    'dst': dst_full,
    'node_edge_nodes': node_edge_nodes,
    'node_edge_edges': node_edge_edges,
    'node_ptr': node_ptr,
}

Train graph for SSL: Graph(num_nodes=136265, num_edges=156843,
      ndata_schemes={'feat': Scheme(shape=(165,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})


### Augmentation Utilities

In [74]:
def apply_feature_masks(features):
    raw_part = features[:, :raw_feature_count]
    raw_mask = torch.empty_like(raw_part).uniform_(0.2, 0.4)
    masked_raw = raw_part * raw_mask
    if agg_feature_count == 0:
        return masked_raw
    agg_part = features[:, raw_feature_count:]
    agg_mask = torch.empty_like(agg_part).uniform_(0.05, 0.2)
    masked_agg = agg_part * agg_mask
    return torch.cat([masked_raw, masked_agg], dim=1)


def ensure_connectivity(keep_mask):
    kept_idx = torch.nonzero(keep_mask, as_tuple=False).squeeze(1)
    if kept_idx.numel() == 0:
        random_edge = torch.randint(0, keep_mask.shape[0], (1,), device=keep_mask.device)
        keep_mask[random_edge] = True
        kept_idx = random_edge
    num_nodes = base_pretrain_graph.num_nodes()
    deg = torch.zeros(num_nodes, device=keep_mask.device, dtype=torch.int64)
    deg.scatter_add_(0, src_full[kept_idx], torch.ones_like(kept_idx, dtype=torch.int64))
    deg.scatter_add_(0, dst_full[kept_idx], torch.ones_like(kept_idx, dtype=torch.int64))
    zero_nodes = (deg == 0).nonzero(as_tuple=False).squeeze(1)
    if zero_nodes.numel() == 0:
        return keep_mask
    node_edge_nodes = edge_helper['node_edge_nodes']
    node_edge_edges = edge_helper['node_edge_edges']
    node_ptr = edge_helper['node_ptr']
    for node in zero_nodes.tolist():
        start = int(node_ptr[node].item())
        end = int(node_ptr[node + 1].item())
        if start == end:
            continue
        candidates = node_edge_edges[start:end]
        chosen = candidates[torch.randint(0, candidates.shape[0], (1,), device=candidates.device)]
        keep_mask[chosen] = True
    return keep_mask


def random_edge_dropout_preserve(drop_prob):
    if drop_prob <= 0 or base_pretrain_graph.num_edges() == 0:
        return base_pretrain_graph
    mask = torch.rand(src_full.shape[0], device=device) > drop_prob
    mask = ensure_connectivity(mask)
    kept_idx = torch.nonzero(mask, as_tuple=False).squeeze(1)
    aug = dgl.graph((src_full[kept_idx], dst_full[kept_idx]), num_nodes=base_pretrain_graph.num_nodes(), device=device)
    return aug


def graph_augment():
    aug_graph = random_edge_dropout_preserve(edge_drop_prob)
    aug_graph = dgl.add_self_loop(aug_graph)
    masked_feats = apply_feature_masks(pretrain_features)
    if feat_drop_prob > 0:
        masked_feats = F.dropout(masked_feats, p=feat_drop_prob, training=True)
    return aug_graph, masked_feats

### BGRL Modules

In [75]:
class Projector(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.PReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        return self.net(x)


class Predictor(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.PReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        return self.net(x)


class BGRL(nn.Module):
    def __init__(self, encoder, hidden_dim, momentum=0.99):
        super().__init__()
        self.online_encoder = encoder
        self.online_projector = Projector(hidden_dim)
        self.target_encoder = copy.deepcopy(encoder)
        self.target_projector = copy.deepcopy(self.online_projector)
        for p in self.target_encoder.parameters():
            p.requires_grad = False
        for p in self.target_projector.parameters():
            p.requires_grad = False
        self.predictor = Predictor(hidden_dim)
        self.momentum = momentum

    @torch.no_grad()
    def update_target(self):
        for target_param, online_param in zip(self.target_encoder.parameters(), self.online_encoder.parameters()):
            target_param.data = self.momentum * target_param.data + (1 - self.momentum) * online_param.data
        for target_param, online_param in zip(self.target_projector.parameters(), self.online_projector.parameters()):
            target_param.data = self.momentum * target_param.data + (1 - self.momentum) * online_param.data

    def loss_fn(self, p, z):
        p = F.normalize(p, dim=1)
        z = F.normalize(z.detach(), dim=1)
        return 2 - 2 * (p * z).sum(dim=1).mean()

    def forward(self, g1, x1, g2, x2):
        h1 = self.online_encoder(g1, x1)
        h2 = self.online_encoder(g2, x2)
        z1 = self.online_projector(h1)
        z2 = self.online_projector(h2)
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        with torch.no_grad():
            t1 = self.target_projector(self.target_encoder(g1, x1))
            t2 = self.target_projector(self.target_encoder(g2, x2))
        return self.loss_fn(p1, t2) + self.loss_fn(p2, t1)

### Pretraining Loop

In [76]:
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

ssl_model = BGRL(GCN(num_total_features, ssl_embedding_dim, ssl_embedding_dim).to(device), ssl_embedding_dim, momentum=momentum).to(device)
optimizer = torch.optim.Adam(ssl_model.parameters(), lr=ssl_lr, weight_decay=1e-4)

print("Starting BGRL pretraining...")
for epoch in range(1, pretrain_epochs + 1):
    g1, x1 = graph_augment()
    g2, x2 = graph_augment()
    loss = ssl_model(g1, x1, g2, x2)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    ssl_model.update_target()

    if epoch == 1 or epoch % 20 == 0:
        print(f"SSL Epoch {epoch:04d} | Loss: {loss.item():.4f}")

pretrained_encoder = ssl_model.online_encoder
pretrained_encoder.eval()
for param in pretrained_encoder.parameters():
    param.requires_grad = False

print("Finished BGRL pretraining. The frozen encoder is available as `pretrained_encoder`.")

Starting BGRL pretraining...
SSL Epoch 0001 | Loss: 4.0515
SSL Epoch 0020 | Loss: 0.2649
Finished BGRL pretraining. The frozen encoder is available as `pretrained_encoder`.


### BGRL Embeddings + Raw Features + GCN

In [77]:

import torch
import torch.nn as nn

if 'pretrained_encoder' not in globals():
    raise RuntimeError('Run the BGRL pretraining cells first to populate `pretrained_encoder`.')

pretrained_encoder = pretrained_encoder.to(device).eval()
class_weights = torch.tensor([0.7, 0.3], dtype=torch.float32, device=device)

with torch.no_grad():
    ssl_train_embeddings = pretrained_encoder(train_graph_all, train_feature_all).detach()
    ssl_test_embeddings = pretrained_encoder(test_graph_all, test_features_all).detach()

ssl_aug_train = torch.cat([train_feature_all, ssl_train_embeddings], dim=1)
ssl_aug_test = torch.cat([test_features_all, ssl_test_embeddings], dim=1)

bgrl_concat_gcn = GCN(ssl_aug_train.shape[1], embedding_dim, 2).to(device)
optimizer = torch.optim.Adam(bgrl_concat_gcn.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

history_bgrl_concat = train_model(
    bgrl_concat_gcn,
    'BGRL Embeddings + Raw Features',
    optimizer,
    criterion,
    train_graph_all,
    ssl_aug_train,
    train_labels_all,
    train_mask_all,
    test_graph_all,
    ssl_aug_test,
    test_labels_all,
    test_mask_all,
    num_epochs=600,
    test_every=50,
    early_stopping_patience=20,
    checkpoint_path='checkpoints/BGRL/gcn_concat/best_model_state.pt'
)

save_history(history_bgrl_concat, sub_dir='BGRL_gcn_concat')


def build_bgrl_concat_features(g_eval, feats):
    g_eval = g_eval.to(device)
    feats = feats.to(device)
    with torch.no_grad():
        ssl_embeds = pretrained_encoder(g_eval, feats)
    return torch.cat([feats, ssl_embeds], dim=1)


bgrl_concat_timestep_metrics = report_timestep_performance(
    "BGRL Embeddings + Raw Features",
    bgrl_concat_gcn,
    feature_builder=build_bgrl_concat_features
)



Epoch 001: Loss 0.5833, Train F1 0.2572, Test Loss 0.4547, Test F1 0.1174
Epoch 050: Loss 0.2051, Train F1 0.7528, Test Loss 0.2981, Test F1 0.3854
Epoch 100: Loss 0.1682, Train F1 0.7954, Test Loss 0.2755, Test F1 0.4434
Epoch 150: Loss 0.1438, Train F1 0.8290, Test Loss 0.2742, Test F1 0.4808
Epoch 200: Loss 0.1254, Train F1 0.8533, Test Loss 0.2736, Test F1 0.5224
Epoch 250: Loss 0.1114, Train F1 0.8728, Test Loss 0.2776, Test F1 0.5821
Epoch 300: Loss 0.1002, Train F1 0.8881, Test Loss 0.2850, Test F1 0.5979
Epoch 350: Loss 0.0913, Train F1 0.8989, Test Loss 0.2916, Test F1 0.6043
Epoch 400: Loss 0.0843, Train F1 0.9078, Test Loss 0.2998, Test F1 0.6057
Epoch 450: Loss 0.0784, Train F1 0.9145, Test Loss 0.3056, Test F1 0.5996
Epoch 500: Loss 0.0735, Train F1 0.9219, Test Loss 0.3088, Test F1 0.5999
Epoch 550: Loss 0.0694, Train F1 0.9280, Test Loss 0.3113, Test F1 0.6037
Epoch 600: Loss 0.0657, Train F1 0.9306, Test Loss 0.3179, Test F1 0.6018
BGRL Embeddings + Raw Features Last Cl

### BGRL Node Score + Raw Features + GCN

In [78]:
class_weights = torch.tensor([0.7, 0.3], dtype=torch.float32, device=device)

if 'pretrained_encoder' not in globals():
    raise RuntimeError('Run the BGRL pretraining cells first to populate `pretrained_encoder`.')

if 'ssl_train_embeddings' not in globals() or 'ssl_test_embeddings' not in globals():
    raise RuntimeError('Generate SSL embeddings before training the score-based head.')

with torch.no_grad():
    bgrl_train_scores = torch.norm(ssl_train_embeddings, dim=1, keepdim=True)
    bgrl_test_scores = torch.norm(ssl_test_embeddings, dim=1, keepdim=True)

bgrl_score_mean = bgrl_train_scores.mean()
bgrl_score_std = bgrl_train_scores.std().clamp_min(1e-6)
bgrl_train_scores = (bgrl_train_scores - bgrl_score_mean) / bgrl_score_std
bgrl_test_scores = (bgrl_test_scores - bgrl_score_mean) / bgrl_score_std

bgrl_score_train = torch.cat([train_feature_all, bgrl_train_scores.to(device)], dim=1)
bgrl_score_test = torch.cat([test_features_all, bgrl_test_scores.to(device)], dim=1)

bgrl_score_gcn = GCN(bgrl_score_train.shape[1], embedding_dim, 2).to(device)
optimizer = torch.optim.Adam(bgrl_score_gcn.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

history_bgrl_score = train_model(
    bgrl_score_gcn,
    "BGRL Node Score + Raw Features",
    optimizer,
    criterion,
    train_graph_all,
    bgrl_score_train,
    train_labels_all,
    train_mask_all,
    test_graph_all,
    bgrl_score_test,
    test_labels_all,
    test_mask_all,
    num_epochs=600,
    test_every=50,
    early_stopping_patience=20,
    checkpoint_path='checkpoints/BGRL/gcn_scores/best_model_state.pt'
)

save_history(history_bgrl_score, sub_dir='BGRL_gcn_scores')


def build_bgrl_score_features(g_eval, feats):
    g_eval = g_eval.to(device)
    feats = feats.to(device)
    with torch.no_grad():
        ssl_embeds = pretrained_encoder(g_eval, feats)
        scores = torch.norm(ssl_embeds, dim=1, keepdim=True)
    scores = (scores - bgrl_score_mean) / bgrl_score_std
    return torch.cat([feats, scores], dim=1)


bgrl_score_timestep_metrics = report_timestep_performance(
    "BGRL Node Score + Raw Features",
    bgrl_score_gcn,
    feature_builder=build_bgrl_score_features
)



Epoch 001: Loss 0.5961, Train F1 0.1110, Test Loss 0.4229, Test F1 0.0649
Epoch 050: Loss 0.2151, Train F1 0.7518, Test Loss 0.3368, Test F1 0.3586
Epoch 100: Loss 0.1766, Train F1 0.7866, Test Loss 0.2932, Test F1 0.4241
Epoch 150: Loss 0.1564, Train F1 0.8105, Test Loss 0.2869, Test F1 0.4455
Epoch 200: Loss 0.1424, Train F1 0.8278, Test Loss 0.2862, Test F1 0.4560
Epoch 250: Loss 0.1313, Train F1 0.8448, Test Loss 0.2859, Test F1 0.4851
Epoch 300: Loss 0.1221, Train F1 0.8591, Test Loss 0.2859, Test F1 0.5142
Epoch 350: Loss 0.1143, Train F1 0.8688, Test Loss 0.2888, Test F1 0.5482
Epoch 400: Loss 0.1075, Train F1 0.8769, Test Loss 0.2935, Test F1 0.5651
Epoch 450: Loss 0.1017, Train F1 0.8834, Test Loss 0.2975, Test F1 0.5793
Epoch 500: Loss 0.0965, Train F1 0.8901, Test Loss 0.3007, Test F1 0.5926
Epoch 550: Loss 0.0917, Train F1 0.8950, Test Loss 0.3042, Test F1 0.5943
Epoch 600: Loss 0.0874, Train F1 0.9019, Test Loss 0.3085, Test F1 0.5938
BGRL Node Score + Raw Features Last Cl

## GAE

### Pretraining Setup

In [79]:

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

if 'train_all_graph' not in globals():
    raise RuntimeError('Load the graphs first so `train_all_graph` is defined.')
if 'GCN' not in globals():
    raise RuntimeError('Run the GCN baseline cell to define the `GCN` encoder class.')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

gae_hidden_dim = 256
gae_latent_dim = 128
gae_epochs = 30
gae_lr = 1e-3
alpha_attr = 0.5
neg_sample_ratio = 1.0
max_struct_samples = 200_000


def prepare_graph_for_gae(graph):
    graph = dgl.to_bidirected(graph, copy_ndata=True)
    return dgl.add_self_loop(graph)


gae_graph = prepare_graph_for_gae(train_all_graph).to(device)
gae_features = gae_graph.ndata['feat'].float().to(device)
num_nodes = gae_graph.num_nodes()
in_feats = gae_features.shape[1]


### Encoder and Decoders

In [80]:

class InnerProductDecoder(nn.Module):
    def forward(self, z, src, dst):
        return (z[src] * z[dst]).sum(dim=1)


gae_encoder = GCN(in_feats, gae_hidden_dim, gae_latent_dim).to(device)
gae_structure_decoder = InnerProductDecoder().to(device)
gae_attribute_decoder = nn.Sequential(
    nn.Linear(gae_latent_dim, gae_hidden_dim),
    nn.ReLU(),
    nn.Linear(gae_hidden_dim, in_feats)
).to(device)

optimizer = torch.optim.Adam(
    list(gae_encoder.parameters()) + list(gae_attribute_decoder.parameters()),
    lr=gae_lr,
    weight_decay=5e-4,
)

src_all, dst_all = gae_graph.edges()
src_all = src_all.to(device)
dst_all = dst_all.to(device)

bce_loss = nn.BCEWithLogitsLoss()
mse_loss = nn.MSELoss()


### Sampling Utilities

In [81]:

def sample_positive_edges(num_samples):
    if num_samples >= src_all.shape[0]:
        return src_all, dst_all
    idx = torch.randint(0, src_all.shape[0], (num_samples,), device=device)
    return src_all[idx], dst_all[idx]


def sample_negative_edges(num_samples):
    neg_src = torch.randint(0, num_nodes, (num_samples,), device=device)
    neg_dst = torch.randint(0, num_nodes, (num_samples,), device=device)
    return neg_src, neg_dst


### Training Loop

In [82]:

print('Starting Graph Autoencoder (GAE) pretraining on train_all_graph...')
for epoch in range(1, gae_epochs + 1):
    gae_encoder.train()
    gae_attribute_decoder.train()

    latent_z = gae_encoder(gae_graph, gae_features)

    num_pos_samples = min(src_all.shape[0], max_struct_samples)
    pos_src, pos_dst = sample_positive_edges(num_pos_samples)
    pos_logits = gae_structure_decoder(latent_z, pos_src, pos_dst)
    pos_labels = torch.ones_like(pos_logits)

    num_neg_samples = max(1, int(neg_sample_ratio * num_pos_samples))
    neg_src, neg_dst = sample_negative_edges(num_neg_samples)
    neg_logits = gae_structure_decoder(latent_z, neg_src, neg_dst)
    neg_labels = torch.zeros_like(neg_logits)

    struct_loss = bce_loss(
        torch.cat([pos_logits, neg_logits], dim=0),
        torch.cat([pos_labels, neg_labels], dim=0)
    )

    recon_features = gae_attribute_decoder(latent_z)
    attr_loss = mse_loss(recon_features, gae_features)

    loss = struct_loss + alpha_attr * attr_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch == 1 or epoch % 20 == 0 or epoch == gae_epochs:
        print(
            f'GAE Epoch {epoch:04d} | Total: {loss.item():.4f} | '
            f'Struct: {struct_loss.item():.4f} | Attr: {attr_loss.item():.4f}'
        )

gae_encoder.eval()
gae_structure_decoder.eval()
gae_attribute_decoder.eval()

with torch.no_grad():
    gae_train_embeddings = gae_encoder(gae_graph, gae_features).detach()
    gae_train_attr_recon = gae_attribute_decoder(gae_train_embeddings).detach()

print('Finished GAE pretraining. Saved embeddings for downstream tasks.')


Starting Graph Autoencoder (GAE) pretraining on train_all_graph...


GAE Epoch 0001 | Total: 4.0057 | Struct: 3.5143 | Attr: 0.9829
GAE Epoch 0020 | Total: 0.9296 | Struct: 0.5007 | Attr: 0.8580
GAE Epoch 0030 | Total: 0.8355 | Struct: 0.4401 | Attr: 0.7909
Finished GAE pretraining. Saved embeddings for downstream tasks.


### Reconstruction Utility

In [83]:

def score_graph_with_gae(graph, feature_key='feat'):
    graph = prepare_graph_for_gae(graph).to(device)
    features = graph.ndata[feature_key].float().to(device)

    with torch.no_grad():
        z = gae_encoder(graph, features)
        feat_recon = gae_attribute_decoder(z)

        attr_err = ((feat_recon - features) ** 2).sum(dim=1)

        src, dst = graph.edges()
        src = src.to(device)
        dst = dst.to(device)
        logits = gae_structure_decoder(z, src, dst)
        labels = torch.ones_like(logits)
        edge_errors = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')

        node_struct_error = torch.zeros(graph.num_nodes(), device=device)
        node_struct_error.scatter_add_(0, src, edge_errors)
        node_struct_error.scatter_add_(0, dst, edge_errors)
        degrees = (graph.out_degrees() + graph.in_degrees()).float().to(device)
        degrees = torch.clamp(degrees, min=1.0)
        node_struct_error = node_struct_error / degrees

        combined_error = node_struct_error + alpha_attr * attr_err

    return {
        'struct_error': node_struct_error.detach().cpu(),
        'attr_error': attr_err.detach().cpu(),
        'combined_error': combined_error.detach().cpu(),
    }


### GAE Embeddings + GCN (All Nodes)

In [84]:
if 'gae_train_embeddings' not in globals():
    raise RuntimeError('Run the GAE pretraining cell before launching this training step.')

class_weights = torch.tensor([0.7, 0.3], dtype=torch.float32, device=device)
gae_train_features = gae_train_embeddings.to(device)

_gae_test_graph = prepare_graph_for_gae(test_all_graph).to(device)
with torch.no_grad():
    gae_test_embeddings = gae_encoder(_gae_test_graph, _gae_test_graph.ndata['feat'].float()).detach()

gae_gcn = GCN(gae_train_features.shape[1], embedding_dim, 2).to(device)
optimizer = torch.optim.Adam(gae_gcn.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

history_gae_gcn = train_model(
    gae_gcn,
    "GAE Embeddings -> GCN",
    optimizer,
    criterion,
    train_graph_all,
    gae_train_features,
    train_labels_all,
    train_mask_all,
    test_graph_all,
    gae_test_embeddings.to(device),
    test_labels_all,
    test_mask_all,
    num_epochs=800,
    test_every=50,
    early_stopping_patience=20,
    checkpoint_path="checkpoints/GAE/gcn_embeddings/best_model_state.pt"
)

save_history(history_gae_gcn, sub_dir="GAE_gcn_embeddings")


def build_gae_embedding_features(g_eval, feats):
    g_eval = g_eval.to(device)
    feats = feats.float().to(device)
    with torch.no_grad():
        embeddings = gae_encoder(g_eval, feats)
    return embeddings


gae_gcn_timestep_metrics = report_timestep_performance(
    "GAE Embeddings -> GCN",
    gae_gcn,
    feature_builder=build_gae_embedding_features
)



Epoch 001: Loss 0.6896, Train F1 0.2153, Test Loss 0.5461, Test F1 0.1260
Epoch 050: Loss 0.3138, Train F1 0.6457, Test Loss 0.3916, Test F1 0.2688
Epoch 100: Loss 0.2479, Train F1 0.7197, Test Loss 0.3534, Test F1 0.3219
Epoch 150: Loss 0.2190, Train F1 0.7429, Test Loss 0.3270, Test F1 0.3544
Epoch 200: Loss 0.2027, Train F1 0.7567, Test Loss 0.3168, Test F1 0.3654
Epoch 250: Loss 0.1914, Train F1 0.7712, Test Loss 0.3107, Test F1 0.3712
Epoch 300: Loss 0.1829, Train F1 0.7817, Test Loss 0.3072, Test F1 0.3753
Epoch 350: Loss 0.1760, Train F1 0.7911, Test Loss 0.3042, Test F1 0.3841
Epoch 400: Loss 0.1701, Train F1 0.7989, Test Loss 0.3025, Test F1 0.3887
Epoch 450: Loss 0.1651, Train F1 0.8040, Test Loss 0.3018, Test F1 0.3997
Epoch 500: Loss 0.1606, Train F1 0.8100, Test Loss 0.3009, Test F1 0.4031
Epoch 550: Loss 0.1566, Train F1 0.8153, Test Loss 0.3001, Test F1 0.4109
Epoch 600: Loss 0.1530, Train F1 0.8190, Test Loss 0.3002, Test F1 0.4202
Epoch 650: Loss 0.1496, Train F1 0.824

### GAE Node Score + Raw Features + GCN

In [85]:
from sklearn.preprocessing import StandardScaler

if 'gae_encoder' not in globals():
    raise RuntimeError('Run the GAE pretraining cell before launching this training step.')

class_weights = torch.tensor([0.7, 0.3], dtype=torch.float32, device=device)
train_scores = score_graph_with_gae(train_all_graph)
test_scores = score_graph_with_gae(test_all_graph)

gae_score_scaler = StandardScaler()
train_anomaly_np = gae_score_scaler.fit_transform(train_scores['combined_error'].numpy().reshape(-1, 1))
test_anomaly_np = gae_score_scaler.transform(test_scores['combined_error'].numpy().reshape(-1, 1))

train_score_features = torch.tensor(train_anomaly_np, dtype=torch.float32)
test_score_features = torch.tensor(test_anomaly_np, dtype=torch.float32)

train_aug_features = torch.cat([train_feature_all.cpu(), train_score_features], dim=1).to(device)
test_aug_features = torch.cat([test_features_all.cpu(), test_score_features], dim=1).to(device)

gae_score_gcn = GCN(train_aug_features.shape[1], embedding_dim, 2).to(device)
optimizer = torch.optim.Adam(gae_score_gcn.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

history_gae_score = train_model(
    gae_score_gcn,
    "GAE Node Score + Raw Features",
    optimizer,
    criterion,
    train_graph_all,
    train_aug_features,
    train_labels_all,
    train_mask_all,
    test_graph_all,
    test_aug_features,
    test_labels_all,
    test_mask_all,
    num_epochs=800,
    test_every=50,
    early_stopping_patience=20,
    checkpoint_path='checkpoints/GAE/gcn_scores/best_model_state.pt'
)

save_history(history_gae_score, sub_dir='GAE_gcn_scores')


def build_gae_score_features(g_eval, feats):
    graph_cpu = g_eval.to('cpu')
    score_outputs = score_graph_with_gae(graph_cpu)
    anomaly_np = gae_score_scaler.transform(score_outputs['combined_error'].numpy().reshape(-1, 1))
    anomaly_tensor = torch.tensor(anomaly_np, dtype=torch.float32, device=feats.device)
    feats = feats.to(device)
    return torch.cat([feats, anomaly_tensor], dim=1)


gae_score_timestep_metrics = report_timestep_performance(
    "GAE Node Score + Raw Features",
    gae_score_gcn,
    feature_builder=build_gae_score_features
)



Epoch 001: Loss 1.2533, Train F1 0.1792, Test Loss 0.6454, Test F1 0.0439
Epoch 050: Loss 0.2427, Train F1 0.7036, Test Loss 0.4302, Test F1 0.3104
Epoch 100: Loss 0.2024, Train F1 0.7656, Test Loss 0.3564, Test F1 0.3670
Epoch 150: Loss 0.1804, Train F1 0.7919, Test Loss 0.3101, Test F1 0.4309
Epoch 200: Loss 0.1659, Train F1 0.8074, Test Loss 0.2856, Test F1 0.4784
Epoch 250: Loss 0.1551, Train F1 0.8175, Test Loss 0.2734, Test F1 0.5086
Epoch 300: Loss 0.1465, Train F1 0.8289, Test Loss 0.2670, Test F1 0.5247
Epoch 350: Loss 0.1394, Train F1 0.8372, Test Loss 0.2640, Test F1 0.5459
Epoch 400: Loss 0.1333, Train F1 0.8432, Test Loss 0.2639, Test F1 0.5600
Epoch 450: Loss 0.1281, Train F1 0.8501, Test Loss 0.2656, Test F1 0.5758
Epoch 500: Loss 0.1234, Train F1 0.8561, Test Loss 0.2682, Test F1 0.5870
Epoch 550: Loss 0.1191, Train F1 0.8621, Test Loss 0.2713, Test F1 0.5891
Epoch 600: Loss 0.1151, Train F1 0.8653, Test Loss 0.2745, Test F1 0.5875
Epoch 650: Loss 0.1114, Train F1 0.869

### Raw + GAE Embeddings Concatenation

In [86]:
class_weights = torch.tensor([0.7, 0.3], dtype=torch.float32, device=device)
train_concat_features = torch.cat([train_feature_all, gae_train_embeddings.to(device)], dim=1)

_gae_test_graph = prepare_graph_for_gae(test_all_graph).to(device)
with torch.no_grad():
    gae_test_embeddings = gae_encoder(_gae_test_graph, _gae_test_graph.ndata['feat'].float()).detach()

test_concat_features = torch.cat([test_features_all, gae_test_embeddings.to(device)], dim=1)

concat_gcn = GCN(train_concat_features.shape[1], embedding_dim, 2).to(device)
optimizer = torch.optim.Adam(concat_gcn.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

history_concat_gcn = train_model(
    concat_gcn,
    "Raw + GAE Embeddings",
    optimizer,
    criterion,
    train_graph_all,
    train_concat_features,
    train_labels_all,
    train_mask_all,
    test_graph_all,
    test_concat_features,
    test_labels_all,
    test_mask_all,
    num_epochs=800,
    test_every=50,
    early_stopping_patience=20,
    checkpoint_path="checkpoints/GAE/gcn_concat/best_model_state.pt"
)

save_history(history_concat_gcn, sub_dir="GAE_gcn_concat")


def build_gae_concat_features(g_eval, feats):
    g_eval = g_eval.to(device)
    feats = feats.float().to(device)
    with torch.no_grad():
        embeddings = gae_encoder(g_eval, feats)
    return torch.cat([feats, embeddings], dim=1)


gae_concat_timestep_metrics = report_timestep_performance(
    "Raw + GAE Embeddings",
    concat_gcn,
    feature_builder=build_gae_concat_features
)



Epoch 001: Loss 0.6485, Train F1 0.1848, Test Loss 0.6000, Test F1 0.1321
Epoch 050: Loss 0.2101, Train F1 0.7509, Test Loss 0.3079, Test F1 0.3762
Epoch 100: Loss 0.1703, Train F1 0.7993, Test Loss 0.2837, Test F1 0.4059
Epoch 150: Loss 0.1483, Train F1 0.8265, Test Loss 0.2862, Test F1 0.4230
Epoch 200: Loss 0.1330, Train F1 0.8463, Test Loss 0.2868, Test F1 0.4501
Epoch 250: Loss 0.1213, Train F1 0.8586, Test Loss 0.2857, Test F1 0.4759
Epoch 300: Loss 0.1118, Train F1 0.8706, Test Loss 0.2864, Test F1 0.5119
Epoch 350: Loss 0.1037, Train F1 0.8825, Test Loss 0.2875, Test F1 0.5445
Epoch 400: Loss 0.0967, Train F1 0.8907, Test Loss 0.2894, Test F1 0.5537
Epoch 450: Loss 0.0907, Train F1 0.8988, Test Loss 0.2935, Test F1 0.5675
Epoch 500: Loss 0.0854, Train F1 0.9046, Test Loss 0.2993, Test F1 0.5682
Epoch 550: Loss 0.0808, Train F1 0.9098, Test Loss 0.3044, Test F1 0.5668
Epoch 600: Loss 0.0768, Train F1 0.9152, Test Loss 0.3108, Test F1 0.5651
Epoch 650: Loss 0.0734, Train F1 0.919

## Local-subgraph GAE

### Pretraining Setup

In [121]:

import dgl
import dgl.sampling
import torch
import torch.nn as nn
from dgl.nn import GraphConv

if 'train_all_graph' not in globals():
    raise RuntimeError('Run the data loading and graph construction cells first so `train_all_graph` exists.')

local_gae_config = {
    'rw_length': 24,
    'rw_trials': 4,
    'rw_restart_prob': 0.3,
    'hidden_dim': 128,
    'epochs': 20,
    'samples_per_epoch': 4000,
    'lr': 5e-4,
    'weight_decay': 1e-4,
    'val_sample_size': 1024,
    'val_interval': 5,
}

local_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

base_graph_for_local_gae = dgl.add_self_loop(dgl.to_bidirected(train_all_graph, copy_ndata=True))
base_graph_for_local_gae.ndata['feat'] = base_graph_for_local_gae.ndata['feat'].float()
local_train_feats = base_graph_for_local_gae.ndata['feat']
local_num_nodes = base_graph_for_local_gae.num_nodes()

print(base_graph_for_local_gae.ndata.keys())

samples_per_epoch = local_gae_config['samples_per_epoch']
if samples_per_epoch is None:
    samples_per_epoch = local_num_nodes
samples_per_epoch = max(1, min(local_num_nodes, samples_per_epoch))

val_sample_size = max(0, int(local_gae_config.get('val_sample_size', 0)))
if val_sample_size > 0:
    validation_nodes = torch.randperm(local_num_nodes)[:min(local_num_nodes, val_sample_size)]
else:
    validation_nodes = torch.tensor([], dtype=torch.long)
validation_interval = max(1, int(local_gae_config.get('val_interval', 5)))


dict_keys(['feat', 'label', 'timestep'])


### Random-Walk Sampling Utilities

In [96]:

def anonymized_random_walk_subgraph(
    graph,
    feature_tensor,
    target_nid,
    *,
    walk_length,
    num_traces,
    restart_prob,
    fallback_khop=2,
):
    """Sample a random-walk-with-restart ego graph and anonymize the target node."""
    seeds = torch.full((num_traces,), int(target_nid), dtype=torch.long)
    try:
        traces, _ = dgl.sampling.random_walk(
            graph,
            seeds,
            length=walk_length,
            restart_prob=restart_prob,
        )
        visited = traces.reshape(-1)
        visited = visited[visited >= 0]
        if visited.numel() == 0:
            visited = torch.tensor([int(target_nid)], dtype=torch.long)
        else:
            visited = visited.unique()
    except dgl.DGLError:
        khop_result = dgl.khop_in_subgraph(graph, torch.tensor([int(target_nid)]), fallback_khop)
        if isinstance(khop_result, tuple):
            subg_tmp, induced = khop_result
            if isinstance(induced, (list, tuple)):
                visited = torch.as_tensor(induced[0], dtype=torch.long)
            else:
                visited = torch.as_tensor(induced, dtype=torch.long)
            graph = subg_tmp
        else:
            graph = khop_result
            visited = graph.ndata[dgl.NID].long()
    target_tensor = torch.tensor([int(target_nid)], dtype=torch.long)
    unique_nodes = torch.unique(torch.cat([visited, target_tensor]))

    subg = dgl.node_subgraph(graph, unique_nodes)
    parent_nids = subg.ndata[dgl.NID].long()
    sub_feats = feature_tensor[parent_nids].clone()
    target_mask = (parent_nids == int(target_nid))
    target_indices = target_mask.nonzero(as_tuple=False).view(-1)
    if target_indices.numel() == 0:
        raise RuntimeError(f'Unable to locate target node {target_nid} inside the sampled subgraph.')
    target_idx = target_indices[0].item()
    target_feat = sub_feats[target_idx].clone()
    sub_feats[target_idx] = 0.0
    subg = dgl.add_self_loop(subg)
    return subg, sub_feats, target_idx, target_feat


### Model Definition

In [97]:

class LocalSubgraphGAE(nn.Module):
    def __init__(self, in_feats, hidden_dim):
        super().__init__()
        self.encoder = GraphConv(in_feats, hidden_dim, allow_zero_in_degree=True)
        self.decoder = GraphConv(hidden_dim, in_feats, allow_zero_in_degree=True)

    def forward(self, g, feat):
        h = torch.relu(self.encoder(g, feat))
        recon = self.decoder(g, h)
        return recon, h


local_subgraph_gae = LocalSubgraphGAE(local_train_feats.shape[1], local_gae_config['hidden_dim']).to(local_device)
local_subgraph_optimizer = torch.optim.Adam(
    local_subgraph_gae.parameters(),
    lr=local_gae_config['lr'],
    weight_decay=local_gae_config['weight_decay'],
)
target_recon_loss = nn.MSELoss()
local_subgraph_history = []
local_subgraph_val_history = []


### Node Scoring Helper

In [98]:

def score_nodes_with_local_subgraph_gae(
    graph,
    target_nodes=None,
    *,
    walk_length=None,
    num_traces=None,
    restart_prob=None,
    model=None,
):
    if model is None:
        model = local_subgraph_gae
    if walk_length is None:
        walk_length = local_gae_config['rw_length']
    if num_traces is None:
        num_traces = local_gae_config['rw_trials']
    if restart_prob is None:
        restart_prob = local_gae_config['rw_restart_prob']

    graph_cpu = dgl.add_self_loop(dgl.to_bidirected(graph, copy_ndata=True))
    feature_tensor = graph_cpu.ndata['feat'].float()

    if target_nodes is None:
        target_nodes = torch.arange(graph_cpu.num_nodes())
    else:
        target_nodes = torch.as_tensor(target_nodes, dtype=torch.long)

    model.eval()
    l2_scores = torch.zeros(target_nodes.shape[0])
    mse_scores = torch.zeros_like(l2_scores)
    with torch.no_grad():
        for idx, node_id in enumerate(target_nodes.tolist()):
            subg, anonymized_feat, target_idx, target_feat = anonymized_random_walk_subgraph(
                graph_cpu,
                feature_tensor,
                int(node_id),
                walk_length=walk_length,
                num_traces=num_traces,
                restart_prob=restart_prob,
            )
            recon, _ = model(subg.to(local_device), anonymized_feat.to(local_device))
            diff = recon[target_idx] - target_feat.to(local_device)
            l2_scores[idx] = torch.norm(diff, p=2).item()
            mse_scores[idx] = torch.mean(diff.pow(2)).item()
    return {
        'node_ids': target_nodes,
        'l2_recon_error': l2_scores,
        'per_node_mse': mse_scores,
    }


### Training Loop

In [99]:

print('Training local-subgraph anonymized GAE (random-walk views)...')
all_train_nodes = torch.arange(local_num_nodes)
for epoch in range(1, local_gae_config['epochs'] + 1):
    epoch_nodes = all_train_nodes[torch.randperm(local_num_nodes)[:samples_per_epoch]]
    running_loss = 0.0
    local_subgraph_gae.train()
    for node_id in epoch_nodes.tolist():
        subg, anonymized_feat, target_idx, target_feat = anonymized_random_walk_subgraph(
            base_graph_for_local_gae,
            local_train_feats,
            int(node_id),
            walk_length=local_gae_config['rw_length'],
            num_traces=local_gae_config['rw_trials'],
            restart_prob=local_gae_config['rw_restart_prob'],
        )
        subg = subg.to(local_device)
        anonymized_feat = anonymized_feat.to(local_device)
        target_feat = target_feat.to(local_device)

        recon, _ = local_subgraph_gae(subg, anonymized_feat)
        loss = target_recon_loss(recon[target_idx], target_feat)

        local_subgraph_optimizer.zero_grad()
        loss.backward()
        local_subgraph_optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(epoch_nodes)
    local_subgraph_history.append(avg_loss)
    print(f"[Local-GAE] Epoch {epoch:03d} | Targets {len(epoch_nodes)} | MSE {avg_loss:.4f}")

    if validation_nodes.numel() > 0 and (epoch % validation_interval == 0):
        val_scores = score_nodes_with_local_subgraph_gae(
            base_graph_for_local_gae,
            target_nodes=validation_nodes,
            model=local_subgraph_gae,
            walk_length=local_gae_config['rw_length'],
            num_traces=local_gae_config['rw_trials'],
            restart_prob=local_gae_config['rw_restart_prob'],
        )
        val_mse = val_scores['per_node_mse'].mean().item()
        local_subgraph_val_history.append((epoch, val_mse))
        print(f"    Validation ({validation_nodes.numel()} nodes) MSE: {val_mse:.4f}")

print('Stored `local_subgraph_gae`, `local_subgraph_history`, and `score_nodes_with_local_subgraph_gae` for downstream evaluation.')


Training local-subgraph anonymized GAE (random-walk views)...
[Local-GAE] Epoch 001 | Targets 4000 | MSE 0.8302
[Local-GAE] Epoch 002 | Targets 4000 | MSE 0.6930
[Local-GAE] Epoch 003 | Targets 4000 | MSE 0.9344
[Local-GAE] Epoch 004 | Targets 4000 | MSE 0.7114
[Local-GAE] Epoch 005 | Targets 4000 | MSE 0.6895
    Validation (1024 nodes) MSE: 0.7323
[Local-GAE] Epoch 006 | Targets 4000 | MSE 0.7172
[Local-GAE] Epoch 007 | Targets 4000 | MSE 0.7084
[Local-GAE] Epoch 008 | Targets 4000 | MSE 0.6718
[Local-GAE] Epoch 009 | Targets 4000 | MSE 0.7491
[Local-GAE] Epoch 010 | Targets 4000 | MSE 0.7281
    Validation (1024 nodes) MSE: 0.7404
[Local-GAE] Epoch 011 | Targets 4000 | MSE 0.6645
[Local-GAE] Epoch 012 | Targets 4000 | MSE 0.7693
[Local-GAE] Epoch 013 | Targets 4000 | MSE 0.7336
[Local-GAE] Epoch 014 | Targets 4000 | MSE 0.6663
[Local-GAE] Epoch 015 | Targets 4000 | MSE 0.7629
    Validation (1024 nodes) MSE: 0.7225
[Local-GAE] Epoch 016 | Targets 4000 | MSE 0.6901
[Local-GAE] Epoch 

### Local GAE Embeddings + Raw Features + GCN

In [92]:

import torch
import torch.nn as nn

if 'local_subgraph_gae' not in globals():
    raise RuntimeError('Train the local-subgraph GAE cells before running this step.')

class_weights = torch.tensor([0.7, 0.3], dtype=torch.float32, device=device)
local_subgraph_gae = local_subgraph_gae.to(local_device).eval()

with torch.no_grad():
    base_graph = base_graph_for_local_gae.to(local_device)
    base_feats = base_graph.ndata['feat'].to(local_device)
    _, local_train_hidden = local_subgraph_gae(base_graph, base_feats)

    test_graph_local = dgl.add_self_loop(dgl.to_bidirected(test_all_graph, copy_ndata=True))
    test_graph_local = test_graph_local.to(local_device)
    test_graph_local.ndata['feat'] = test_graph_local.ndata['feat'].float()
    _, local_test_hidden = local_subgraph_gae(test_graph_local, test_graph_local.ndata['feat'])

local_train_hidden = local_train_hidden.to(device)
local_test_hidden = local_test_hidden.to(device)

local_aug_train = torch.cat([train_feature_all, local_train_hidden], dim=1)
local_aug_test = torch.cat([test_features_all, local_test_hidden], dim=1)

local_concat_gcn = GCN(local_aug_train.shape[1], embedding_dim, 2).to(device)
optimizer = torch.optim.Adam(local_concat_gcn.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

history_local_concat = train_model(
    local_concat_gcn,
    'Local GAE Embeddings + Raw Features',
    optimizer,
    criterion,
    train_graph_all,
    local_aug_train,
    train_labels_all,
    train_mask_all,
    test_graph_all,
    local_aug_test,
    test_labels_all,
    test_mask_all,
    num_epochs=600,
    test_every=50,
    early_stopping_patience=20,
    checkpoint_path='checkpoints/LocalGAE/gcn_concat/best_model_state.pt'
)

save_history(history_local_concat, sub_dir='LocalGAE_gcn_concat')




def build_local_concat_features(g_eval, feats):
    g_local = g_eval.to(local_device)
    feats_local = g_local.ndata['feat'].float()
    with torch.no_grad():
        _, hidden = local_subgraph_gae(g_local, feats_local)
    hidden = hidden.to(device)
    feats = feats.to(device)
    return torch.cat([feats, hidden], dim=1)


local_concat_timestep_metrics = report_timestep_performance(
    "Local GAE Embeddings + Raw Features",
    local_concat_gcn,
    feature_builder=build_local_concat_features
)



Epoch 001: Loss 0.7805, Train F1 0.2749, Test Loss 0.7512, Test F1 0.1666
Epoch 050: Loss 0.2076, Train F1 0.7472, Test Loss 0.2943, Test F1 0.4182
Epoch 100: Loss 0.1713, Train F1 0.7946, Test Loss 0.2684, Test F1 0.5194
Epoch 150: Loss 0.1510, Train F1 0.8204, Test Loss 0.2646, Test F1 0.5582
Epoch 200: Loss 0.1358, Train F1 0.8402, Test Loss 0.2635, Test F1 0.5692
Epoch 250: Loss 0.1234, Train F1 0.8577, Test Loss 0.2653, Test F1 0.5746
Epoch 300: Loss 0.1128, Train F1 0.8702, Test Loss 0.2710, Test F1 0.5811
Epoch 350: Loss 0.1039, Train F1 0.8811, Test Loss 0.2805, Test F1 0.5848
Epoch 400: Loss 0.0963, Train F1 0.8896, Test Loss 0.2917, Test F1 0.5872
Epoch 450: Loss 0.0898, Train F1 0.8970, Test Loss 0.3005, Test F1 0.5918
Epoch 500: Loss 0.0844, Train F1 0.9044, Test Loss 0.3066, Test F1 0.5938
Epoch 550: Loss 0.0796, Train F1 0.9114, Test Loss 0.3115, Test F1 0.5971
Epoch 600: Loss 0.0755, Train F1 0.9151, Test Loss 0.3156, Test F1 0.5996
Local GAE Embeddings + Raw Features La

### Local GAE Node Score + Raw Features + GCN

In [93]:
class_weights = torch.tensor([0.7, 0.3], dtype=torch.float32, device=device)

if 'local_subgraph_gae' not in globals() or 'score_nodes_with_local_subgraph_gae' not in globals():
    raise RuntimeError('Train the local-subgraph GAE cells before running this step.')

train_score_dict = score_nodes_with_local_subgraph_gae(
    train_all_graph,
    model=local_subgraph_gae,
    walk_length=local_gae_config['rw_length'],
    num_traces=local_gae_config['rw_trials'],
    restart_prob=local_gae_config['rw_restart_prob'],
)

test_score_dict = score_nodes_with_local_subgraph_gae(
    test_all_graph,
    model=local_subgraph_gae,
    walk_length=local_gae_config['rw_length'],
    num_traces=local_gae_config['rw_trials'],
    restart_prob=local_gae_config['rw_restart_prob'],
)

train_scores = train_score_dict['per_node_mse'].view(-1, 1)
test_scores = test_score_dict['per_node_mse'].view(-1, 1)

local_score_mean = train_scores.mean()
local_score_std = train_scores.std().clamp_min(1e-6)
train_scores = (train_scores - local_score_mean) / local_score_std
test_scores = (test_scores - local_score_mean) / local_score_std

local_score_train = torch.cat([train_feature_all.cpu(), train_scores], dim=1).to(device)
local_score_test = torch.cat([test_features_all.cpu(), test_scores], dim=1).to(device)

local_score_gcn = GCN(local_score_train.shape[1], embedding_dim, 2).to(device)
optimizer = torch.optim.Adam(local_score_gcn.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

history_local_score = train_model(
    local_score_gcn,
    "Local GAE Node Score + Raw Features",
    optimizer,
    criterion,
    train_graph_all,
    local_score_train,
    train_labels_all,
    train_mask_all,
    test_graph_all,
    local_score_test,
    test_labels_all,
    test_mask_all,
    num_epochs=600,
    test_every=50,
    early_stopping_patience=20,
    checkpoint_path='checkpoints/LocalGAE/gcn_scores/best_model_state.pt'
)

save_history(history_local_score, sub_dir='LocalGAE_gcn_scores')


def build_local_score_features(g_eval, feats):
    graph_cpu = g_eval.to('cpu')
    score_dict = score_nodes_with_local_subgraph_gae(
        graph_cpu,
        model=local_subgraph_gae,
        walk_length=local_gae_config['rw_length'],
        num_traces=local_gae_config['rw_trials'],
        restart_prob=local_gae_config['rw_restart_prob'],
    )
    scores = score_dict['per_node_mse'].view(-1, 1)
    scores = (scores - local_score_mean) / local_score_std
    scores = scores.to(device)
    feats = feats.to(device)
    return torch.cat([feats, scores], dim=1)


local_score_timestep_metrics = report_timestep_performance(
    "Local GAE Node Score + Raw Features",
    local_score_gcn,
    feature_builder=build_local_score_features
)



Epoch 001: Loss 0.6842, Train F1 0.1954, Test Loss 0.9821, Test F1 0.1193
Epoch 050: Loss 0.2252, Train F1 0.7286, Test Loss 0.3711, Test F1 0.3175
Epoch 100: Loss 0.1839, Train F1 0.7871, Test Loss 0.3033, Test F1 0.4139
Epoch 150: Loss 0.1621, Train F1 0.8102, Test Loss 0.2842, Test F1 0.4532
Epoch 200: Loss 0.1481, Train F1 0.8255, Test Loss 0.2823, Test F1 0.4722
Epoch 250: Loss 0.1374, Train F1 0.8398, Test Loss 0.2856, Test F1 0.4892
Epoch 300: Loss 0.1286, Train F1 0.8482, Test Loss 0.2894, Test F1 0.5078
Epoch 350: Loss 0.1214, Train F1 0.8591, Test Loss 0.2926, Test F1 0.5298
Epoch 400: Loss 0.1150, Train F1 0.8665, Test Loss 0.2972, Test F1 0.5494
Epoch 450: Loss 0.1095, Train F1 0.8757, Test Loss 0.3015, Test F1 0.5622
Epoch 500: Loss 0.1045, Train F1 0.8818, Test Loss 0.3050, Test F1 0.5667
Epoch 550: Loss 0.1000, Train F1 0.8881, Test Loss 0.3074, Test F1 0.5664
Epoch 600: Loss 0.0958, Train F1 0.8922, Test Loss 0.3092, Test F1 0.5686
Local GAE Node Score + Raw Features La

## CoLA + GCN

### Pretraining Setup

In [130]:
import dgl
import dgl.sampling
import torch
import torch.nn as nn
from dgl.nn import GraphConv

if 'base_graph_for_local_gae' not in globals() or 'anonymized_random_walk_subgraph' not in globals():
    raise RuntimeError('Run the local-subgraph GAE cells first to set up the sampling utilities.')

cola_config = {
    'hidden_dim': 128,
    'epochs': 30,
    'samples_per_epoch': 4000,
    'rw_length': local_gae_config['rw_length'],
    'rw_trials': local_gae_config['rw_trials'],
    'rw_restart_prob': local_gae_config['rw_restart_prob'],
    'lr': 5e-4,
    'weight_decay': 0.0,
    'neg_trials': 1,
    'eval_neg_trials': 2,
}

cola_graph = base_graph_for_local_gae
cola_features = cola_graph.ndata['feat']
cola_num_nodes = cola_graph.num_nodes()
cola_samples_per_epoch = min(cola_num_nodes, cola_config['samples_per_epoch'])


def _group_nodes_by_timestep(timestep_tensor):
    sanitized = timestep_tensor.to(torch.long).view(-1).cpu()
    groups = {}
    for ts in torch.unique(sanitized).tolist():
        mask = sanitized == ts
        groups[int(ts)] = mask.nonzero(as_tuple=False).view(-1)
    return sanitized, groups


if 'timestep' not in cola_graph.ndata:
    raise KeyError(
        "cola_graph.ndata is missing 'timestep'. Re-run the data loading cell that calls create_dgl_graph so timesteps are attached to each node."
    )

cola_node_timesteps, cola_nodes_by_timestep = _group_nodes_by_timestep(cola_graph.ndata['timestep'])
cola_device = local_device


### Contrastive Modules and Utilities

In [131]:

class CoLAEncoder(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.conv = GraphConv(in_dim, hidden_dim, allow_zero_in_degree=True)

    def forward(self, g, feat):
        return torch.relu(self.conv(g, feat))


class BilinearDiscriminator(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.scorer = nn.Bilinear(hidden_dim, hidden_dim, 1)

    def forward(self, node_embed, graph_embed):
        return self.scorer(node_embed, graph_embed).squeeze(-1)


def _graph_readout(node_embeddings):
    return node_embeddings.mean(dim=0)


def _sample_negative_node(exclude_node):
    candidate = None
    group = cola_nodes_by_timestep.get(int(cola_node_timesteps[exclude_node].item()))
    if group is not None and group.numel() > 0:
        rand_idx = torch.randint(0, group.numel(), (1,)).item()
        candidate = int(group[rand_idx])
        if candidate == exclude_node and group.numel() > 1:
            candidate = int(group[(rand_idx + 1) % group.numel()])
    if candidate is None:
        candidate = int(torch.randint(0, cola_num_nodes, (1,)))
        if candidate == exclude_node:
            candidate = int((candidate + 1) % cola_num_nodes)
    return candidate


cola_encoder = CoLAEncoder(cola_features.shape[1], cola_config['hidden_dim']).to(cola_device)
cola_discriminator = BilinearDiscriminator(cola_config['hidden_dim']).to(cola_device)
cola_optimizer = torch.optim.Adam(
    list(cola_encoder.parameters()) + list(cola_discriminator.parameters()),
    lr=cola_config['lr'],
    weight_decay=cola_config['weight_decay'],
)
cola_bce = nn.BCEWithLogitsLoss()
cola_train_history = []


def cola_forward_step(node_id):
    subg_pos, feat_pos, idx_pos, _ = anonymized_random_walk_subgraph(
        cola_graph,
        cola_features,
        node_id,
        walk_length=cola_config['rw_length'],
        num_traces=cola_config['rw_trials'],
        restart_prob=cola_config['rw_restart_prob'],
    )
    subg_pos = subg_pos.to(cola_device)
    feat_pos = feat_pos.to(cola_device)
    embeds_pos = cola_encoder(subg_pos, feat_pos)
    target_embed = embeds_pos[idx_pos]
    graph_embed_pos = _graph_readout(embeds_pos)

    neg_scores = []
    neg_labels = []
    for _ in range(cola_config['neg_trials']):
        neg_node = _sample_negative_node(node_id)
        subg_neg, feat_neg, _, _ = anonymized_random_walk_subgraph(
            cola_graph,
            cola_features,
            neg_node,
            walk_length=cola_config['rw_length'],
            num_traces=cola_config['rw_trials'],
            restart_prob=cola_config['rw_restart_prob'],
        )
        subg_neg = subg_neg.to(cola_device)
        feat_neg = feat_neg.to(cola_device)
        embeds_neg = cola_encoder(subg_neg, feat_neg)
        graph_embed_neg = _graph_readout(embeds_neg)
        neg_scores.append(cola_discriminator(target_embed, graph_embed_neg).unsqueeze(0))
        neg_labels.append(torch.zeros(1, device=cola_device))

    pos_score = cola_discriminator(target_embed, graph_embed_pos)
    pos_label = torch.ones(1, device=cola_device)
    neg_scores = torch.cat(neg_scores) if neg_scores else torch.tensor([], device=cola_device)
    neg_labels = torch.cat(neg_labels) if neg_labels else torch.tensor([], device=cola_device)
    return pos_score, pos_label, neg_scores, neg_labels


### Training Loop

In [132]:

print('Training CoLA contrastive model...')
all_nodes = torch.arange(cola_num_nodes)
for epoch in range(1, cola_config['epochs'] + 1):
    epoch_nodes = all_nodes[torch.randperm(cola_num_nodes)[:cola_samples_per_epoch]]
    running_loss = 0.0
    for node_id in epoch_nodes.tolist():
        cola_encoder.train()
        cola_discriminator.train()
        pos_score, pos_label, neg_scores, neg_labels = cola_forward_step(int(node_id))

        loss = cola_bce(pos_score.unsqueeze(0), pos_label)
        if neg_scores.numel() > 0:
            loss = loss + cola_bce(neg_scores, neg_labels)

        cola_optimizer.zero_grad()
        loss.backward()
        cola_optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(epoch_nodes)
    cola_train_history.append(avg_loss)
    print(f"[CoLA] Epoch {epoch:03d} | Nodes {len(epoch_nodes)} | Loss {avg_loss:.4f}")

print('Stored `cola_encoder`, `cola_discriminator`, and `cola_train_history` for downstream scoring.')


Training CoLA contrastive model...
[CoLA] Epoch 001 | Nodes 4000 | Loss 1.0934
[CoLA] Epoch 002 | Nodes 4000 | Loss 0.8198
[CoLA] Epoch 003 | Nodes 4000 | Loss 0.6772
[CoLA] Epoch 004 | Nodes 4000 | Loss 0.6095
[CoLA] Epoch 005 | Nodes 4000 | Loss 0.7126
[CoLA] Epoch 006 | Nodes 4000 | Loss 0.5519
[CoLA] Epoch 007 | Nodes 4000 | Loss 0.5419
[CoLA] Epoch 008 | Nodes 4000 | Loss 0.4986
[CoLA] Epoch 009 | Nodes 4000 | Loss 0.5730
[CoLA] Epoch 010 | Nodes 4000 | Loss 0.5166
[CoLA] Epoch 011 | Nodes 4000 | Loss 0.5217
[CoLA] Epoch 012 | Nodes 4000 | Loss 0.5966
[CoLA] Epoch 013 | Nodes 4000 | Loss 0.5128
[CoLA] Epoch 014 | Nodes 4000 | Loss 0.5035
[CoLA] Epoch 015 | Nodes 4000 | Loss 0.5033
[CoLA] Epoch 016 | Nodes 4000 | Loss 0.5220
[CoLA] Epoch 017 | Nodes 4000 | Loss 0.4844
[CoLA] Epoch 018 | Nodes 4000 | Loss 0.5948
[CoLA] Epoch 019 | Nodes 4000 | Loss 0.4952
[CoLA] Epoch 020 | Nodes 4000 | Loss 0.4948
[CoLA] Epoch 021 | Nodes 4000 | Loss 1.3308
[CoLA] Epoch 022 | Nodes 4000 | Loss 0.54

### CoLA Embeddings + Raw Features

In [133]:
import dgl
import torch
import torch.nn as nn

if 'cola_encoder' not in globals():
    raise RuntimeError('Train the CoLA encoder cell before running this step.')

class_weights = torch.tensor([0.7, 0.3], dtype=torch.float32, device=device)

def _prep_graph(base_graph):
    return dgl.add_self_loop(dgl.to_bidirected(base_graph, copy_ndata=True)).to(device)

cola_train_graph = _prep_graph(train_all_graph)
cola_test_graph = _prep_graph(test_all_graph)
train_labels = cola_train_graph.ndata['label']
train_mask = (train_labels >= 0)
test_labels = cola_test_graph.ndata['label']
test_mask = (test_labels >= 0)

with torch.no_grad():
    base_train_feats = cola_train_graph.ndata['feat'].float()
    base_test_feats = cola_test_graph.ndata['feat'].float()
    cola_train_embeddings = cola_encoder(cola_train_graph, base_train_feats)
    cola_test_embeddings = cola_encoder(cola_test_graph, base_test_feats)

aug_train_feats = torch.cat([base_train_feats, cola_train_embeddings], dim=1)
aug_test_feats = torch.cat([base_test_feats, cola_test_embeddings], dim=1)

cola_gcn = GCN(aug_train_feats.shape[1], embedding_dim, 2).to(device)
optimizer = torch.optim.Adam(cola_gcn.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

history_cola_gcn = train_model(
    cola_gcn,
    "CoLA Embeddings + Raw",
    optimizer,
    criterion,
    cola_train_graph,
    aug_train_feats,
    train_labels.to(device),
    train_mask,
    cola_test_graph,
    aug_test_feats,
    test_labels.to(device),
    test_mask,
    num_epochs=300,
    test_every=50,
    early_stopping_patience=10,
    checkpoint_path="checkpoints/CoLA/gcn_concat/best_model_state.pt"
)

save_history(history_cola_gcn, sub_dir="CoLA_gcn_concat")


def build_cola_concat_features(g_eval, feats):
    g_for_encoder = g_eval.to(cola_device)
    feats_for_encoder = g_for_encoder.ndata['feat'].float()
    with torch.no_grad():
        embeddings = cola_encoder(g_for_encoder, feats_for_encoder)
    embeddings = embeddings.to(device)
    feats = feats.to(device)
    return torch.cat([feats, embeddings], dim=1)


cola_gcn_timestep_metrics = report_timestep_performance(
    "CoLA Embeddings + Raw Features",
    cola_gcn,
    feature_builder=build_cola_concat_features
)



Epoch 001: Loss 1.1419, Train F1 0.2184, Test Loss 0.8135, Test F1 0.1335
Epoch 050: Loss 0.2274, Train F1 0.7221, Test Loss 0.3582, Test F1 0.3316
Epoch 100: Loss 0.1875, Train F1 0.7741, Test Loss 0.3155, Test F1 0.3901
Epoch 150: Loss 0.1639, Train F1 0.8060, Test Loss 0.3022, Test F1 0.4100
Epoch 200: Loss 0.1466, Train F1 0.8292, Test Loss 0.2979, Test F1 0.4183
Epoch 250: Loss 0.1321, Train F1 0.8484, Test Loss 0.2958, Test F1 0.4563
Epoch 300: Loss 0.1195, Train F1 0.8635, Test Loss 0.2936, Test F1 0.4977
CoLA Embeddings + Raw Last Classification Report on Labeled Test Graph:
              precision    recall  f1-score   support

     illicit       0.67      0.40      0.50      1083
       licit       0.96      0.99      0.97     15587

    accuracy                           0.95     16670
   macro avg       0.82      0.69      0.74     16670
weighted avg       0.94      0.95      0.94     16670

CoLA Embeddings + Raw Best Classification Report on Labeled Test Graph at epoch 300

### CoLA Node Score + Raw Features + GCN

In [134]:
class_weights = torch.tensor([0.7, 0.3], dtype=torch.float32, device=device)

if 'cola_encoder' not in globals() or 'cola_discriminator' not in globals():
    raise RuntimeError('Train the CoLA encoder cell before running this step.')

def compute_cola_scores(graph):
    graph_proc = dgl.add_self_loop(dgl.to_bidirected(graph, copy_ndata=True)).to(cola_device)
    feats = graph_proc.ndata['feat'].float()
    with torch.no_grad():
        node_embeddings = cola_encoder(graph_proc, feats)
        graph_embedding = _graph_readout(node_embeddings)
        repeated = graph_embedding.unsqueeze(0).repeat(node_embeddings.shape[0], 1)
        logits = cola_discriminator(node_embeddings, repeated).unsqueeze(1)
    return logits.cpu()

train_scores = compute_cola_scores(train_all_graph)
test_scores = compute_cola_scores(test_all_graph)

cola_score_mean = train_scores.mean()
cola_score_std = train_scores.std().clamp_min(1e-6)
train_scores = (train_scores - cola_score_mean) / cola_score_std
test_scores = (test_scores - cola_score_mean) / cola_score_std

cola_score_train = torch.cat([train_feature_all.cpu(), train_scores], dim=1).to(device)
cola_score_test = torch.cat([test_features_all.cpu(), test_scores], dim=1).to(device)

cola_score_graph = dgl.add_self_loop(dgl.to_bidirected(train_all_graph, copy_ndata=True)).to(device)
cola_score_test_graph = dgl.add_self_loop(dgl.to_bidirected(test_all_graph, copy_ndata=True)).to(device)
train_labels = cola_score_graph.ndata['label']
train_mask = (train_labels >= 0)
test_labels = cola_score_test_graph.ndata['label']
test_mask = (test_labels >= 0)

cola_score_gcn = GCN(cola_score_train.shape[1], embedding_dim, 2).to(device)
optimizer = torch.optim.Adam(cola_score_gcn.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

history_cola_score = train_model(
    cola_score_gcn,
    "CoLA Node Score + Raw Features",
    optimizer,
    criterion,
    cola_score_graph,
    cola_score_train,
    train_labels.to(device),
    train_mask,
    cola_score_test_graph,
    cola_score_test,
    test_labels.to(device),
    test_mask,
    num_epochs=300,
    test_every=50,
    early_stopping_patience=10,
    checkpoint_path='checkpoints/CoLA/gcn_scores/best_model_state.pt'
)

save_history(history_cola_score, sub_dir='CoLA_gcn_scores')


def build_cola_score_features(g_eval, feats):
    scores = compute_cola_scores(g_eval)
    scores = (scores - cola_score_mean) / cola_score_std
    scores = scores.to(device)
    feats = feats.to(device)
    return torch.cat([feats, scores], dim=1)


cola_score_timestep_metrics = report_timestep_performance(
    "CoLA Node Score + Raw Features",
    cola_score_gcn,
    feature_builder=build_cola_score_features
)



Epoch 001: Loss 0.6947, Train F1 0.2167, Test Loss 0.6692, Test F1 0.0666
Epoch 050: Loss 0.2198, Train F1 0.7306, Test Loss 0.3207, Test F1 0.4066
Epoch 100: Loss 0.1794, Train F1 0.7856, Test Loss 0.2701, Test F1 0.5203
Epoch 150: Loss 0.1579, Train F1 0.8090, Test Loss 0.2549, Test F1 0.5770
Epoch 200: Loss 0.1436, Train F1 0.8297, Test Loss 0.2587, Test F1 0.5926
Epoch 250: Loss 0.1326, Train F1 0.8425, Test Loss 0.2715, Test F1 0.5941
Epoch 300: Loss 0.1236, Train F1 0.8557, Test Loss 0.2910, Test F1 0.5714
CoLA Node Score + Raw Features Last Classification Report on Labeled Test Graph:
              precision    recall  f1-score   support

     illicit       0.81      0.44      0.57      1083
       licit       0.96      0.99      0.98     15587

    accuracy                           0.96     16670
   macro avg       0.89      0.72      0.77     16670
weighted avg       0.95      0.96      0.95     16670

CoLA Node Score + Raw Features Best Classification Report on Labeled Test 