In [None]:
!unzip '/content/organized_audio_output_nested.zip' -d '/content/'

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/organized_audio_output_nested/train/natural/6097/roots_37_morris_0055.wav  
  inflating: /content/organized_audio_output_nested/train/natural/6097/roots_39_morris_0212.wav  
  inflating: /content/organized_audio_output_nested/train/natural/6097/roots_40_morris_0029.wav  
  inflating: /content/organized_audio_output_nested/train/natural/6097/roots_40_morris_0362.wav  
  inflating: /content/organized_audio_output_nested/train/natural/6097/roots_41_morris_0007.wav  
  inflating: /content/organized_audio_output_nested/train/natural/6097/roots_41_morris_0183.wav  
  inflating: /content/organized_audio_output_nested/train/natural/6097/roots_42_morris_0025.wav  
  inflating: /content/organized_audio_output_nested/train/natural/6097/roots_42_morris_0037.wav  
  inflating: /content/organized_audio_output_nested/train/natural/6097/roots_43_morris_0234.wav  
  inflating: /content/organized_audio_output_nested/t

In [None]:
BASE_DIR = '/content/organized_audio_output_nested'
TRAIN_PROTOCOL_PATH = os.path.join(BASE_DIR, "protocols/train.txt")
DEV_PROTOCOL_PATH = os.path.join(BASE_DIR, "protocols/dev.txt")
EVAL_PROTOCOL_PATH = os.path.join(BASE_DIR, "protocols/eval.txt")
EVAL_ASV_SCORES_PATH = os.path.join(BASE_DIR, "protocols/eval_asv_scores.txt") # Example path, adjust!

LEARNING FROM YOURSELF: A SELF-DISTILLATION METHOD FOR FAKE SPEECH DETECTION

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import librosa
import math
import glob
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# === Configuration ===
# For Google Colab (or adjust as needed)
BASE_DIR = '/content/organized_audio_output_nested'
# --- Protocol file paths ---
# You MUST ensure these paths are correct for your setup
TRAIN_PROTOCOL_PATH = os.path.join(BASE_DIR, "protocols/train.txt")
DEV_PROTOCOL_PATH = os.path.join(BASE_DIR, "protocols/dev.txt")
EVAL_PROTOCOL_PATH = os.path.join(BASE_DIR, "protocols/eval.txt")
# --- ASV Scores Path (REQUIRED FOR t-DCF) ---
# You MUST provide the path to the ASV scores file for the evaluation set
# The file should map utterance IDs (like 'LA_E_1234567') to ASV scores
EVAL_ASV_SCORES_PATH = os.path.join(BASE_DIR, "protocols/eval_asv_scores.txt") # Example path, adjust!

SAMPLE_RATE = 16000

# --- Training parameters ---
BATCH_SIZE = 32
NUM_EPOCHS = 30 # As per paper: 32 epochs mentioned, adjust if needed
LEARNING_RATE = 1e-4 # As per paper
WEIGHT_DECAY = 1e-4 # As per paper
ADAM_BETA1 = 0.9 # As per paper
ADAM_BETA2 = 0.98 # As per paper
ADAM_EPS = 1e-9 # As per paper

# --- Model parameters ---
MODEL_TYPE = 'eca'  # 'eca' or 'se'
MODEL_DEPTH = 18    # 9, 18, 34, or 50

# --- Self-Distillation parameters ---
ALPHA = 0.7         # Weight for hard loss (from paper's Table 3 caption)
BETA = 0.3          # Weight for feature loss (from paper's Table 3 caption)
TEMPERATURE = 3.0   # Temperature for KL divergence (common default, verify if paper specifies)
ASOFTMAX_MARGIN = 4 # Margin 'm' for A-Softmax (common default, verify if paper specifies)

# --- Feature extraction parameters ---
N_FFT = 1728        # From paper
HOP_LENGTH = 130    # From paper
N_FREQ = 45         # F0 subband dim (from paper)
N_FRAMES = 600      # From paper
WINDOW = 'blackman' # From paper

# --- t-DCF parameters (ASVspoof 2019 standard) ---
P_TARGET = 0.05
C_MISS = 1
C_FA = 10

# === Dataset Class ===
class FSDDataset(Dataset):
    def __init__(self, protocol_file, base_dir, transform=None, return_utt_id=False):
        """
        Args:
            protocol_file (str): Path to the protocol file
            base_dir (str): Base directory containing audio files
            transform (callable, optional): Optional transform for feature extraction
            return_utt_id (bool): If True, returns utterance ID along with data and label
        """
        self.base_dir = base_dir
        self.transform = transform
        self.return_utt_id = return_utt_id

        self.data = []
        with open(protocol_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    # Protocol format from directory.py:
                    # relative_path label
                    # e.g., "train/natural/subdir1/audio1.wav bonafide"
                    relative_path = parts[0]
                    label_str = parts[-1]

                    # Extract a unique ID from the filename portion of the path
                    # This assumes the filename (without extension) can serve as an utterance ID
                    filename = os.path.basename(relative_path)
                    utt_id = os.path.splitext(filename)[0]  # Remove extension

                    label = 1 if label_str == 'spoof' else 0
                    self.data.append((relative_path, label, utt_id))

        print(f"Loaded {len(self.data)} samples from {protocol_file}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        relative_path, label, utt_id = self.data[idx]
        full_path = os.path.join(self.base_dir, relative_path)

        try:
            audio, _ = librosa.load(full_path, sr=SAMPLE_RATE, mono=True)

            if self.transform:
                features = self.transform(audio)
                features = np.expand_dims(features, axis=0)  # Add channel dim
                features_tensor = torch.FloatTensor(features)
                if self.return_utt_id:
                    return features_tensor, label, utt_id
                else:
                    return features_tensor, label
            else:
                if self.return_utt_id:
                    return audio, label, utt_id
                else:
                    return audio, label

        except Exception as e:
            print(f"Error processing {full_path}: {e}")
            # Return zeros and handle potential downstream issues
            if self.transform:
                zero_features = np.zeros((1, N_FREQ, N_FRAMES))
                if self.return_utt_id:
                    return torch.FloatTensor(zero_features), label, utt_id
                else:
                    return torch.FloatTensor(zero_features), label
            else:
                zero_audio = np.zeros(SAMPLE_RATE)
                if self.return_utt_id:
                    return zero_audio, label, utt_id
                else:
                    return zero_audio, label

# === Front-end Feature Extraction (Unchanged) ===
class F0SubbandFeatureExtractor:
    def __init__(self, n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=N_FFT,
                 window=WINDOW, n_freq=N_FREQ, n_frames=N_FRAMES):
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = window
        self.n_freq = n_freq
        self.n_frames = n_frames

    def __call__(self, audio):
        # Extract the LPS (Log Power Spectrogram)
        D = librosa.stft(audio, n_fft=self.n_fft, hop_length=self.hop_length,
                        win_length=self.win_length, window=self.window)
        S = np.abs(D) ** 2
        log_S = librosa.power_to_db(S)
        f0_subband = log_S[:self.n_freq, :]
        if f0_subband.shape[1] < self.n_frames:
            pad_width = self.n_frames - f0_subband.shape[1]
            f0_subband = np.pad(f0_subband, ((0, 0), (0, pad_width)))
        else:
            f0_subband = f0_subband[:, :self.n_frames]
        return f0_subband

# === Building Blocks (Unchanged) ===
class ConvBlock(nn.Module): # Basic Conv -> BN -> ReLU
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) # Bias false with BN
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class ECABlock(nn.Module): # Efficient Channel Attention
    def __init__(self, channels, gamma=2, b=1):
        super(ECABlock, self).__init__()
        t = int(abs(math.log(channels, 2) + b) / gamma)
        k = t if t % 2 else t + 1
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

class SEBlock(nn.Module): # Squeeze and Excitation
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False), nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False), nn.Sigmoid())
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class AttentionResidualBlock(nn.Module): # ECA or SE Residual block
    def __init__(self, in_channels, out_channels, stride=1, downsample=None,
                 attention_type='eca', reduction=16):
        super(AttentionResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        if attention_type == 'eca': self.attention = ECABlock(out_channels)
        else: self.attention = SEBlock(out_channels, reduction)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.attention(out)
        if self.downsample is not None: identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

# === AngleLinear Layer (Unchanged, ASoftmax logic is in the loss) ===
class AngleLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(AngleLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, x):
        # Input x is features after pooling
        x_norm = F.normalize(x, p=2, dim=1)
        w_norm = F.normalize(self.weight, p=2, dim=1)
        cos_theta = F.linear(x_norm, w_norm)
        return cos_theta.clamp(-1, 1) # Clamp for numerical stability

# === Main Model (Unchanged Structure) ===
class SelfDistillationFSD(nn.Module):
    def __init__(self, block_type='eca', depth=18, num_classes=2, in_channels=1):
        super(SelfDistillationFSD, self).__init__()
        self.block_type = block_type
        if depth == 9: blocks = [1, 1, 1, 1]
        elif depth == 18: blocks = [2, 2, 2, 2]
        elif depth == 34: blocks = [3, 4, 6, 3]
        elif depth == 50: blocks = [3, 4, 6, 3] # Need Bottleneck block for true ResNet50
        else: raise ValueError(f"Unsupported depth: {depth}")

        # --- Network structure ---
        self.conv1 = ConvBlock(in_channels, 16) # Initial block
        self.block1 = self._make_layer(16, 32, blocks[0], stride=1)
        self.block2 = self._make_layer(32, 64, blocks[1], stride=2)
        self.block3 = self._make_layer(64, 128, blocks[2], stride=2)
        self.block4 = self._make_layer(128, 256, blocks[3], stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # --- Classifiers ---
        self.classifier1 = AngleLinear(32, num_classes)
        self.classifier2 = AngleLinear(64, num_classes)
        self.classifier3 = AngleLinear(128, num_classes)
        self.classifier4 = AngleLinear(256, num_classes) # Main classifier

    def _make_layer(self, in_channels, out_channels, num_blocks, stride=1):
        downsample = None
        if stride != 1 or in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels))
        layers = [AttentionResidualBlock(in_channels, out_channels, stride, downsample, self.block_type)]
        for _ in range(1, num_blocks):
            layers.append(AttentionResidualBlock(out_channels, out_channels, attention_type=self.block_type))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x1 = self.block1(x)
        f1 = self.avg_pool(x1).view(x1.size(0), -1)
        x2 = self.block2(x1)
        f2 = self.avg_pool(x2).view(x2.size(0), -1)
        x3 = self.block3(x2)
        f3 = self.avg_pool(x3).view(x3.size(0), -1)
        x4 = self.block4(x3)
        f4 = self.avg_pool(x4).view(x4.size(0), -1)

        out1 = self.classifier1(f1)
        out2 = self.classifier2(f2)
        out3 = self.classifier3(f3)
        out4 = self.classifier4(f4) # Main output

        if self.training:
            return {'features': [f1, f2, f3, f4], 'logits': [out1, out2, out3, out4]}
        else:
            return out4 # Return only main classifier output during evaluation

# === Loss Functions ===

# --- A-Softmax Loss Implementation ---
class ASoftmaxLoss(nn.Module):
    def __init__(self, margin=4, gamma=0, base=1000.0, power=2):
        super(ASoftmaxLoss, self).__init__()
        self.margin = margin # m
        self.gamma = gamma   # For scaling/annealing margin if needed (usually 0)
        self.base = base
        self.power = power
        self.LambdaMin = 5.0
        self.iter = 0
        # C = number of classes (usually 2 for binary spoof/bona fide)
        # self.index = torch.zeros(N, C) # This needs batch size N, handled dynamically
        # self.label = torch.zeros(N)    # Handled dynamically

    def calculate_psi_theta(self, cos_theta):
        # Equation 4 from SphereFace paper (https://arxiv.org/pdf/1704.08063.pdf)
        # psi(theta) = (-1)^k * cos(m*theta) - 2*k
        # where theta is in [k*pi/m, (k+1)*pi/m]
        theta = torch.acos(cos_theta) # Angle in radians
        k = torch.floor(self.margin * theta / math.pi)
        psi_theta = ((-1)**k) * torch.cos(self.margin * theta) - (2 * k)
        return psi_theta

    def forward(self, cosine_similarities, labels):
        self.iter += 1
        batch_size = cosine_similarities.size(0)

        # Get cosine for the target class
        one_hot_labels = F.one_hot(labels, num_classes=cosine_similarities.size(1))
        cos_theta_target = torch.sum(cosine_similarities * one_hot_labels, dim=1)

        # Calculate psi(theta) for the target class
        psi_theta_target = self.calculate_psi_theta(cos_theta_target)

        # Annealed lambda (optional, from SphereFace)
        lambda_val = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter)**(-self.power))

        # Adjust the target logit using psi(theta)
        # Fix dimensions to allow for proper broadcasting:
        # We need to reshape target_logits from (batch_size,) to (batch_size, 1)
        # so it can properly broadcast when multiplied with one_hot_labels
        scaled_logits = cosine_similarities * lambda_val
        target_logits = psi_theta_target - cos_theta_target  # Difference due to margin
        target_logits = target_logits.unsqueeze(1)  # Reshape to (batch_size, 1) for broadcasting

        # Now the broadcasting will work correctly
        final_logits = scaled_logits + one_hot_labels * target_logits * lambda_val

        # Standard Cross Entropy on the modified logits
        loss = F.cross_entropy(final_logits, labels)
        return loss


# --- Self-Distillation Loss (Updated) ---
class SelfDistillationLoss(nn.Module):
    def __init__(self, alpha=ALPHA, beta=BETA, temperature=TEMPERATURE, asoftmax_margin=ASOFTMAX_MARGIN):
        super(SelfDistillationLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.temperature = temperature
        # Use A-Softmax for the hard loss component
        self.hard_loss_fn = ASoftmaxLoss(margin=asoftmax_margin)
        self.feature_projection_cache = {} # Cache projections

    def get_projection(self, in_dim, out_dim, device):
        key = (in_dim, out_dim)
        if key not in self.feature_projection_cache:
            # Create and store a new projection layer if needed
            projection = nn.Linear(in_dim, out_dim, bias=False).to(device)
            # Consider initializing projection weights (e.g., Xavier)
            nn.init.xavier_uniform_(projection.weight)
            self.feature_projection_cache[key] = projection
        return self.feature_projection_cache[key]

    def forward(self, outputs, labels):
        features = outputs['features']
        logits = outputs['logits'] # These are cosine similarities from AngleLinear

        # Hard loss (A-Softmax on the deepest network)
        hard_loss = self.hard_loss_fn(logits[-1], labels)

        # Soft loss (KL divergence between teacher and student models)
        soft_loss = 0
        teacher_logits_detached = logits[-1].detach() # Detach teacher

        for i in range(len(logits) - 1):
            student_logits = logits[i]
            soft_loss += self.kl_divergence_loss(student_logits, teacher_logits_detached)

        # Feature loss (L2 loss between deepest and shallow features)
        feature_loss = 0
        teacher_features_detached = features[-1].detach() # Detach teacher

        for i in range(len(features) - 1):
            student_features = features[i]
            # Projection layer if dimensions don't match
            # Note: Using nn.Linear is a simple choice. Other projections
            # (e.g., 1x1 Conv before pooling, MLP) might be alternatives.
            if student_features.size(1) != teacher_features_detached.size(1):
                 projection = self.get_projection(student_features.size(1), teacher_features_detached.size(1), student_features.device)
                 projected_student_features = projection(student_features)
                 feature_loss += F.mse_loss(projected_student_features, teacher_features_detached)
            else:
                feature_loss += F.mse_loss(student_features, teacher_features_detached)


        # Combine losses
        total_loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss + self.beta * feature_loss

        # Log individual loss components for debugging if needed
        # print(f"Hard: {hard_loss.item():.4f}, Soft: {soft_loss.item():.4f}, Feat: {feature_loss.item():.4f}")

        return total_loss

    def kl_divergence_loss(self, student_logits, teacher_logits):
        # Student logits and teacher logits are cosine similarities here
        # Softmax is typically applied *before* KL divergence for distillation
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1) # Teacher uses softmax, not log_softmax
        kl_div = F.kl_div(soft_student, soft_teacher, reduction='batchmean', log_target=False) # log_target=False as teacher is softmax
        # Scale by T^2 as per Hinton's distillation paper
        return kl_div * (self.temperature ** 2)


# === EER & t-DCF Calculation ===
def compute_eer(scores, labels):
    """Computes the Equal Error Rate (EER)"""
    fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1) # bona fide is 0 (neg), spoof is 1 (pos)
    fnr = 1 - tpr
    # Find the operating point where |fpr - fnr| is minimal
    eer_threshold = thresholds[np.nanargmin(np.absolute(fnr - fpr))]
    # Calculate EER using interpolation
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.) * 100
    return eer

def compute_tDCF(bonafide_scores, spoof_scores, P_target=P_TARGET, C_miss=C_MISS, C_fa=C_FA):
    """Computes the Tandem Detection Cost Function (t-DCF)"""
    # Equations are from the ASVspoof 2019 paper Section 4.1 & B.1
    N_bona = len(bonafide_scores)
    N_spoof = len(spoof_scores)

    # Sort scores
    all_scores = np.concatenate((bonafide_scores, spoof_scores))
    all_labels = np.concatenate((np.zeros(N_bona), np.ones(N_spoof)))
    indices = np.argsort(all_scores)
    sorted_labels = all_labels[indices]
    sorted_scores = all_scores[indices]

    # Calculate P_miss and P_fa for all possible thresholds
    P_miss = np.cumsum(sorted_labels) / N_spoof  # Cumulative sum of spoofs mistaken as bona fide
    P_fa = (N_bona - np.cumsum(1 - sorted_labels)) / N_bona # Cumulative sum of bona fides mistaken as spoof

    # Effective prior P_eff_target = P_target if using ASV score combination, else P_target
    # For standalone CM system, P_tar_eff = P_target
    P_tar_eff = P_target

    # Calculate t-DCF for all thresholds
    tDCF = P_tar_eff * C_miss * P_miss + (1 - P_tar_eff) * C_fa * P_fa

    # Normalized t-DCF (divided by the baseline cost of classifying everything as bona fide or spoof)
    min_baseline_cost = min(P_tar_eff * C_miss, (1 - P_tar_eff) * C_fa)
    min_tDCF_normalized = np.min(tDCF) / min_baseline_cost

    return min_tDCF_normalized

# Function to load ASV scores (Needs adaptation based on your file format)
def load_asv_scores(filepath):
    asv_scores = {}
    try:
        with open(filepath, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    # Assuming format: utt_id score
                    utt_id = parts[0]
                    score = float(parts[1])
                    asv_scores[utt_id] = score
        print(f"Loaded {len(asv_scores)} ASV scores from {filepath}")
    except FileNotFoundError:
        print(f"Warning: ASV scores file not found at {filepath}. t-DCF calculation will be skipped.")
        return None
    return asv_scores


# === Training Loop (Updated for t-DCF) ===
def train_and_evaluate(model, train_loader, dev_loader, test_loader, criterion, optimizer, num_epochs, device, eval_asv_scores):
    best_dev_metric = float('inf') # Use min_t_dcf as the primary metric
    train_losses = []
    dev_eers = []
    dev_tdcfs = []

    os.makedirs('checkpoints', exist_ok=True)
    print("Starting training...")

    for epoch in range(num_epochs):
        # --- Training Phase ---
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training")
        for batch_idx, (inputs, labels) in enumerate(progress_bar):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            progress_bar.set_postfix({'loss': running_loss / (batch_idx + 1)})
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)

        # --- Validation Phase ---
        model.eval()
        dev_cm_scores = []
        dev_labels = []
        dev_utt_ids = [] # Store utterance IDs if needed for ASV score lookup

        with torch.no_grad():
            for inputs, labels, utt_ids_batch in tqdm(dev_loader, desc="Validation"): # Expect utt_id from loader
                inputs = inputs.to(device)
                outputs = model(inputs) # Get main classifier output (cosine similarities)
                # Convert cosine similarities to scores (e.g., using softmax or just scaling)
                # For tDCF, often raw scores or log-likelihood ratios are used.
                # Let's use the positive class (spoof) cosine similarity directly as the CM score.
                # Score = (cosine_spoof - cosine_genuine) or just cosine_spoof? Let's use raw output.
                # Assume output[:, 1] is spoof score (needs verification based on AngleLinear setup)
                # Let's use softmax for probability-like scores
                scores = F.softmax(outputs, dim=1)[:, 1].cpu().numpy() # Probability of being spoof

                dev_cm_scores.extend(scores)
                dev_labels.extend(labels.numpy())
                dev_utt_ids.extend(utt_ids_batch)

        dev_labels_np = np.array(dev_labels)
        dev_cm_scores_np = np.array(dev_cm_scores)

        # Calculate Dev EER
        dev_eer = compute_eer(dev_cm_scores_np, dev_labels_np)
        dev_eers.append(dev_eer)

        # Calculate Dev t-DCF (using eval ASV scores as placeholder - Needs Dev ASV scores ideally)
        dev_tdcf = float('inf') # Default if ASV scores not available
        if eval_asv_scores: # Use EVAL scores for Dev tDCF calculation as a proxy if Dev scores not loaded
            dev_bona_cm_scores = dev_cm_scores_np[dev_labels_np == 0]
            dev_spoof_cm_scores = dev_cm_scores_np[dev_labels_np == 1]
            # We only need CM scores for standalone t-DCF calculation
            dev_tdcf = compute_tDCF(dev_bona_cm_scores, dev_spoof_cm_scores)
            dev_tdcfs.append(dev_tdcf)
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Dev EER: {dev_eer:.4f}%, Dev t-DCF: {dev_tdcf:.4f}')
        else:
             dev_tdcfs.append(None)
             print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Dev EER: {dev_eer:.4f}% (t-DCF skipped)')


        # Save best model based on t-DCF (primary metric)
        current_metric = dev_tdcf if eval_asv_scores else dev_eer # Use EER if tDCF not available
        metric_name = "tDCF" if eval_asv_scores else "EER"

        if current_metric < best_dev_metric:
            best_dev_metric = current_metric
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss,
                'best_dev_metric': best_dev_metric,
                'metric_name': metric_name
            }, f'checkpoints/best_model_{MODEL_TYPE}{MODEL_DEPTH}.pth')
            print(f'>>> New best model saved with Dev {metric_name}: {best_dev_metric:.4f}')

    # Plotting (include t-DCF if available)
    plt.figure(figsize=(18, 5))
    plt.subplot(1, 3, 1)
    plt.plot(train_losses)
    plt.title('Training Loss'); plt.xlabel('Epoch'); plt.ylabel('Loss')
    plt.grid(True)

    plt.subplot(1, 3, 2)
    plt.plot(dev_eers)
    plt.title('Dev EER (%)'); plt.xlabel('Epoch'); plt.ylabel('EER (%)')
    plt.grid(True)

    if dev_tdcfs and any(t is not None for t in dev_tdcfs):
        plt.subplot(1, 3, 3)
        # Filter out None values for plotting
        epochs_with_tdcf = [i + 1 for i, t in enumerate(dev_tdcfs) if t is not None]
        valid_dev_tdcfs = [t for t in dev_tdcfs if t is not None]
        plt.plot(epochs_with_tdcf, valid_dev_tdcfs)
        plt.title('Dev min t-DCF'); plt.xlabel('Epoch'); plt.ylabel('min t-DCF')
        plt.grid(True)

    plt.tight_layout()
    plt.savefig(f'training_curves_{MODEL_TYPE}{MODEL_DEPTH}.png')
    plt.close() # Close plot to prevent display in non-interactive environments

    # --- Final Test Evaluation ---
    print("Loading best model for final evaluation...")
    # Load best model based on dev metric
    checkpoint = torch.load(f'checkpoints/best_model_{MODEL_TYPE}{MODEL_DEPTH}.pth') # weights_only=False deprecated
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    test_cm_scores = []
    test_labels = []
    test_utt_ids = []

    with torch.no_grad():
        for inputs, labels, utt_ids_batch in tqdm(test_loader, desc="Test Evaluation"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            scores = F.softmax(outputs, dim=1)[:, 1].cpu().numpy() # Spoof probability
            test_cm_scores.extend(scores)
            test_labels.extend(labels.numpy())
            test_utt_ids.extend(utt_ids_batch)

    test_labels_np = np.array(test_labels)
    test_cm_scores_np = np.array(test_cm_scores)

    # Calculate final EER
    test_eer = compute_eer(test_cm_scores_np, test_labels_np)
    print(f'Final Test EER: {test_eer:.4f}%')

    # Calculate final t-DCF
    test_tdcf = float('inf')
    if eval_asv_scores:
        test_bona_cm_scores = test_cm_scores_np[test_labels_np == 0]
        test_spoof_cm_scores = test_cm_scores_np[test_labels_np == 1]
        test_tdcf = compute_tDCF(test_bona_cm_scores, test_spoof_cm_scores)
        print(f'Final Test min t-DCF: {test_tdcf:.4f}')
    else:
        print("Final Test t-DCF: Skipped (ASV scores not provided)")


    return train_losses, dev_eers, dev_tdcfs, test_eer, test_tdcf

# === Main script ===
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create feature extractor
    feature_extractor = F0SubbandFeatureExtractor()

    # --- Load ASV Scores FIRST ---
    # CRITICAL: Ensure EVAL_ASV_SCORES_PATH is correctly set above
    eval_asv_scores = load_asv_scores(EVAL_ASV_SCORES_PATH)
    # NOTE: Ideally, you should also have DEV ASV scores for validation t-DCF.
    # If not, validation uses EVAL scores as a proxy (as coded above) or skips t-DCF.

    # Check if protocols exist
    if not os.path.exists(TRAIN_PROTOCOL_PATH):
        raise FileNotFoundError(f"Train protocol: {TRAIN_PROTOCOL_PATH}")
    if not os.path.exists(DEV_PROTOCOL_PATH):
        raise FileNotFoundError(f"Dev protocol: {DEV_PROTOCOL_PATH}")
    if not os.path.exists(EVAL_PROTOCOL_PATH):
        raise FileNotFoundError(f"Eval protocol: {EVAL_PROTOCOL_PATH}")

    # Create datasets (pass return_utt_id=True for evaluation sets)
    print("Loading datasets...")
    train_dataset = FSDDataset(TRAIN_PROTOCOL_PATH, BASE_DIR, transform=feature_extractor)
    dev_dataset = FSDDataset(DEV_PROTOCOL_PATH, BASE_DIR, transform=feature_extractor, return_utt_id=True) # Need IDs for eval
    test_dataset = FSDDataset(EVAL_PROTOCOL_PATH, BASE_DIR, transform=feature_extractor, return_utt_id=True) # Need IDs for eval

    # Create data loaders
    print("Creating data loaders...")
    # Consider pinning memory if using GPU: pin_memory=True
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    print("Data loaders created.")

    # Create model
    model = SelfDistillationFSD(block_type=MODEL_TYPE, depth=MODEL_DEPTH)
    model = model.to(device)
    print(f"Created {MODEL_TYPE.upper()}Net{MODEL_DEPTH} model with self-distillation.")
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model parameters: {num_params:,}")


    # Define loss function and optimizer
    criterion = SelfDistillationLoss(alpha=ALPHA, beta=BETA, temperature=TEMPERATURE, asoftmax_margin=ASOFTMAX_MARGIN)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(ADAM_BETA1, ADAM_BETA2),
                          eps=ADAM_EPS, weight_decay=WEIGHT_DECAY)

    # Train and evaluate
    results = train_and_evaluate(
        model, train_loader, dev_loader, test_loader, criterion, optimizer, NUM_EPOCHS, device, eval_asv_scores
    )

    # Unpack results for clarity
    train_losses, dev_eers, dev_tdcfs, test_eer, test_tdcf = results

    print("\n--- Training Complete ---")
    print(f"Final Test EER: {test_eer:.4f}%")
    if eval_asv_scores:
        print(f"Final Test min t-DCF: {test_tdcf:.4f}")
    else:
        print("Final Test min t-DCF: Not calculated (ASV scores unavailable)")

    # Save the final model state (optional, best model already saved)
    # torch.save(model.state_dict(), f'final_model_state_{MODEL_TYPE}{MODEL_DEPTH}.pth')
    print("Best model saved during training in 'checkpoints/' based on validation metric.")

Using device: cuda
Loading datasets...
Loaded 6046 samples from /content/organized_audio_output_nested/protocols/train.txt
Loaded 756 samples from /content/organized_audio_output_nested/protocols/dev.txt
Loaded 754 samples from /content/organized_audio_output_nested/protocols/eval.txt
Creating data loaders...
Data loaders created.
Created ECANet18 model with self-distillation.
Model parameters: 2,791,248
Starting training...


Epoch 1/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.04it/s, loss=1.05e+3]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.69it/s]


Epoch 1/30, Loss: 1048.7353, Dev EER: 6.3492% (t-DCF skipped)
>>> New best model saved with Dev EER: 6.3492


Epoch 2/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.01it/s, loss=630]
Validation: 100%|██████████| 24/24 [00:08<00:00,  2.88it/s]


Epoch 2/30, Loss: 630.0774, Dev EER: 2.9101% (t-DCF skipped)
>>> New best model saved with Dev EER: 2.9101


Epoch 3/30 Training: 100%|██████████| 189/189 [01:01<00:00,  3.06it/s, loss=511]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.64it/s]


Epoch 3/30, Loss: 510.5699, Dev EER: 3.4392% (t-DCF skipped)


Epoch 4/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.04it/s, loss=434]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.23it/s]


Epoch 4/30, Loss: 434.0442, Dev EER: 2.1164% (t-DCF skipped)
>>> New best model saved with Dev EER: 2.1164


Epoch 5/30 Training: 100%|██████████| 189/189 [01:01<00:00,  3.05it/s, loss=382]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.33it/s]


Epoch 5/30, Loss: 382.0875, Dev EER: 3.4392% (t-DCF skipped)


Epoch 6/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.04it/s, loss=333]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.69it/s]


Epoch 6/30, Loss: 333.4936, Dev EER: 1.8519% (t-DCF skipped)
>>> New best model saved with Dev EER: 1.8519


Epoch 7/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.01it/s, loss=313]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.22it/s]


Epoch 7/30, Loss: 312.5177, Dev EER: 2.1164% (t-DCF skipped)


Epoch 8/30 Training: 100%|██████████| 189/189 [01:01<00:00,  3.06it/s, loss=292]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.49it/s]


Epoch 8/30, Loss: 291.9819, Dev EER: 3.4392% (t-DCF skipped)


Epoch 9/30 Training: 100%|██████████| 189/189 [01:01<00:00,  3.05it/s, loss=274]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.60it/s]


Epoch 9/30, Loss: 274.3182, Dev EER: 1.8519% (t-DCF skipped)
>>> New best model saved with Dev EER: 1.8519


Epoch 10/30 Training: 100%|██████████| 189/189 [01:03<00:00,  2.99it/s, loss=271]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.14it/s]


Epoch 10/30, Loss: 271.0495, Dev EER: 3.9683% (t-DCF skipped)


Epoch 11/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.05it/s, loss=284]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.61it/s]


Epoch 11/30, Loss: 284.0741, Dev EER: 30.6878% (t-DCF skipped)


Epoch 12/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.03it/s, loss=263]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.34it/s]


Epoch 12/30, Loss: 262.5540, Dev EER: 1.0582% (t-DCF skipped)
>>> New best model saved with Dev EER: 1.0582


Epoch 13/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.03it/s, loss=251]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.06it/s]


Epoch 13/30, Loss: 250.7003, Dev EER: 1.0582% (t-DCF skipped)
>>> New best model saved with Dev EER: 1.0582


Epoch 14/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.01it/s, loss=243]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.60it/s]


Epoch 14/30, Loss: 242.6692, Dev EER: 3.1746% (t-DCF skipped)


Epoch 15/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.00it/s, loss=235]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.12it/s]


Epoch 15/30, Loss: 235.3472, Dev EER: 1.8519% (t-DCF skipped)


Epoch 16/30 Training: 100%|██████████| 189/189 [01:01<00:00,  3.05it/s, loss=233]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.38it/s]


Epoch 16/30, Loss: 232.7623, Dev EER: 3.4392% (t-DCF skipped)


Epoch 17/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.02it/s, loss=228]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.53it/s]


Epoch 17/30, Loss: 228.2952, Dev EER: 3.7037% (t-DCF skipped)


Epoch 18/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.02it/s, loss=232]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.11it/s]


Epoch 18/30, Loss: 231.5201, Dev EER: 1.0582% (t-DCF skipped)


Epoch 19/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.05it/s, loss=227]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.20it/s]


Epoch 19/30, Loss: 226.8983, Dev EER: 0.7937% (t-DCF skipped)
>>> New best model saved with Dev EER: 0.7937


Epoch 20/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.04it/s, loss=231]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.52it/s]


Epoch 20/30, Loss: 230.9066, Dev EER: 1.0582% (t-DCF skipped)


Epoch 21/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.00it/s, loss=222]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.08it/s]


Epoch 21/30, Loss: 222.4956, Dev EER: 1.8519% (t-DCF skipped)


Epoch 22/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.04it/s, loss=228]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.60it/s]


Epoch 22/30, Loss: 227.7740, Dev EER: 1.8519% (t-DCF skipped)


Epoch 23/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.05it/s, loss=213]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.29it/s]


Epoch 23/30, Loss: 213.2543, Dev EER: 1.3228% (t-DCF skipped)


Epoch 24/30 Training: 100%|██████████| 189/189 [01:01<00:00,  3.05it/s, loss=219]
Validation: 100%|██████████| 24/24 [00:08<00:00,  2.92it/s]


Epoch 24/30, Loss: 219.4794, Dev EER: 2.3810% (t-DCF skipped)


Epoch 25/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.04it/s, loss=216]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.70it/s]


Epoch 25/30, Loss: 215.8732, Dev EER: 1.0582% (t-DCF skipped)


Epoch 26/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.01it/s, loss=212]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.19it/s]


Epoch 26/30, Loss: 211.5073, Dev EER: 1.3228% (t-DCF skipped)


Epoch 27/30 Training: 100%|██████████| 189/189 [01:01<00:00,  3.05it/s, loss=212]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.46it/s]


Epoch 27/30, Loss: 211.8207, Dev EER: 1.3228% (t-DCF skipped)


Epoch 28/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.04it/s, loss=210]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.64it/s]


Epoch 28/30, Loss: 210.4858, Dev EER: 0.7937% (t-DCF skipped)


Epoch 29/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.01it/s, loss=209]
Validation: 100%|██████████| 24/24 [00:07<00:00,  3.21it/s]


Epoch 29/30, Loss: 208.8959, Dev EER: 1.8519% (t-DCF skipped)


Epoch 30/30 Training: 100%|██████████| 189/189 [01:02<00:00,  3.05it/s, loss=209]
Validation: 100%|██████████| 24/24 [00:06<00:00,  3.54it/s]


Epoch 30/30, Loss: 209.0306, Dev EER: 1.0582% (t-DCF skipped)
Loading best model for final evaluation...


Test Evaluation: 100%|██████████| 24/24 [00:07<00:00,  3.35it/s]

Final Test EER: 0.2653%
Final Test t-DCF: Skipped (ASV scores not provided)

--- Training Complete ---
Final Test EER: 0.2653%
Final Test min t-DCF: Not calculated (ASV scores unavailable)
Best model saved during training in 'checkpoints/' based on validation metric.



