<a href="https://colab.research.google.com/github/SrivardhanS/Speech_GNN_FYP/blob/main/fyp_antispoofing_train_eval_error.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 1. Install torch-scatter, torch-sparse, torch-cluster, torch-spline-conv
#    from the official PyG wheels matching your torch+cuda version.
#    (This will pull prebuilt binaries)
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.4.0+cpu.html

# 2. Finally install torch-geometric
!pip install torch-geometric


Looking in links: https://data.pyg.org/whl/torch-2.4.0+cpu.html
Looking in links: https://data.pyg.org/whl/torch-2.4.0+cpu.html
Looking in links: https://data.pyg.org/whl/torch-2.4.0+cpu.html
Looking in links: https://data.pyg.org/whl/torch-2.4.0+cpu.html


In [None]:
save_path = "/content/gcn_compressed_asvspoof.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import kagglehub
import os

# Download dataset
path = kagglehub.dataset_download("awsaf49/asvpoof-2019-dataset")
print("Dataset path:", path)

Using Colab cache for faster access to the 'asvpoof-2019-dataset' dataset.
Dataset path: /kaggle/input/asvpoof-2019-dataset


In [None]:
# ======================
# Full: Compression + GCN training for ASVspoof2019 LA
# ======================
!pip install -q kaggle torch_geometric librosa

import os, sys, logging, random
import pandas as pd
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

# Logging setup (Colab-friendly)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
    force=True
)
logger = logging.getLogger("asvspoof-compress-train")


# ----------------------
# Dataset / Protocols
# ----------------------
import kagglehub
path = kagglehub.dataset_download("awsaf49/asvpoof-2019-dataset")
logger.info(f"Dataset path: {path}")

dataset_path = os.path.join(path, "LA", "LA")
proto_dir = os.path.join(dataset_path, "ASVspoof2019_LA_cm_protocols")
train_proto = os.path.join(proto_dir, "ASVspoof2019.LA.cm.train.trn.txt")

train_audio_dir = os.path.join(dataset_path, "ASVspoof2019_LA_train", "flac")
dev_audio_dir   = os.path.join(dataset_path, "ASVspoof2019_LA_dev", "flac")
eval_audio_dir  = os.path.join(dataset_path, "ASVspoof2019_LA_eval", "flac")

protocol_df = pd.read_csv(train_proto, sep=" ", header=None)
protocol_df.columns = ["utt_id", "speaker_id", "system_id", "attack_id", "label"]
logger.info(f"Protocol sample:\n{protocol_df.head()}")


# ----------------------
# audio -> temporal graph
# ----------------------
def audio_to_graph(file_path, sr=16000, n_mfcc=13, pool_size=4):
    # load
    y, _ = librosa.load(file_path, sr=sr)
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=n_mfcc)

    # pool frames
    T = mfcc.shape[1] // pool_size
    if T < 2:
        # pad/truncate guards
        mfcc = np.pad(mfcc, ((0,0),(0, pool_size*2 - mfcc.shape[1])), mode='wrap')
        T = mfcc.shape[1] // pool_size
    pooled = np.stack([ np.mean(mfcc[:, i*pool_size:(i+1)*pool_size], axis=1) for i in range(T) ])

    x = torch.tensor(pooled, dtype=torch.float)        # [nodes, feat]
    edge_index = torch.tensor([[i,i+1] for i in range(x.size(0)-1)], dtype=torch.long).T
    return Data(x=x, edge_index=edge_index)


# ----------------------
# DifferentiableGraphCompressor (exactly as you provided)
# ----------------------
class DifferentiableGraphCompressor(nn.Module):
    def __init__(self, feature_dim, tau_T=1.0, lambda_id=1.0, lambda_comp=0.1):
        super().__init__()
        self.a_raw = nn.Parameter(torch.tensor(0.0))
        self.tau_T = tau_T
        self.lambda_id = lambda_id
        self.lambda_comp = lambda_comp
        self.speaker_embedding = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32)
        )

    def get_alpha(self):
        return torch.sigmoid(self.a_raw)

    def node_similarity(self, x_u, x_v):
        return F.cosine_similarity(x_u.unsqueeze(0), x_v.unsqueeze(0), dim=1)

    def neighborhood_similarity(self, x, edge_index, u, v):
        def get_neighbors(node):
            mask = (edge_index[0] == node)
            if mask.any():
                return edge_index[1][mask]
            else:
                return torch.tensor([], dtype=torch.long, device=x.device)

        neighbors_u = get_neighbors(u)
        neighbors_v = get_neighbors(v)
        if neighbors_u.numel() == 0 or neighbors_v.numel() == 0:
            return torch.tensor(0.0, device=x.device)
        m_u = x[neighbors_u].mean(dim=0)
        m_v = x[neighbors_v].mean(dim=0)
        return F.cosine_similarity(m_u.unsqueeze(0), m_v.unsqueeze(0), dim=1)

    def combined_similarity(self, x, edge_index, u, v):
        alpha = self.get_alpha()
        sim_x = self.node_similarity(x[u], x[v])
        sim_m = self.neighborhood_similarity(x, edge_index, u, v)
        return alpha * sim_x + (1 - alpha) * sim_m

    def compute_adaptive_thresholds(self, similarities):
        tau_1_bar = torch.quantile(similarities, 0.75)
        tau_2_bar = torch.quantile(similarities, 0.60)
        return tau_1_bar, tau_2_bar

    def compute_merge_probabilities(self, x, edge_index, window_size=10):
        num_nodes = x.size(0)
        merge_probs = torch.zeros(num_nodes, device=x.device)
        all_similarities = []

        for t in range(1, num_nodes - 1):
            for k in range(1, min(window_size + 1, num_nodes - t)):
                if t + k >= num_nodes - 1:
                    continue
                try:
                    s1 = self.combined_similarity(x, edge_index, t, t + k)
                    s2 = self.combined_similarity(x, edge_index, t - 1, t + k - 1)
                    s3 = self.combined_similarity(x, edge_index, t + 1, t + k + 1)
                    all_similarities.extend([s1, s2, s3])
                except Exception:
                    continue

        if len(all_similarities) == 0:
            return merge_probs

        similarities_tensor = torch.stack(all_similarities).view(-1)
        tau_1_bar, tau_2_bar = self.compute_adaptive_thresholds(similarities_tensor)

        for t in range(1, num_nodes - 1):
            max_prob = 0.0
            for k in range(1, min(window_size + 1, num_nodes - t)):
                if t + k >= num_nodes - 1:
                    continue
                try:
                    s1 = self.combined_similarity(x, edge_index, t, t + k)
                    s2 = self.combined_similarity(x, edge_index, t - 1, t + k - 1)
                    s3 = self.combined_similarity(x, edge_index, t + 1, t + k + 1)
                    gate1 = torch.sigmoid((s1 - tau_1_bar) / self.tau_T)
                    gate2 = torch.sigmoid((s2 - tau_2_bar) / self.tau_T)
                    gate3 = torch.sigmoid((s3 - tau_2_bar) / self.tau_T)
                    prob = gate1 * gate2 * gate3
                    max_prob = max(max_prob, prob.item())
                except Exception:
                    continue
            merge_probs[t] = max_prob

        return merge_probs

    def differentiable_compression(self, x, edge_index):
        merge_probs = self.compute_merge_probabilities(x, edge_index)
        x_compressed = x.clone()
        num_nodes = x.size(0)
        for t in range(1, num_nodes - 1):
            if merge_probs[t] > 0:
                best_k = 1
                best_sim = -1
                for k in range(1, min(10, num_nodes - t)):
                    if t + k >= num_nodes: break
                    try:
                        sim = self.combined_similarity(x, edge_index, t, t + k)
                        if sim > best_sim:
                            best_sim = sim
                            best_k = k
                    except Exception:
                        continue
                if t + best_k < num_nodes:
                    p = merge_probs[t]
                    interpolated = (x[t] + x[t + best_k]) / 2
                    x_compressed[t] = (1 - p) * x[t] + p * interpolated
        return x_compressed, merge_probs

    def speaker_identity_loss(self, x_original, x_compressed):
        g_original = self.speaker_embedding(x_original.mean(dim=0))
        g_compressed = self.speaker_embedding(x_compressed.mean(dim=0))
        cos_sim = F.cosine_similarity(g_original.unsqueeze(0), g_compressed.unsqueeze(0))
        return 1 - cos_sim

    def compression_loss(self, merge_probs):
        return (1 - merge_probs).mean()

    def forward(self, x, edge_index):
        x_compressed, merge_probs = self.differentiable_compression(x, edge_index)
        L_id = self.speaker_identity_loss(x, x_compressed)
        L_comp = self.compression_loss(merge_probs)
        total_loss = self.lambda_id * L_id + self.lambda_comp * L_comp
        return {
            'compressed_features': x_compressed,
            'merge_probs': merge_probs,
            'loss': total_loss,
            'loss_id': L_id,
            'loss_comp': L_comp,
            'alpha': self.get_alpha()
        }

    # Hard compression inference helper for single graph
    def hard_compression_inference(self, x, edge_index, window_size=10):
        self.eval()
        with torch.no_grad():
            alpha = self.get_alpha().item()
            num_nodes = x.size(0)
            to_remove = set()
            all_similarities = []
            for t in range(1, num_nodes - 1):
                for k in range(1, min(window_size + 1, num_nodes - t)):
                    if t + k >= num_nodes - 1: continue
                    try:
                        s1 = self.combined_similarity(x, edge_index, t, t + k)
                        s2 = self.combined_similarity(x, edge_index, t - 1, t + k - 1)
                        s3 = self.combined_similarity(x, edge_index, t + 1, t + k + 1)
                        all_similarities.extend([s1, s2, s3])
                    except Exception:
                        continue
            if all_similarities:
                similarities_tensor = torch.stack(all_similarities).view(-1)
                tau_1_hat, tau_2_hat = self.compute_adaptive_thresholds(similarities_tensor)
                for t in range(1, num_nodes - 1):
                    if t in to_remove: continue
                    for k in range(1, min(window_size + 1, num_nodes - t)):
                        if t + k >= num_nodes - 1 or any(n in to_remove for n in [t, t+k, t-1, t+k-1, t+1, t+k+1]):
                            continue
                        try:
                            s1 = self.combined_similarity(x, edge_index, t, t + k)
                            s2 = self.combined_similarity(x, edge_index, t - 1, t + k - 1)
                            s3 = self.combined_similarity(x, edge_index, t + 1, t + k + 1)
                            if s1 >= tau_1_hat and s2 >= tau_2_hat and s3 >= tau_2_hat:
                                to_remove.add(t)
                                break
                        except Exception:
                            continue
            remaining_indices = [i for i in range(num_nodes) if i not in to_remove]
            x_new = x[remaining_indices]
            index_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(remaining_indices)}
            new_edges = []
            for i in range(edge_index.size(1)):
                src, dst = edge_index[0, i].item(), edge_index[1, i].item()
                if src not in to_remove and dst not in to_remove:
                    new_edges.append([index_mapping[src], index_mapping[dst]])
            if new_edges:
                edge_index_new = torch.tensor(new_edges, dtype=torch.long).T
            else:
                edge_index_new = torch.empty((2, 0), dtype=torch.long)
            compressed_data = Data(x=x_new, edge_index=edge_index_new)
            return compressed_data, to_remove, alpha


# ----------------------
# GCN classifier
# ----------------------
class GCNClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels=32):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.fc = nn.Linear(hidden_channels, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        return torch.sigmoid(self.fc(x)).view(-1)


# ----------------------
# Build lists of Data (graphs)
# ----------------------
def build_graph_list(df, audio_dir, limit=None, n_mfcc=13, pool_size=4):
    graphs = []
    rows = df.reset_index(drop=True)
    if limit is not None:
        rows = rows.iloc[:limit]
    for i, row in rows.iterrows():
        utt = row['speaker_id']
        file_path = os.path.join(audio_dir, f"{utt}.flac")
        if not os.path.isfile(file_path):
            logger.warning(f"Missing file {file_path}, skipping")
            continue
        g = audio_to_graph(file_path, n_mfcc=n_mfcc, pool_size=pool_size)
        g.y = torch.tensor(1 if row['label']=='bonafide' else 0, dtype=torch.float)
        graphs.append(g)
    return graphs



# ----------------------
# Training + evaluation helpers
# ----------------------
def train_epoch(compressor, classifier, loader, optimizer, criterion, device, epoch):
    classifier.train()
    compressor.train()
    total_cls_loss = 0.0
    total_comp_loss = 0.0
    n = 0
    for batch_idx, batch in enumerate(loader):
        batch = batch.to(device)
        # batch_size==1 assumption: batch.batch is zeros vector; find node indices for graph
        x = batch.x
        edge_index = batch.edge_index

        # Run compressor (differentiable)
        comp_res = compressor(x, edge_index)
        x_compressed = comp_res['compressed_features']
        comp_loss = comp_res['loss']

        # Replace features in a new Data object and classify
        compressed_data = Data(x=x_compressed, edge_index=edge_index, y=batch.y, batch=torch.zeros(x_compressed.size(0), dtype=torch.long, device=device))
        compressed_data = compressed_data.to(device)

        pred = classifier(compressed_data)
        cls_loss = criterion(pred, batch.y)

        loss = cls_loss + lambda_comp_loss * comp_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_cls_loss += cls_loss.item()
        total_comp_loss += comp_loss.item()
        n += 1

        if (batch_idx + 1) % 20 == 0:
            logger.info(f"Epoch {epoch} | Batch {batch_idx+1}/{len(loader)} | cls_loss={cls_loss.item():.4f} comp_loss={comp_loss.item():.4f}")

    return total_cls_loss / n, total_comp_loss / n


def evaluate_with_hard_compression(compressor, classifier, loader, device):
    classifier.eval()
    compressor.eval()
    y_true, y_pred, y_prob = [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            # Hard compression inference
            compressed_data, removed, alpha = compressor.hard_compression_inference(batch.x, batch.edge_index)
            # attach label and batch vector
            compressed_data.y = batch.y.cpu()
            compressed_data.batch = torch.zeros(compressed_data.x.size(0), dtype=torch.long)
            compressed_data = compressed_data.to(device)

            prob = classifier(compressed_data)
            pred = (prob >= 0.5).long()
            y_true.append(int(batch.y.cpu().item()))
            y_pred.append(int(pred.cpu().item()))
            y_prob.append(float(prob.cpu().item()))

    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    auc = roc_auc_score(y_true, y_prob) if len(set(y_true)) > 1 else float('nan')
    return {"acc":acc, "prec":prec, "rec":rec, "f1":f1, "auc":auc}



save_path = "/content/gcn_compressed_asvspoof.pth"


Using Colab cache for faster access to the 'asvpoof-2019-dataset' dataset.
14:08:25 [INFO] Dataset path: /kaggle/input/asvpoof-2019-dataset
14:08:25 [INFO] Protocol sample:
    utt_id    speaker_id system_id attack_id     label
0  LA_0079  LA_T_1138215         -         -  bonafide
1  LA_0079  LA_T_1271820         -         -  bonafide
2  LA_0079  LA_T_1272637         -         -  bonafide
3  LA_0079  LA_T_1276960         -         -  bonafide
4  LA_0079  LA_T_1341447         -         -  bonafide


In [None]:
# ======================
# Run compression + GCN training
# ======================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Build dataset (limit to few samples for demo — remove limit for full training)
train_graphs = build_graph_list(protocol_df, train_audio_dir, limit=10)
train_loader = DataLoader(train_graphs, batch_size=1, shuffle=True)

# Initialize models
feature_dim = 13
compressor = DifferentiableGraphCompressor(feature_dim).to(device)
classifier = GCNClassifier(in_channels=feature_dim).to(device)

# Training setup
criterion = nn.BCELoss()
optimizer = optim.Adam(
    list(compressor.parameters()) + list(classifier.parameters()), lr=1e-3
)
lambda_comp_loss = 0.1
epochs = 6  # increase later when GPU runtime allows

# Train
for epoch in range(1, epochs + 1):
    cls_loss, comp_loss = train_epoch(
        compressor, classifier, train_loader, optimizer, criterion, device, epoch
    )
    logger.info(f"Epoch {epoch}: cls_loss={cls_loss:.4f} | comp_loss={comp_loss:.4f}")

# Evaluate
metrics = evaluate_with_hard_compression(compressor, classifier, train_loader, device)
logger.info(f"Evaluation metrics: {metrics}")

# Save model
torch.save({
    "compressor_state_dict": compressor.state_dict(),
    "classifier_state_dict": classifier.state_dict(),
}, save_path)
logger.info(f"✅ Model saved to {save_path}")


14:09:20 [INFO] Using device: cpu
14:09:24 [INFO] Epoch 1: cls_loss=0.0039 | comp_loss=0.0890
14:09:28 [INFO] Epoch 2: cls_loss=0.0000 | comp_loss=0.0890
14:09:32 [INFO] Epoch 3: cls_loss=0.0000 | comp_loss=0.0890
14:09:36 [INFO] Epoch 4: cls_loss=0.0000 | comp_loss=0.0890
14:09:39 [INFO] Epoch 5: cls_loss=0.0000 | comp_loss=0.0890
14:09:42 [INFO] Epoch 6: cls_loss=0.0000 | comp_loss=0.0890
14:09:44 [INFO] Evaluation metrics: {'acc': 1.0, 'prec': 1.0, 'rec': 1.0, 'f1': 1.0, 'auc': nan}
14:09:44 [INFO] ✅ Model saved to /content/gcn_compressed_asvspoof.pth


In [None]:
# Standalone evaluation script for ASVspoof2019 LA (hard-compression + GCN)
# Run in a fresh session. Make sure `save_path`, `proto_dir`, `eval_audio_dir` point to real files.

import os, sys, logging
import numpy as np
import pandas as pd
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, classification_report

# Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)])
logger = logging.getLogger("eval-standalone")

# ----------------------
# Configs: update these if different in your environment
# ----------------------
save_path = "/content/gcn_compressed_asvspoof.pth"   # path to saved checkpoint
# dataset root / protocol dirs (ensure these are correct)
# If you used kagglehub.dataset_download earlier, use the same dataset_path values.
# dataset_base = "/kaggle/input/asvpoof-2019-dataset"   # change if your dataset is elsewhere
dataset_base = "/root/.cache/kagglehub/datasets/awsaf49/asvpoof-2019-dataset/versions/1"
dataset_path = os.path.join(dataset_base, "LA", "LA")  # same layout as training session
proto_dir = os.path.join(dataset_path, "ASVspoof2019_LA_cm_protocols")
eval_audio_dir = os.path.join(dataset_path, "ASVspoof2019_LA_eval", "flac")
test_proto = os.path.join(proto_dir, "ASVspoof2019.LA.cm.eval.trl.txt")

# ----------------------
# audio -> temporal graph helper (same logic as training)
# ----------------------
def audio_to_graph(file_path, sr=16000, n_mfcc=13, pool_size=4):
    y, _ = librosa.load(file_path, sr=sr)
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=n_mfcc)
    T = mfcc.shape[1] // pool_size
    if T < 2:
        mfcc = np.pad(mfcc, ((0,0),(0, pool_size*2 - mfcc.shape[1])), mode='wrap')
        T = mfcc.shape[1] // pool_size
    pooled = np.stack([ np.mean(mfcc[:, i*pool_size:(i+1)*pool_size], axis=1) for i in range(T) ])
    x = torch.tensor(pooled, dtype=torch.float)
    if x.size(0) < 2:
        # ensure at least 2 nodes (avoid zero-edge)
        x = torch.cat([x, x], dim=0)
    edge_index = torch.tensor([[i,i+1] for i in range(x.size(0)-1)], dtype=torch.long).T
    return Data(x=x, edge_index=edge_index)

def build_graph_list(df, audio_dir, limit=None, n_mfcc=13, pool_size=4):
    graphs = []
    rows = df.reset_index(drop=True)
    if limit is not None:
        rows = rows.iloc[:limit]
    for _, row in rows.iterrows():
        utt = row['speaker_id']
        file_path = os.path.join(audio_dir, f"{utt}.flac")
        if not os.path.isfile(file_path):
            logger.warning(f"Missing file {file_path}, skipping")
            continue
        g = audio_to_graph(file_path, n_mfcc=n_mfcc, pool_size=pool_size)
        g.y = torch.tensor(1 if row['label']=='bonafide' else 0, dtype=torch.float)
        graphs.append(g)
    return graphs

# ----------------------
# Model definitions (same as training)
# ----------------------
class DifferentiableGraphCompressor(nn.Module):
    def __init__(self, feature_dim, tau_T=1.0, lambda_id=1.0, lambda_comp=0.1):
        super().__init__()
        self.a_raw = nn.Parameter(torch.tensor(0.0))
        self.tau_T = tau_T
        self.lambda_id = lambda_id
        self.lambda_comp = lambda_comp
        self.speaker_embedding = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32)
        )

    def get_alpha(self):
        return torch.sigmoid(self.a_raw)

    def node_similarity(self, x_u, x_v):
        return F.cosine_similarity(x_u.unsqueeze(0), x_v.unsqueeze(0), dim=1)

    def neighborhood_similarity(self, x, edge_index, u, v):
        def get_neighbors(node):
            mask = (edge_index[0] == node)
            if mask.any():
                return edge_index[1][mask]
            else:
                return torch.tensor([], dtype=torch.long, device=x.device)
        neighbors_u = get_neighbors(u)
        neighbors_v = get_neighbors(v)
        if neighbors_u.numel() == 0 or neighbors_v.numel() == 0:
            return torch.tensor(0.0, device=x.device)
        m_u = x[neighbors_u].mean(dim=0)
        m_v = x[neighbors_v].mean(dim=0)
        return F.cosine_similarity(m_u.unsqueeze(0), m_v.unsqueeze(0), dim=1)

    def combined_similarity(self, x, edge_index, u, v):
        alpha = self.get_alpha()
        sim_x = self.node_similarity(x[u], x[v])
        sim_m = self.neighborhood_similarity(x, edge_index, u, v)
        return alpha * sim_x + (1 - alpha) * sim_m

    def compute_adaptive_thresholds(self, similarities):
        tau_1_bar = torch.quantile(similarities, 0.75)
        tau_2_bar = torch.quantile(similarities, 0.60)
        return tau_1_bar, tau_2_bar

    def compute_merge_probabilities(self, x, edge_index, window_size=10):
        num_nodes = x.size(0)
        merge_probs = torch.zeros(num_nodes, device=x.device)
        all_similarities = []
        for t in range(1, num_nodes - 1):
            for k in range(1, min(window_size + 1, num_nodes - t)):
                if t + k >= num_nodes - 1:
                    continue
                try:
                    s1 = self.combined_similarity(x, edge_index, t, t + k)
                    s2 = self.combined_similarity(x, edge_index, t - 1, t + k - 1)
                    s3 = self.combined_similarity(x, edge_index, t + 1, t + k + 1)
                    all_similarities.extend([s1, s2, s3])
                except Exception:
                    continue
        if len(all_similarities) == 0:
            return merge_probs
        similarities_tensor = torch.stack(all_similarities).view(-1)
        tau_1_bar, tau_2_bar = self.compute_adaptive_thresholds(similarities_tensor)
        for t in range(1, num_nodes - 1):
            max_prob = 0.0
            for k in range(1, min(window_size + 1, num_nodes - t)):
                if t + k >= num_nodes - 1:
                    continue
                try:
                    s1 = self.combined_similarity(x, edge_index, t, t + k)
                    s2 = self.combined_similarity(x, edge_index, t - 1, t + k - 1)
                    s3 = self.combined_similarity(x, edge_index, t + 1, t + k + 1)
                    gate1 = torch.sigmoid((s1 - tau_1_bar) / self.tau_T)
                    gate2 = torch.sigmoid((s2 - tau_2_bar) / self.tau_T)
                    gate3 = torch.sigmoid((s3 - tau_2_bar) / self.tau_T)
                    prob = gate1 * gate2 * gate3
                    max_prob = max(max_prob, prob.item())
                except Exception:
                    continue
            merge_probs[t] = max_prob
        return merge_probs

    def differentiable_compression(self, x, edge_index):
        merge_probs = self.compute_merge_probabilities(x, edge_index)
        x_compressed = x.clone()
        num_nodes = x.size(0)
        for t in range(1, num_nodes - 1):
            if merge_probs[t] > 0:
                best_k = 1
                best_sim = -1
                for k in range(1, min(10, num_nodes - t)):
                    if t + k >= num_nodes: break
                    try:
                        sim = self.combined_similarity(x, edge_index, t, t + k)
                        if sim > best_sim:
                            best_sim = sim
                            best_k = k
                    except Exception:
                        continue
                if t + best_k < num_nodes:
                    p = merge_probs[t]
                    interpolated = (x[t] + x[t + best_k]) / 2
                    x_compressed[t] = (1 - p) * x[t] + p * interpolated
        return x_compressed, merge_probs

    def speaker_identity_loss(self, x_original, x_compressed):
        g_original = self.speaker_embedding(x_original.mean(dim=0))
        g_compressed = self.speaker_embedding(x_compressed.mean(dim=0))
        cos_sim = F.cosine_similarity(g_original.unsqueeze(0), g_compressed.unsqueeze(0))
        return 1 - cos_sim

    def compression_loss(self, merge_probs):
        return (1 - merge_probs).mean()

    def forward(self, x, edge_index):
        x_compressed, merge_probs = self.differentiable_compression(x, edge_index)
        L_id = self.speaker_identity_loss(x, x_compressed)
        L_comp = self.compression_loss(merge_probs)
        total_loss = self.lambda_id * L_id + self.lambda_comp * L_comp
        return {
            'compressed_features': x_compressed,
            'merge_probs': merge_probs,
            'loss': total_loss,
            'loss_id': L_id,
            'loss_comp': L_comp,
            'alpha': self.get_alpha()
        }

    def hard_compression_inference(self, x, edge_index, window_size=10):
        self.eval()
        with torch.no_grad():
            alpha = self.get_alpha().item()
            num_nodes = x.size(0)
            to_remove = set()
            all_similarities = []
            for t in range(1, num_nodes - 1):
                for k in range(1, min(window_size + 1, num_nodes - t)):
                    if t + k >= num_nodes - 1: continue
                    try:
                        s1 = self.combined_similarity(x, edge_index, t, t + k)
                        s2 = self.combined_similarity(x, edge_index, t - 1, t + k - 1)
                        s3 = self.combined_similarity(x, edge_index, t + 1, t + k + 1)
                        all_similarities.extend([s1, s2, s3])
                    except Exception:
                        continue
            if all_similarities:
                similarities_tensor = torch.stack(all_similarities).view(-1)
                tau_1_hat, tau_2_hat = self.compute_adaptive_thresholds(similarities_tensor)
                for t in range(1, num_nodes - 1):
                    if t in to_remove: continue
                    for k in range(1, min(window_size + 1, num_nodes - t)):
                        if t + k >= num_nodes - 1 or any(n in to_remove for n in [t, t+k, t-1, t+k-1, t+1, t+k+1]):
                            continue
                        try:
                            s1 = self.combined_similarity(x, edge_index, t, t + k)
                            s2 = self.combined_similarity(x, edge_index, t - 1, t + k - 1)
                            s3 = self.combined_similarity(x, edge_index, t + 1, t + k + 1)
                            if s1 >= tau_1_hat and s2 >= tau_2_hat and s3 >= tau_2_hat:
                                to_remove.add(t)
                                break
                        except Exception:
                            continue
            remaining_indices = [i for i in range(num_nodes) if i not in to_remove]
            x_new = x[remaining_indices]
            index_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(remaining_indices)}
            new_edges = []
            for i in range(edge_index.size(1)):
                src, dst = edge_index[0, i].item(), edge_index[1, i].item()
                if src not in to_remove and dst not in to_remove:
                    new_edges.append([index_mapping[src], index_mapping[dst]])
            if new_edges:
                edge_index_new = torch.tensor(new_edges, dtype=torch.long).T
            else:
                edge_index_new = torch.empty((2, 0), dtype=torch.long)
            compressed_data = Data(x=x_new, edge_index=edge_index_new)
            return compressed_data, to_remove, alpha

class GCNClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels=32):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.fc = nn.Linear(hidden_channels, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        return torch.sigmoid(self.fc(x)).view(-1)

# ----------------------
# Evaluation utility (same as training)
# ----------------------
def evaluate_with_hard_compression(compressor, classifier, loader, device):
    classifier.eval()
    compressor.eval()
    y_true, y_pred, y_prob = [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            compressed_data, removed, alpha = compressor.hard_compression_inference(batch.x, batch.edge_index)
            compressed_data.y = batch.y.cpu()
            compressed_data.batch = torch.zeros(compressed_data.x.size(0), dtype=torch.long)
            compressed_data = compressed_data.to(device)

            prob = classifier(compressed_data)  # length 1 tensor (batch_size=1)
            pred = (prob >= 0.5).long()
            y_true.append(int(batch.y.cpu().item()))
            y_pred.append(int(pred.cpu().item()))
            y_prob.append(float(prob.cpu().item()))

    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    auc = roc_auc_score(y_true, y_prob) if len(set(y_true)) > 1 else float('nan')
    return {"acc":acc, "prec":prec, "rec":rec, "f1":f1, "auc":auc, "y_true": y_true, "y_pred": y_pred, "y_prob": y_prob}

# ----------------------
# Main evaluation flow
# ----------------------
if not os.path.isfile(save_path):
    raise FileNotFoundError(f"Checkpoint not found at {save_path}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Read test protocol
if not os.path.isfile(test_proto):
    raise FileNotFoundError(f"Test protocol not found at {test_proto}. Update proto_dir/test_proto paths.")
test_df = pd.read_csv(test_proto, sep=" ", header=None)
test_df.columns = ["utt_id", "speaker_id", "system_id", "attack_id", "label"]
logger.info(f"Test protocol contains {len(test_df)} rows")

# Build test graphs (this will read/flac files — can be slow)
logger.info("Building test graphs (this may take a few minutes)...")
test_graphs = build_graph_list(test_df, eval_audio_dir, limit=None)
if len(test_graphs) == 0:
    raise RuntimeError("No test graphs were built — check eval_audio_dir and speaker_id->filename mapping")
logger.info(f"Built {len(test_graphs)} test graphs")

test_loader = DataLoader(test_graphs, batch_size=1, shuffle=False)

# Recreate models (must match training hyperparams)
feature_dim = test_graphs[0].x.size(1)
compressor = DifferentiableGraphCompressor(feature_dim, tau_T=1.0, lambda_id=1.0, lambda_comp=0.05).to(device)
classifier = GCNClassifier(in_channels=feature_dim, hidden_channels=32).to(device)

# Load checkpoint
ckpt = torch.load(save_path, map_location=device)
compressor.load_state_dict(ckpt['compressor_state'])
classifier.load_state_dict(ckpt['classifier_state'])
logger.info(f"Loaded checkpoint from {save_path} (saved epoch={ckpt.get('epoch','?')}, dev_f1={ckpt.get('dev_f1','?')})")

# Evaluate
metrics = evaluate_with_hard_compression(compressor, classifier, test_loader, device)
logger.info("===== Test Results =====")
logger.info(f"Accuracy : {metrics['acc']:.4f}")
logger.info(f"Precision: {metrics['prec']:.4f}")
logger.info(f"Recall   : {metrics['rec']:.4f}")
logger.info(f"F1-score : {metrics['f1']:.4f}")
logger.info(f"AUC      : {metrics['auc']}")

# Additional reports
y_true = metrics['y_true']
y_pred = metrics['y_pred']
logger.info("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred))
logger.info("Classification report:")
print(classification_report(y_true, y_pred, digits=4, zero_division=0))


14:10:10 [INFO] Using device: cpu
14:10:10 [INFO] Test protocol contains 71237 rows
14:10:10 [INFO] Building test graphs (this may take a few minutes)...
14:29:54 [INFO] Built 71237 test graphs


KeyError: 'compressor_state'