In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.utils as utils
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
import seaborn as sns
import args
import time

In [3]:
def prepare_gtnvae_data(root, dataset_name='Cora', pos_dim=16):
    from torch_geometric.datasets import Planetoid
    from torch_geometric.utils import to_dense_adj, train_test_split_edges, get_laplacian

    dataset = Planetoid(root=root, name=dataset_name)
    data = dataset[0]
    data_split = train_test_split_edges(data)
    N = data.num_nodes

    # Dense node features
    x = data.x.unsqueeze(0)

    # Dense adjacencies
    def to_dense_binary(edge_index):
        adj = to_dense_adj(edge_index, max_num_nodes=N)[0]
        return (adj > 0).float()

    train_adj = to_dense_binary(data_split.train_pos_edge_index).unsqueeze(0)
    val_adj   = to_dense_binary(data_split.val_pos_edge_index).unsqueeze(0)
    test_adj  = to_dense_binary(data_split.test_pos_edge_index).unsqueeze(0)

    # Positional encodings from train edges
    edge_index, edge_weight = get_laplacian(data_split.train_pos_edge_index, normalization='sym', num_nodes=N)
    L = torch.sparse_coo_tensor(edge_index, edge_weight, (N, N)).to_dense()
    eigval, eigvec = torch.linalg.eigh(L)
    pos_enc = eigvec[:, 1:pos_dim+1].unsqueeze(0)

    return x, pos_enc, train_adj, val_adj, test_adj, data.y

x, pos_enc, train_adj, val_adj, test_adj, y = prepare_gtnvae_data("../data/", 'Cora', pos_dim=128)

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [9]:
print(
    f"\n===== DATASET SUMMARY =====\n"
    f"Node Feature Shape (x):           {x.shape}  -> [B, N, d_node_in]\n"
    f"Positional Encoding Shape:        {pos_enc.shape}  -> [B, N, d_pos]\n"
    f"Train Adjacency Shape:            {train_adj.shape}  -> [B, N, N]\n"
    f"Validation Adjacency Shape:       {val_adj.shape}  -> [B, N, N]\n"
    f"Test Adjacency Shape:             {test_adj.shape}  -> [B, N, N]\n"
    f"--------------------------------------\n"
    f"Total Nodes (N):                  {train_adj.size(-1)}\n"
    f"Feature Dimension (d_node_in):    {x.size(-1)}\n"
    f"Positional Dim (d_pos):           {pos_enc.size(-1)}\n"
    f"Total Train Edges:                {int(train_adj.sum().item())}\n"
    f"Total Val Edges:                  {int(val_adj.sum().item())}\n"
    f"Total Test Edges:                 {int(test_adj.sum().item())}\n"
    f"======================================\n"
)


===== DATASET SUMMARY =====
Node Feature Shape (x):           torch.Size([1, 2708, 1433])  -> [B, N, d_node_in]
Positional Encoding Shape:        torch.Size([1, 2708, 128])  -> [B, N, d_pos]
Train Adjacency Shape:            torch.Size([1, 2708, 2708])  -> [B, N, N]
Validation Adjacency Shape:       torch.Size([1, 2708, 2708])  -> [B, N, N]
Test Adjacency Shape:             torch.Size([1, 2708, 2708])  -> [B, N, N]
--------------------------------------
Total Nodes (N):                  2708
Feature Dimension (d_node_in):    1433
Positional Dim (d_pos):           128
Total Train Edges:                8976
Total Val Edges:                  263
Total Test Edges:                 527



In [13]:
import torch.optim as optim
device = torch.device("mps")  # or "cuda" / "cpu"
load = None

model = GTN_VAE(
    input_dim=x.size(-1),
    pos_dim=pos_enc.size(-1),
    hidden_dim=128,
    n_layers=4,
    n_heads=4
).to(device)

if load:
    model = torch.load(load, weights_only=False)

In [15]:
x, pos_enc, train_adj, val_adj, test_adj, y = (
    x.to(device),
    pos_enc.to(device),
    train_adj.to(device),
    val_adj.to(device),
    test_adj.to(device),
    y.to(device)
)

# Weighting setup
pos_weight = float(train_adj.numel() - train_adj.sum()) / train_adj.sum()
norm = train_adj.numel() / ((train_adj.numel() - train_adj.sum()) * 2)
weight_mask = train_adj.flatten() == 1
weight_tensor = torch.ones(weight_mask.size(0), device=device)
weight_tensor[weight_mask] = pos_weight

In [6]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [7]:
from sklearn.metrics import roc_auc_score, average_precision_score

def compute_metrics(adj_pred, adj_true):
    adj_true_np = adj_true.detach().cpu().numpy().flatten()
    adj_pred_np = adj_pred.detach().cpu().numpy().flatten()
    roc = roc_auc_score(adj_true_np, adj_pred_np)
    ap = average_precision_score(adj_true_np, adj_pred_np)
    return roc, ap

def compute_accuracy(adj_pred, adj_true, threshold=0.5):
    pred_bin = (adj_pred > threshold).float()
    correct = (pred_bin == adj_true).float().sum()
    return (correct / adj_true.numel()).item()

def weighted_vae_loss(adj_recon, adj_true, mu, logvar, weight_tensor, norm):
    bce = F.binary_cross_entropy(
        adj_recon.view(-1),
        adj_true.view(-1),
        weight=weight_tensor,
    )
    recon_loss = norm * bce
    kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss, recon_loss, kl_loss

def balanced_edge_indices(adj_true, num_samples=5000):
    """
    Returns equal numbers of positive and negative edge indices.
    Ensures class balance and guards against small graphs.
    """
    # Flatten and find all positive/negative indices
    pos_idx = (adj_true.view(-1) == 1).nonzero(as_tuple=False).view(-1)
    neg_idx = (adj_true.view(-1) == 0).nonzero(as_tuple=False).view(-1)

    # How many to sample per class
    num_pos = min(len(pos_idx), num_samples // 2)
    num_neg = num_pos  # strict equality
    if num_pos == 0 or num_neg == 0:
        raise ValueError("No positive or negative edges to sample from.")

    # Random balanced sample
    pos_idx = pos_idx[torch.randperm(len(pos_idx))[:num_pos]]
    neg_idx = neg_idx[torch.randperm(len(neg_idx))[:num_neg]]
    idx = torch.cat([pos_idx, neg_idx])

    return idx, num_pos, num_neg


def sampled_vae_loss(adj_pred, adj_true, mu, logvar, num_samples=5000):
    """
    Balanced random BCE + KL loss for VGAE-style models.
    """
    idx, num_pos, num_neg = balanced_edge_indices(adj_true, num_samples)
    y_true = adj_true.view(-1)[idx]
    y_pred = adj_pred.view(-1)[idx]

    # BCE over balanced sample
    bce = F.binary_cross_entropy(y_pred, y_true)
    kl  = -0.001 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    return bce + kl, bce, kl


def sampled_metrics(adj_pred, adj_true, num_samples=5000):
    """
    Computes ROC-AUC, AP, and Accuracy on balanced edge samples.
    """
    idx, num_pos, num_neg = balanced_edge_indices(adj_true, num_samples)
    y_true = adj_true.view(-1)[idx].cpu().numpy()
    y_pred = adj_pred.view(-1)[idx].detach().cpu().numpy()

    roc = roc_auc_score(y_true, y_pred)
    ap  = average_precision_score(y_true, y_pred)
    acc = ((torch.tensor(y_pred) > 0.5).float().numpy() == y_true).mean()
    return roc, ap, acc

In [8]:
def model_in_out(model_comp, inputs):
    output = model_comp(*inputs)
    output_mean = output.mean()
    output_std = output.std()
    print(f"output: {output}\noutput_mean:{output_mean}\noutput_std:{output_std}")
    return output

def model_debug_auto(model_comp, inputs, verbose=True):
    # Run forward pass
    if not isinstance(inputs, (tuple, list)):
        inputs = (inputs,)
    
    output = model_comp(*inputs)
    
    def summarize_tensor(name, tensor):
        if isinstance(tensor, torch.Tensor):
            return {
                "name": name,
                "shape": tuple(tensor.shape),
                "dtype": tensor.dtype,
                "min": tensor.min().item(),
                "max": tensor.max().item(),
                "mean": tensor.mean().item(),
                "std": tensor.std().item()
            }
        elif isinstance(tensor, (list, tuple)):
            return [summarize_tensor(f"{name}[{i}]", t) for i, t in enumerate(tensor)]
        elif isinstance(tensor, dict):
            return {k: summarize_tensor(f"{name}.{k}", v) for k, v in tensor.items()}
        else:
            return {name: str(type(tensor))}

    # Summarize input(s) and output(s)
    input_summary = [summarize_tensor(f"input[{i}]", inp) for i, inp in enumerate(inputs)]
    output_summary = summarize_tensor("output", output)

    if verbose:
        print("=== INPUT SUMMARY ===")
        for s in input_summary:
            print(s)
        print("\n=== OUTPUT SUMMARY ===")
        print(output_summary)
    
    return output, input_summary, output_summary

# Model Training

In [8]:
num_epochs = 200

for epoch in range(1, num_epochs + 1):
    # ---------- TRAIN ----------
    model.train()
    optimizer.zero_grad()

    # Forward pass on training adjacency
    adj_recon = model(x, pos_enc, train_adj)

    # Compute sampled loss (balanced positives/negatives)
    total_loss, recon_loss, kl_loss = sampled_vae_loss(
        adj_recon, train_adj, model.z_mean, model.z_logvar, num_samples=8000
    )

    total_loss.backward()
    optimizer.step()

    # Metrics (sampled)
    roc_train, ap_train, acc_train = sampled_metrics(adj_recon, train_adj, num_samples=8000)

    # ---------- VALIDATION ----------
    model.eval()
    with torch.no_grad():
        adj_val_recon = model(x, pos_enc, val_adj)
        val_loss, _, _ = sampled_vae_loss(
            adj_val_recon, val_adj, model.z_mean, model.z_logvar, num_samples=400
        )
        roc_val, ap_val, acc_val = sampled_metrics(adj_val_recon, val_adj, num_samples=400)

    # ---------- LOGGING ----------
    print(
        f"Epoch [{epoch:03d}/{num_epochs}] | "
        f"Train Loss: {total_loss.item():.4f} | Val Loss: {val_loss.item():.4f} | "
        f"Train Acc: {acc_train:.4f} | Val Acc: {acc_val:.4f} | "
        f"Train ROC: {roc_train:.4f} | Val ROC: {roc_val:.4f} | "
        f"Train AP: {ap_train:.4f} | Val AP: {ap_val:.4f}"
    )

Epoch [001/200] | Train Loss: 4.7473 | Val Loss: 4.5702 | Train Acc: 0.5032 | Val Acc: 0.4950 | Train ROC: 0.5038 | Val ROC: 0.5086 | Train AP: 0.5019 | Val AP: 0.5080
Epoch [002/200] | Train Loss: 4.2339 | Val Loss: 3.9358 | Train Acc: 0.5180 | Val Acc: 0.4975 | Train ROC: 0.5258 | Val ROC: 0.5035 | Train AP: 0.5148 | Val AP: 0.5013
Epoch [003/200] | Train Loss: 3.8169 | Val Loss: 3.7201 | Train Acc: 0.5201 | Val Acc: 0.5600 | Train ROC: 0.5242 | Val ROC: 0.5543 | Train AP: 0.5154 | Val AP: 0.5398
Epoch [004/200] | Train Loss: 3.8571 | Val Loss: 3.8252 | Train Acc: 0.5051 | Val Acc: 0.5075 | Train ROC: 0.5136 | Val ROC: 0.4986 | Train AP: 0.5097 | Val AP: 0.4936
Epoch [005/200] | Train Loss: 3.8008 | Val Loss: 3.2760 | Train Acc: 0.5224 | Val Acc: 0.5150 | Train ROC: 0.5229 | Val ROC: 0.5214 | Train AP: 0.5139 | Val AP: 0.5096
Epoch [006/200] | Train Loss: 3.6599 | Val Loss: 3.3327 | Train Acc: 0.5200 | Val Acc: 0.5275 | Train ROC: 0.5248 | Val ROC: 0.5325 | Train AP: 0.5153 | Val AP:

In [9]:
# torch.manual_seed(42)
def test_model(model, x, pos_e, test_adj, num_samples):
    model.eval()
    with torch.no_grad():
        adj_test_recon = model(x, pos_enc, test_adj)
    
        # Balanced sampled VAE loss
        test_loss, _, _ = sampled_vae_loss(
            adj_test_recon, test_adj, model.z_mean, model.z_logvar, num_samples=1054

        )
    
        # Sampled metrics (balanced)
        roc_test, ap_test, acc_test = sampled_metrics(
            adj_test_recon, test_adj, num_samples=num_samples
        )
    return test_loss, acc_test, roc_test, ap_test

exps = 20
losses = []
accs = []
rocs = []
precs = []
for i in range(exps):
    torch.manual_seed(42 + i)
    t_loss, t_acc, t_roc, t_prec = test_model(model, x, pos_enc, test_adj, num_samples=1000)
    losses.append(t_loss.detach().cpu().item())
    accs.append(t_acc)
    rocs.append(t_roc)
    precs.append(t_prec)

In [10]:
losses = np.array(losses)
accs = np.array(accs)
rocs = np.array(rocs)
precs = np.array(precs)

In [11]:
print("\n=== FINAL TEST RESULTS ===")
print(
    f"Test Loss: {losses.mean():.4f} +- {losses.std()} | \n"
    f"Test Acc: {accs.mean():.4f} +- {accs.std()} | \n"
    f"ROC-AUC: {rocs.mean():.4f} +- {rocs.std()} |\nAP: {precs.mean():.4f} +- {precs.std()}"
)


=== FINAL TEST RESULTS ===
Test Loss: 0.4981 +- 0.011562777382434627 | 
Test Acc: 0.7343 +- 0.01081711606667878 | 
ROC-AUC: 0.9161 +- 0.005897844162064628 |
AP: 0.9235 +- 0.00753070491905668


=== FINAL TEST RESULTS ===
kl: 0.005
n_heads = 4
dim = 128
n_layers = 4

Test Loss: 0.4868 +- 0.009812361955942385 | 
Test Acc: 0.7414 +- 0.008182145195485114 | 
ROC-AUC: 0.9176 +- 0.005479434469906512 |
AP: 0.9240 +- 0.007316556545360688

kl: 0.0005
lr: 0.0001

=== FINAL TEST RESULTS ===
Test Loss: 0.5380 +- 0.014766905705631252 | 
Test Acc: 0.7196 +- 0.008816320094007483 | 
ROC-AUC: 0.8679 +- 0.007809779471278348 |
AP: 0.8753 +- 0.009753158063202054

# Best Result

kl: 0.0005
lr: 1e-3
n_heads = 4, n_layers=4, pos_dim=128, hidden_dim=128

=== FINAL TEST RESULTS ===
Test Loss: 0.4818 +- 0.007868968959063699 | 
Test Acc: 0.7419 +- 0.01077253452071518 | 
ROC-AUC: 0.9204 +- 0.006907713010830738 |
AP: 0.9266 +- 0.008169060892064308

In [12]:
torch.save(model, "92_4H_4L_128.pt")

In [16]:
from analysis_model import run_full_analysis
run_full_analysis(model, x, pos_enc, train_adj, y=y)

  Q, _ = normalizer(A @ Q)
  Q, _ = normalizer(A @ Q)
  Q, _ = normalizer(A @ Q)
  Q, _ = normalizer(A.T @ Q)
  Q, _ = normalizer(A.T @ Q)
  Q, _ = normalizer(A.T @ Q)
  Q, _ = qr_normalizer(A @ Q)
  Q, _ = qr_normalizer(A @ Q)
  Q, _ = qr_normalizer(A @ Q)
  B = Q.T @ M
  B = Q.T @ M
  B = Q.T @ M
  U = Q @ Uhat
  U = Q @ Uhat
  U = Q @ Uhat
  Q, _ = normalizer(A @ Q)
  Q, _ = normalizer(A @ Q)
  Q, _ = normalizer(A @ Q)
  Q, _ = normalizer(A.T @ Q)
  Q, _ = normalizer(A.T @ Q)
  Q, _ = normalizer(A.T @ Q)
  Q, _ = qr_normalizer(A @ Q)
  Q, _ = qr_normalizer(A @ Q)
  Q, _ = qr_normalizer(A @ Q)
  B = Q.T @ M
  B = Q.T @ M
  B = Q.T @ M
  U = Q @ Uhat
  U = Q @ Uhat
  U = Q @ Uhat


{'latent_2d': '/Users/sid/COLLEGE_MATERIALS/Research/vgae_pytorch/visualizations/latent_tsne_2d_20251021_093826.png',
 'latent_3d': '/Users/sid/COLLEGE_MATERIALS/Research/vgae_pytorch/visualizations/latent_tsne_3d_20251021_093839.html',
 'adj_comparison': '/Users/sid/COLLEGE_MATERIALS/Research/vgae_pytorch/visualizations/adjacency_comparison_20251021_093850.png',
 'link_confidence': '/Users/sid/COLLEGE_MATERIALS/Research/vgae_pytorch/visualizations/link_confidence_20251021_093902.png',
 'attention_stats': '/Users/sid/COLLEGE_MATERIALS/Research/vgae_pytorch/visualizations/attention_stats_20251021_093906.png',
 'attention_layer_files': ['/Users/sid/COLLEGE_MATERIALS/Research/vgae_pytorch/visualizations/attention_layer_0_20251021_093924.png',
  '/Users/sid/COLLEGE_MATERIALS/Research/vgae_pytorch/visualizations/attention_layer_1_20251021_093956.png',
  '/Users/sid/COLLEGE_MATERIALS/Research/vgae_pytorch/visualizations/attention_layer_2_20251021_094028.png',
  '/Users/sid/COLLEGE_MATERIALS/