In [None]:
# ================================
# Environment setup (Colab)
# ================================
!pip -q install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
!pip -q install torch-geometric torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
print("✅ PyTorch & PyG install attempted.")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m66.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m86.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m90.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m29.4 MB/s[0m eta [36m0:00:00[0m
[?25h✅ PyTorch & PyG install attempted.


In [None]:
# ================================
# Imports & Utilities
# ================================
import os, random, copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, SAGEConv, GraphSAGE
from torch_geometric.utils import train_test_split_edges, to_undirected, subgraph
from torch_geometric.transforms import RandomLinkSplit
from sklearn.metrics import roc_auc_score, average_precision_score

# Device & seed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

Device: cuda


In [None]:
#================================
# Load Citeseer + edge split for link prediction (70/10/20)
#================================
dataset = Planetoid(root='/content/data/Citeseer', name='Citeseer')
data = dataset[0]
data.edge_index = to_undirected(data.edge_index)  # Ensure undirected
# Use RandomLinkSplit for link prediction: splits edges into train/val/test
# add_negative_train_samples=True to include negative edges for supervision
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, add_negative_train_samples=True)
train_data, val_data, test_data = transform(data)
num_feats = dataset.num_node_features
print(train_data)
print('Features:', num_feats)

Data(x=[3327, 3703], edge_index=[2, 6374], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327], edge_label=[6374], edge_label_index=[2, 6374])
Features: 3703


In [None]:
#================================
# Config (paper-ish, Citeseer for link prediction)
#================================
CFG = dict(
POS_TRAIN=50, POS_TEST=50,
NEG_TRAIN=50, NEG_TEST=50,
USE_FT_LAST=True, USE_FT_ALL=True,
USE_PR_LAST=True, USE_PR_ALL=True,
USE_DISTILL=True,
DISTILL_STEPS=250,   # paper ~1000; adjust if GPU/time allows
FP_P=64,
FP_NODES=32,
FP_SAMPLE_M=32,  # Number of node pairs to sample per fingerprint graph
FP_EDGE_INIT_P=0.05,
FP_EDGE_TOPK=96,
EDGE_LOGIT_STEP=2.5,
OUTER_ITERS=20,
FP_STEPS=5,
V_STEPS=10,
LR_TARGET=0.005,
WD_TARGET=5e-4,
LR_V=1e-3,
LR_X=1e-3,
SEED=1,
)
print(CFG)

{'POS_TRAIN': 50, 'POS_TEST': 50, 'NEG_TRAIN': 50, 'NEG_TEST': 50, 'USE_FT_LAST': True, 'USE_FT_ALL': True, 'USE_PR_LAST': True, 'USE_PR_ALL': True, 'USE_DISTILL': True, 'DISTILL_STEPS': 250, 'FP_P': 64, 'FP_NODES': 32, 'FP_SAMPLE_M': 32, 'FP_EDGE_INIT_P': 0.05, 'FP_EDGE_TOPK': 96, 'EDGE_LOGIT_STEP': 2.5, 'OUTER_ITERS': 20, 'FP_STEPS': 5, 'V_STEPS': 10, 'LR_TARGET': 0.005, 'WD_TARGET': 0.0005, 'LR_V': 0.001, 'LR_X': 0.001, 'SEED': 1}


In [None]:
#================================
# Define 3-layer GNN models (output embeddings for link prediction)
#================================
class GCN(nn.Module):
    def __init__(self, in_channels, hidden, dropout=0.5):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.conv3 = GCNConv(hidden, hidden)  # Output embeddings of size hidden
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv3(x, edge_index)
        return x

class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden, dropout=0.5):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden)
        self.conv2 = SAGEConv(hidden, hidden)
        self.conv3 = SAGEConv(hidden, hidden)  # Output embeddings of size hidden
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv3(x, edge_index)
        return x

In [None]:
#================================
# Train helpers for link prediction
#================================
@torch.no_grad()
def evaluate(model, loader_data):
    model.eval()
    h = model(loader_data.x.to(device), loader_data.edge_index.to(device))
    edge_label_index = loader_data.edge_label_index.to(device)
    edge_label = loader_data.edge_label.to(device).float()
    h_src = h[edge_label_index[0]]
    h_dst = h[edge_label_index[1]]
    scores = torch.sigmoid((h_src * h_dst).sum(dim=-1))
    pred = (scores > 0.5).float()
    acc = (pred == edge_label).float().mean().item()
    return acc
def train_link_pred(model, train_data, val_data, epochs=200, lr=0.005, wd=5e-4):
    model = model.to(device)
    train_data = train_data.to(device)
    val_data = val_data.to(device)
    opt = Adam(model.parameters(), lr=lr, weight_decay=wd)
    best = {'val': 0.0, 'state': None}
    for ep in range(epochs):
        model.train(); opt.zero_grad()
        h = model(train_data.x, train_data.edge_index)
        edge_label_index = train_data.edge_label_index
        edge_label = train_data.edge_label.float()
        h_src = h[edge_label_index[0]]
        h_dst = h[edge_label_index[1]]
        scores = torch.sigmoid((h_src * h_dst).sum(dim=-1))
        loss = F.binary_cross_entropy(scores, edge_label)
        loss.backward(); opt.step()
        va = evaluate(model, val_data)
        if va > best['val']:
            best['val'] = va; best['state'] = copy.deepcopy(model.state_dict())
        if ep % 20 == 0:
            print(f"Epoch {ep:03d} | loss {loss.item():.4f} | val {va:.3f}")
    if best['state'] is not None:
        model.load_state_dict(best['state'])
    te = evaluate(model, test_data.to(device))
    print(f"✅ Final (best-val) | val {best['val']:.3f} | test {te:.3f}")
    return model

In [None]:
# Train target model f (GCN for link prediction)
#================================
set_seed(CFG["SEED"])
model_f = GCN(num_feats, hidden=16, dropout=0.5)
model_f = train_link_pred(model_f, train_data, val_data, epochs=200, lr=CFG["LR_TARGET"], wd=CFG["WD_TARGET"])

Epoch 000 | loss 0.6900 | val 0.504
Epoch 020 | loss 0.4350 | val 0.724
Epoch 040 | loss 0.3384 | val 0.703
Epoch 060 | loss 0.2734 | val 0.686
Epoch 080 | loss 0.2241 | val 0.699
Epoch 100 | loss 0.2020 | val 0.698
Epoch 120 | loss 0.1766 | val 0.698
Epoch 140 | loss 0.1746 | val 0.691
Epoch 160 | loss 0.1564 | val 0.688
Epoch 180 | loss 0.1614 | val 0.709
✅ Final (best-val) | val 0.726 | test 0.698


In [None]:
#=========================================
# Build suspect models (F+ and F−)  [BUDGETED, adapted for link prediction]
#=========================================
@torch.no_grad()
def reset_module(m):
    for layer in m.modules():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()

def ft_model(base_model, train_data, last_only=True, epochs=10, lr=0.005, seed=123):
    set_seed(seed)
    m = copy.deepcopy(base_model).to(device)
    for p in m.parameters(): p.requires_grad_(not last_only)
    for p in m.conv3.parameters(): p.requires_grad_(True)  # last layer fine-tune
    opt = Adam(filter(lambda p: p.requires_grad, m.parameters()), lr=lr)
    for _ in range(epochs):
        m.train(); opt.zero_grad()
        h = m(train_data.x, train_data.edge_index)
        edge_label_index = train_data.edge_label_index
        edge_label = train_data.edge_label.float()
        h_src = h[edge_label_index[0]]
        h_dst = h[edge_label_index[1]]
        scores = torch.sigmoid((h_src * h_dst).sum(dim=-1))
        loss = F.binary_cross_entropy(scores, edge_label)
        loss.backward(); opt.step()
    return m.eval()

def pr_model(base_model, train_data, last_only=True, epochs=10, lr=0.005, seed=456):
    set_seed(seed)
    m = copy.deepcopy(base_model).to(device)
    if last_only: reset_module(m.conv3)
    else:         reset_module(m)
    opt = Adam(m.parameters(), lr=lr)
    for _ in range(epochs):
        m.train(); opt.zero_grad()
        h = m(train_data.x, train_data.edge_index)
        edge_label_index = train_data.edge_label_index
        edge_label = train_data.edge_label.float()
        h_src = h[edge_label_index[0]]
        h_dst = h[edge_label_index[1]]
        scores = torch.sigmoid((h_src * h_dst).sum(dim=-1))
        loss = F.binary_cross_entropy(scores, edge_label)
        loss.backward(); opt.step()
    return m.eval()

def make_student(arch='GCN', hidden=16):
    return (GCN(num_feats, hidden, dropout=0.5).to(device)
            if arch=='GCN' else
            GraphSAGE(num_feats, hidden, dropout=0.5).to(device))

def random_subgraph_idx(n, keep_ratio=0.6, seed=7):
    g = torch.Generator().manual_seed(seed)
    keep = int(n*keep_ratio)
    return torch.randperm(n, generator=g)[:keep]

def distill_from_teacher(teacher, data, arch='GCN', steps=250, lr=0.01, seed=777):
    set_seed(seed)
    student = make_student(arch, hidden=16)
    opt = Adam(student.parameters(), lr=lr)
    mse = nn.MSELoss()
    x_all = data.x.to(device); ei_all = data.edge_index.to(device)
    for t in range(steps):
        keep_ratio = float(torch.empty(1).uniform_(0.5, 0.8))
        idx = random_subgraph_idx(data.num_nodes, keep_ratio=keep_ratio, seed=seed+t).to(device)
        ei_sub, _ = subgraph(idx, ei_all, relabel_nodes=True)
        x_sub = x_all[idx]
        with torch.no_grad():
            h_t = teacher(x_sub, ei_sub)
        student.train(); opt.zero_grad()
        h_s = student(x_sub, ei_sub)
        loss = mse(h_s, h_t)
        loss.backward(); opt.step()
    return student.eval()

#---- budget splitter ----
def _distribute_budget(total, keys):
    if not keys: return {}
    base = total // len(keys); rem = total - base*len(keys)
    out = {k: base for k in keys}
    for k in keys[:rem]: out[k] += 1
    return out

#===== Positives (F+) =====
F_pos_all = []
pos_total = CFG["POS_TRAIN"] + CFG["POS_TEST"]
pos_keys = []
if CFG["USE_FT_LAST"]: pos_keys.append("FT_LAST")
if CFG["USE_FT_ALL"]:  pos_keys.append("FT_ALL")
if CFG["USE_PR_LAST"]: pos_keys.append("PR_LAST")
if CFG["USE_PR_ALL"]:  pos_keys.append("PR_ALL")
if CFG["USE_DISTILL"]: pos_keys.append("DISTILL")
pos_budget = _distribute_budget(pos_total, pos_keys)
seed_base = 10
for key in pos_keys:
    cnt = pos_budget[key]
    if key == "FT_LAST":
        for s in range(seed_base, seed_base+cnt):
            F_pos_all.append(ft_model(model_f, train_data, last_only=True,  epochs=10, seed=s))
        seed_base += cnt
    elif key == "FT_ALL":
        for s in range(seed_base, seed_base+cnt):
            F_pos_all.append(ft_model(model_f, train_data, last_only=False, epochs=10, seed=s))
        seed_base += cnt
    elif key == "PR_LAST":
        for s in range(seed_base, seed_base+cnt):
            F_pos_all.append(pr_model(model_f, train_data, last_only=True,  epochs=10, seed=s))
        seed_base += cnt
    elif key == "PR_ALL":
        for s in range(seed_base, seed_base+cnt):
            F_pos_all.append(pr_model(model_f, train_data, last_only=False, epochs=10, seed=s))
        seed_base += cnt
    elif key == "DISTILL":
        arches = (['GCN'] * (cnt//2) + ['SAGE'] * (cnt - cnt//2))
        for i, arch in enumerate(arches, 400):
            F_pos_all.append(distill_from_teacher(model_f, data, arch=arch,
                                                steps=CFG["DISTILL_STEPS"], seed=1000+i))
        seed_base += cnt

assert len(F_pos_all) == pos_total, f"Expected {pos_total} positives, got {len(F_pos_all)}"

#===== Negatives (F−) =====
F_neg_all = []
neg_total = CFG["NEG_TRAIN"] + CFG["NEG_TEST"]
neg_keys = ["GCN", "SAGE"]
neg_budget = _distribute_budget(neg_total, neg_keys)
seed_base = 500
for s in range(seed_base, seed_base+neg_budget["GCN"]):
    set_seed(s)
    m = GCN(num_feats, 16, dropout=0.5)
    m = train_link_pred(m, train_data, val_data, epochs=120, lr=CFG["LR_TARGET"], wd=CFG["WD_TARGET"])
    F_neg_all.append(m.eval())
seed_base += neg_budget["GCN"]
for s in range(seed_base, seed_base+neg_budget["SAGE"]):
    set_seed(s)
    m = GraphSAGE(num_feats, 32, dropout=0.5)
    m = train_link_pred(m, train_data, val_data, epochs=120, lr=CFG["LR_TARGET"], wd=CFG["WD_TARGET"])
    F_neg_all.append(m.eval())

#===== Train/Test split =====
def split_pool(pool, n_train, n_test, seed=999):
    set_seed(seed)
    idx = torch.randperm(len(pool)).tolist()
    train = [pool[i] for i in idx[:n_train]]
    test  = [pool[i] for i in idx[n_train:n_train+n_test]]
    return train, test

F_pos_tr, F_pos_te = split_pool(F_pos_all, CFG["POS_TRAIN"], CFG["POS_TEST"])
F_neg_tr, F_neg_te = split_pool(F_neg_all, CFG["NEG_TRAIN"], CFG["NEG_TEST"])
print(f"F+ train/test: {len(F_pos_tr)}/{len(F_pos_te)} | F- train/test: {len(F_neg_tr)}/{len(F_neg_te)} (totals match budgets)")

Epoch 000 | loss 0.6888 | val 0.499
Epoch 020 | loss 0.4150 | val 0.727
Epoch 040 | loss 0.3185 | val 0.711
Epoch 060 | loss 0.2563 | val 0.692
Epoch 080 | loss 0.2080 | val 0.685
Epoch 100 | loss 0.1894 | val 0.679
✅ Final (best-val) | val 0.735 | test 0.723
Epoch 000 | loss 0.6907 | val 0.500
Epoch 020 | loss 0.4350 | val 0.713
Epoch 040 | loss 0.3433 | val 0.701
Epoch 060 | loss 0.2987 | val 0.681
Epoch 080 | loss 0.2513 | val 0.685
Epoch 100 | loss 0.2204 | val 0.696
✅ Final (best-val) | val 0.744 | test 0.716
Epoch 000 | loss 0.6859 | val 0.499
Epoch 020 | loss 0.4111 | val 0.699
Epoch 040 | loss 0.3244 | val 0.685
Epoch 060 | loss 0.2737 | val 0.682
Epoch 080 | loss 0.2352 | val 0.671
Epoch 100 | loss 0.2229 | val 0.669
✅ Final (best-val) | val 0.736 | test 0.702
Epoch 000 | loss 0.6893 | val 0.512
Epoch 020 | loss 0.4129 | val 0.725
Epoch 040 | loss 0.3199 | val 0.720
Epoch 060 | loss 0.2565 | val 0.682
Epoch 080 | loss 0.2259 | val 0.707
Epoch 100 | loss 0.1866 | val 0.708
✅ Fi

In [None]:
#=======================================================
# Fingerprint set for edge-level (P fingerprints, sample m pairs per graph)
#=======================================================
class FingerprintGraph(nn.Module):
    def __init__(self, n_nodes, feat_dim, sample_m, edge_init_p=0.05):
        super().__init__()
        self.n = n_nodes
        self.d = feat_dim
        self.m = min(sample_m, n_nodes*(n_nodes-1)//2)  # Max possible pairs
        X = torch.empty(self.n, self.d).uniform_(-0.5, 0.5)
        self.X = nn.Parameter(X.to(device))
        A0 = (torch.rand(self.n, self.n, device=device) < edge_init_p).float()
        A0.fill_diagonal_(0.0)
        A0 = torch.maximum(A0, A0.T)
        self.A_logits = nn.Parameter(torch.logit(torch.clamp(A0, 1e-4, 1-1e-4)))
        #Sample m unique node pairs (for potential edges)
        all_pairs = torch.combinations(torch.arange(self.n, device=device), r=2)
        perm = torch.randperm(len(all_pairs), device=device)[:self.m]
        self.sample_pairs = all_pairs[perm].t()  # 2 x m

    @torch.no_grad()
    def edge_index(self):
        A_prob = torch.sigmoid(self.A_logits)
        A_bin = (A_prob > 0.5).float()
        A_bin.fill_diagonal_(0.0)
        A_bin = torch.maximum(A_bin, A_bin.T)
        idx = A_bin.nonzero(as_tuple=False)
        if idx.numel() == 0:
            return torch.empty(2, 0, dtype=torch.long, device=device)
        return idx.t().contiguous()

    @torch.no_grad()
    def flip_topk_by_grad(self, gradA, topk=64, step=2.5):
        g = gradA.abs()
        triu = torch.triu(torch.ones_like(g), diagonal=1)
        scores = (g * triu).flatten()
        k = min(topk, scores.numel())
        if k == 0: return
        _, idxs = torch.topk(scores, k=k)
        r = self.n
        pairs = torch.stack((idxs // r, idxs % r), dim=1)
        A_prob = torch.sigmoid(self.A_logits).detach()
        for (u, v) in pairs.tolist():
            guv = gradA[u, v].item()
            exist = A_prob[u, v] > 0.5
            if exist and guv <= 0:
                self.A_logits.data[u, v] -= step
                self.A_logits.data[v, u] -= step
            elif (not exist) and guv >= 0:
                self.A_logits.data[u, v] += step
                self.A_logits.data[v, u] += step
        self.A_logits.data.fill_diagonal_(-10.0)

class FingerprintSet(nn.Module):
    def __init__(self, P, n_nodes, feat_dim, sample_m, edge_init_p=0.05, topk_edges=64, edge_step=2.5):
        super().__init__()
        self.P = P
        self.fps = nn.ModuleList([FingerprintGraph(n_nodes, feat_dim, sample_m, edge_init_p) for _ in range(P)]).to(device)
        self.topk_edges = topk_edges
        self.edge_step = edge_step

    def concat_outputs(self, model, *, require_grad: bool = False):
        outs = []
        model.eval()
        ctx = torch.enable_grad() if require_grad else torch.no_grad()
        with ctx:
            for fp in self.fps:
                ei = fp.edge_index()
                h  = model(fp.X, ei)
                h_u = h[fp.sample_pairs[0]]
                h_v = h[fp.sample_pairs[1]]
                probs = torch.sigmoid((h_u * h_v).sum(dim=-1))  # m probabilities
                outs.append(probs)
        return torch.cat(outs, dim=0)

    def flip_adj_by_grad(self, surrogate_grad_list):
        for fp, g in zip(self.fps, surrogate_grad_list):
            fp.flip_topk_by_grad(g, topk=self.topk_edges, step=self.edge_step)

fp_set = FingerprintSet(
    P=CFG["FP_P"],
    n_nodes=CFG["FP_NODES"],
    feat_dim=num_feats,
    sample_m=CFG["FP_SAMPLE_M"],
    edge_init_p=CFG["FP_EDGE_INIT_P"],
    topk_edges=CFG["FP_EDGE_TOPK"],
    edge_step=CFG["EDGE_LOGIT_STEP"],
)
INPUT_DIM = CFG["FP_P"] * CFG["FP_SAMPLE_M"]  # P graphs * m pairs * 1 (prob)
print("Univerifier input dim =", INPUT_DIM)

Univerifier input dim = 2048


In [None]:
#========================================
# Univerifier (binary classifier)
#========================================
class Univerifier(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128), nn.LeakyReLU(0.01),
            nn.Linear(128, 64),       nn.LeakyReLU(0.01),
            nn.Linear(64, 32),        nn.LeakyReLU(0.01),
            nn.Linear(32, 2),
        )
    def forward(self, x):
        return self.net(x)
V = Univerifier(INPUT_DIM).to(device)
opt_V = Adam(V.parameters(), lr=CFG["LR_V"])

In [None]:
#=====================================================
# Joint learning (Algorithm-1 for edge-level)
#=====================================================
models_pos_tr = [model_f.to(device)] + [m.to(device) for m in F_pos_tr]
models_neg_tr = [m.to(device) for m in F_neg_tr]
print(f"Train pools -> Pos: {len(models_pos_tr)} | Neg: {len(models_neg_tr)}")
def batch_from_pool(fp_set, pos_models, neg_models, *, require_grad: bool):
    X = []; y = []
    for m in pos_models:
        X.append(fp_set.concat_outputs(m, require_grad=require_grad)); y.append(1)
    for m in neg_models:
        X.append(fp_set.concat_outputs(m, require_grad=require_grad)); y.append(0)
    return torch.stack(X, dim=0), torch.tensor(y, device=device)
def surrogate_grad_A_for_fp(fp, model):
    #cosine-sim surrogate gradient from target model's first layer
    with torch.no_grad():
        ei = fp.edge_index()
        h  = model.conv1(fp.X, ei)
        h  = F.relu(h)
        hn = F.normalize(h, dim=-1)
        sim = hn @ hn.t()
        gradA = sim - 0.5
    return gradA.detach().cpu()
def update_features(fp_set, V, pos_models, neg_models, steps, lr_x):
    #freeze model params
    for m in pos_models + neg_models:
        for p in m.parameters(): p.requires_grad_(False)
    #turn on grad for X before building batch
    for fp in fp_set.fps:
        fp.X.requires_grad_(True)
    for _ in range(steps):
        Xb, yb = batch_from_pool(fp_set, pos_models, neg_models, require_grad=True)
        V.eval()
        for p in V.parameters(): p.requires_grad_(False)
        logits = V(Xb.to(device))
        loss = F.cross_entropy(logits, yb)
        for fp in fp_set.fps:
            if fp.X.grad is not None: fp.X.grad.zero_()
        loss.backward()
        with torch.no_grad():
            for fp in fp_set.fps:
                fp.X.add_(lr_x * fp.X.grad)
                fp.X.grad.zero_()
        for p in V.parameters(): p.requires_grad_(True)
    grads = [surrogate_grad_A_for_fp(fp, pos_models[0]) for fp in fp_set.fps]
    fp_set.flip_adj_by_grad(grads)
def update_verifier(fp_set, V, pos_models, neg_models, steps):
    for _ in range(steps):
        V.train()
        Xb, yb = batch_from_pool(fp_set, pos_models, neg_models, require_grad=False)
        logits = V(Xb.to(device))
        loss = F.cross_entropy(logits, yb)
        opt_V.zero_grad(); loss.backward(); opt_V.step()
for it in range(1, CFG["OUTER_ITERS"] + 1):
    update_features(fp_set, V, models_pos_tr, models_neg_tr, steps=CFG["FP_STEPS"], lr_x=CFG["LR_X"])
    update_verifier(fp_set, V, models_pos_tr, models_neg_tr, steps=CFG["V_STEPS"])
    V.eval()
    Xb, yb = batch_from_pool(fp_set, models_pos_tr, models_neg_tr, require_grad=False)
    with torch.no_grad():
        pred = V(Xb).argmax(dim=1)
        acc = (pred.cpu() == yb.cpu()).float().mean().item()
        pos_acc = (pred[:len(models_pos_tr)].cpu() == 1).float().mean().item()
        neg_acc = (pred[len(models_pos_tr):].cpu() == 0).float().mean().item()
    print(f"Iter {it:02d}/{CFG['OUTER_ITERS']} | train all {acc:.3f} | pos {pos_acc:.3f} | neg {neg_acc:.3f}")

Train pools -> Pos: 51 | Neg: 50
Iter 01/20 | train all 0.495 | pos 0.000 | neg 1.000
Iter 02/20 | train all 0.782 | pos 1.000 | neg 0.560
Iter 03/20 | train all 0.822 | pos 1.000 | neg 0.640
Iter 04/20 | train all 0.960 | pos 0.941 | neg 0.980
Iter 05/20 | train all 0.891 | pos 1.000 | neg 0.780
Iter 06/20 | train all 0.921 | pos 1.000 | neg 0.840
Iter 07/20 | train all 0.970 | pos 0.961 | neg 0.980
Iter 08/20 | train all 0.980 | pos 1.000 | neg 0.960
Iter 09/20 | train all 0.980 | pos 1.000 | neg 0.960
Iter 10/20 | train all 0.941 | pos 0.882 | neg 1.000
Iter 11/20 | train all 0.941 | pos 1.000 | neg 0.880
Iter 12/20 | train all 0.990 | pos 1.000 | neg 0.980
Iter 13/20 | train all 1.000 | pos 1.000 | neg 1.000
Iter 14/20 | train all 1.000 | pos 1.000 | neg 1.000
Iter 15/20 | train all 1.000 | pos 1.000 | neg 1.000
Iter 16/20 | train all 1.000 | pos 1.000 | neg 1.000
Iter 17/20 | train all 1.000 | pos 1.000 | neg 1.000
Iter 18/20 | train all 0.812 | pos 1.000 | neg 0.620
Iter 19/20 | 

In [None]:
# Held-out verification (Robustness/Uniqueness/ARUC)
#==========================================================
models_pos_te = [model_f.to(device)] + [m.to(device) for m in F_pos_te]
models_neg_te = [m.to(device) for m in F_neg_te]
@torch.no_grad()
def verify_scores(V, fp_set, models):
    Xs = [fp_set.concat_outputs(m, require_grad=False) for m in models]
    logits = V(torch.stack(Xs, dim=0).to(device))
    probs = F.softmax(logits, dim=-1)[:, 1]  # p(positive)
    return probs.detach().cpu().numpy()
def sweep_threshold(p_pos, p_neg, num=301):
    ths = np.linspace(0.0, 1.0, num=num)
    R = []; U = []; A = []
    for t in ths:
        tp = (p_pos >= t).mean()    # robustness
        tn = (p_neg <  t).mean()    # uniqueness
        R.append(tp); U.append(tn)
        A.append((tp + tn) / 2.0)   # balanced acc (pos/neg equal)
    return ths, np.array(R), np.array(U), np.array(A)
p_pos = verify_scores(V, fp_set, models_pos_te)
p_neg = verify_scores(V, fp_set, models_neg_te)
ths, R, U, A = sweep_threshold(p_pos, p_neg, num=301)
best_idx = A.argmax()
mean_acc = A.mean()
#numpy>=2.0 has trapezoid; fallback to trapz if unavailable
ARUC = np.trapezoid(np.minimum(R, U), ths) if hasattr(np, "trapezoid") else np.trapz(np.minimum(R, U), ths)
print(f"Best @ λ={ths[best_idx]:.3f} | Robustness={R[best_idx]:.3f} | Uniqueness={U[best_idx]:.3f} | MeanAcc*={A[best_idx]:.3f}")
print(f"Mean Test Accuracy (avg over λ): {mean_acc:.3f}")
print(f"ARUC (approx): {ARUC:.3f}")

Best @ λ=0.387 | Robustness=0.922 | Uniqueness=1.000 | MeanAcc*=0.961
Mean Test Accuracy (avg over λ): 0.912
ARUC (approx): 0.835
