In [None]:
import os
# this must be here before cuda init
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import torch.nn as nn
import random
import numpy as np
import json

from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter

from script_classification.utilities.graph_stats import *
from script_classification.data_loader import BitcoinScriptsDataset
from script_classification.view_augmenter import ViewAugmenter
from script_classification.models import GraphEncoder
from script_classification.engine import *
from script_classification.losses import *
from script_classification.utilities.graph_ops import *
from script_classification.evaluation.metrics import *
from script_classification.evaluation.embeddings import *

In [None]:
SEED = 64
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
config = json.load(open("config.json"))
DATA_ROOT = config["data_root"]
SAVES_ROOT = config["saves_root"]
MODEL_SAVE_FILENAME = config["model_save_filename"]
MODEL_SAVE_PATH = os.path.join(SAVES_ROOT, MODEL_SAVE_FILENAME)
LOGS_ROOT = os.path.join(SAVES_ROOT, "logs_small")
MIN_NODES_PER_GRAPH = config["min_nodes_per_graph"]
MAX_NODES_PER_GRAPH = config["max_nodes_per_graph"]
BATCH_SIZE = config["batch_size"]
EPOCHS = config["epochs"]
WARMUP_EPOCHS = config["warmup_epochs"]
LEARNING_RATE = config["learning_rate"]
HIDDEN_DIM = config["hidden_dim"]
OUT_DIM = config["out_dim"]
TEMPERATURE = config["temperature"]
LAMBDA_VAL = config["lambda_val"]  # Weight for the graph-level loss
PATIENCE = config["patience"]
PROJ_DIM = config["proj_dim"]
NUM_CLUSTERS = config["num_clusters"]
KNN_K = config["knn_k"]
BLEND_ALPHA = config["blend_alpha"]

In [None]:
dataset = BitcoinScriptsDataset(
    root=DATA_ROOT, 
    min_nodes_per_graph=MIN_NODES_PER_GRAPH, 
    max_nodes_per_graph=MAX_NODES_PER_GRAPH
)

graph_id_to_idx_map = {}
for i in range(len(dataset)):
    d = dataset.get(i)
    gid = d.graph_id
    try:
        key = gid
        graph_id_to_idx_map[key] = i
    except TypeError:
        graph_id_to_idx_map[str(gid)] = i

In [None]:
EDGE_DIM = getattr(dataset, "num_edge_features", dataset[0].edge_attr.size(1))

In [None]:
n = len(dataset)
n_train = int(0.7 * n)
n_val   = int(0.15 * n)
n_test  = n - n_train - n_val

gen = torch.Generator().manual_seed(SEED)
train_raw, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [n_train, n_val, n_test], generator=gen
)

train_dataset = train_raw

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
)

In [None]:
encoder = GraphEncoder(
    in_channels=dataset.num_node_features,
    hidden_channels=64, # divisible by HEADS
    out_channels=OUT_DIM,
    edge_dim=EDGE_DIM
).to(device)


proj_head = nn.Sequential(
    nn.Linear(encoder.out_dim, PROJ_DIM),
    nn.GELU(),
    nn.Linear(PROJ_DIM, PROJ_DIM),
).to(device)


# AdamW + no weight decay on norm/bias
decay, no_decay = [], []
for n, p in encoder.named_parameters():
    if not p.requires_grad: 
        continue
    if n.endswith("bias") or "norm" in n.lower():
        no_decay.append(p)
    else:
        decay.append(p)

optimizer = torch.optim.AdamW(
    [
        {"params": encoder.parameters(), "weight_decay": 1e-4},
        {"params": proj_head.parameters(), "weight_decay": 1e-4}
    ], 
    lr=LEARNING_RATE
)

In [None]:
warmup_scheduler = LinearLR(optimizer, start_factor=1e-2, total_iters=WARMUP_EPOCHS)
decay_scheduler  = CosineAnnealingLR(optimizer, T_max=max(2, EPOCHS - WARMUP_EPOCHS), eta_min=LEARNING_RATE * 0.1)
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[WARMUP_EPOCHS])

In [None]:
summary_writer = SummaryWriter(LOGS_ROOT)

In [None]:
augmenter = ViewAugmenter(
    block_height_col=1,
    block_height_scale_range=(0.99, 1.01),
    block_height_shift_range=(-3.0, 3.0),
    degree_cols=(2, 3),
    degree_jitter=0.05,
    value_col=0,
    value_jitter=0.05
).to(device)

In [None]:
def evaluate_root_clustering(encoder, loader, device, n_clusters=10):
    Z = embed_roots(encoder, loader, device)  # [G, D], NumPy
    # in case someone swaps embed_roots to return a tensor:
    if hasattr(Z, "detach"):
        Z = Z.detach().cpu().numpy()
    if Z.shape[0] < n_clusters:
        return -1.0, None, None
    km = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    labels = km.fit_predict(Z)
    score = silhouette_score(Z, labels, metric="cosine")
    return score, Z, labels

In [None]:
T_START, T_END = 0.5, 0.08

best_val = float("-inf")
patience_counter = 0

for epoch in tqdm(range(1, EPOCHS + 1)):
    train_out = train_one_epoch(
        encoder=encoder,
        loader=train_loader,
        optimizer=optimizer,
        device=device,
        epoch=epoch,
        writer=summary_writer,
        augmenter=augmenter,
        temp_root=max(T_END, T_START - (T_START - T_END) * (epoch-1) / max(1, EPOCHS-1)),
        proj_head=proj_head
    )

    try:
        current_lr = scheduler.get_last_lr()[0] if hasattr(scheduler, "get_last_lr") else optimizer.param_groups[0]["lr"]
    except Exception:
        current_lr = optimizer.param_groups[0]["lr"]
    summary_writer.add_scalar("lr", float(current_lr), epoch)

    sil_score, _, _ = evaluate_root_clustering(
        encoder=encoder,
        loader=val_loader,
        device=device,
        n_clusters=NUM_CLUSTERS
    )
    
    knn_consistency = evaluate_root_knn_consistency(
        encoder=encoder,
        loader=val_loader,
        device=device,
        augmenter=augmenter,
        k=KNN_K
    )

    val_metric = knn_consistency if sil_score < 0 else BLEND_ALPHA * sil_score + (1.0 - BLEND_ALPHA) * knn_consistency
    
    tr_loss = float(train_out.get("loss", float("nan")))
    tr_root_loss = float(train_out.get("root_loss", float("nan")))

    summary_writer.add_scalar("val/silhouette_root", float(sil_score), epoch)
    summary_writer.add_scalar("val/knn_consistency_root", float(knn_consistency), epoch)
    summary_writer.add_scalar("val/metric", float(val_metric), epoch)
    summary_writer.add_scalar("train/total_loss_epoch", tr_loss, epoch)
    summary_writer.add_scalar("train/root_loss_epoch",  tr_root_loss, epoch)

    print(
        f"Epoch {epoch:02d}/{EPOCHS}\t|\t"
        f"Train loss {tr_loss:.4f} "
        f"(root loss {tr_root_loss:.4f})\t|\t"
        f"Val {val_metric:.4f} (sil {sil_score:.4f}, knn {knn_consistency:.4f})\t|\t"
        f"LR {current_lr:.6f}"
    )

    if val_metric > best_val:
        best_val = val_metric
        torch.save({
            "epoch": epoch,
            "model_state": encoder.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "best_val_metric": float(best_val),
            "silhouette": float(sil_score),
            "knn_consistency": float(knn_consistency),
        }, MODEL_SAVE_PATH)
        print(f"New best model saved (val metric={best_val:.4f})\n")
        patience_counter = 0
    else:
        patience_counter += 1
        print(f"Patience {patience_counter}/{PATIENCE}\n")
        if patience_counter >= PATIENCE:
            print(f"\nEarly stop: no improvement in {PATIENCE} epochs.\n")
            break

print("\n\nSuccessfully finished training!!")