# Deepfake Audio Detection - Colab Test

**IMPORTANT**: Run Cell 1, then **RESTART RUNTIME** before continuing.

In [None]:
# CELL 1: Install dependencies (RUN ONCE, THEN RESTART RUNTIME)
!pip uninstall datasets -y -q 2>/dev/null
!pip install datasets==2.14.7 soundfile librosa -q
print("\n" + "="*60)
print("âœ… INSTALLATION COMPLETE")
print("ðŸ‘‰ NOW GO TO: Runtime â†’ Restart runtime")
print("ðŸ‘‰ THEN skip this cell and run Cell 2")
print("="*60)

In [None]:
# CELL 2: Upload model file
from google.colab import files
print("ðŸ“¤ Upload your model file (prosody_encoder_best_yet.pth):")
uploaded = files.upload()
MODEL_PATH = list(uploaded.keys())[0]
print(f"âœ… Uploaded: {MODEL_PATH}")

In [None]:
# CELL 3: All imports and config
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio.transforms as T
import numpy as np
from sklearn.metrics import roc_curve
from datasets import load_dataset
from tqdm.auto import tqdm
import gc
import warnings
warnings.filterwarnings('ignore')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_RATE = 16000
N_MELS = 80
MAX_AUDIO_LEN_SECONDS = 4
MAX_LEN_SAMPLES = SAMPLE_RATE * MAX_AUDIO_LEN_SECONDS
MAX_MEL_FRAMES = int(MAX_AUDIO_LEN_SECONDS * (SAMPLE_RATE / 160)) + 1

print(f"Device: {DEVICE}")
print(f"Max audio samples: {MAX_LEN_SAMPLES}")
print(f"Max mel frames: {MAX_MEL_FRAMES}")

In [None]:
# CELL 4: Model definition (EXACT copy from test.py)
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ProsodyEncoder(nn.Module):
    def __init__(self, n_mels=N_MELS, num_features=256):
        super(ProsodyEncoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = ResBlock(64, 64, stride=(1,2))
        self.layer2 = ResBlock(64, 128, stride=(2,2)) 
        self.layer3 = ResBlock(128, 256, stride=(2,2))
        self.layer4 = ResBlock(256, 512, stride=(2,2))
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(512, num_features)
        self.fc2 = nn.Linear(num_features, 2) 

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        features = F.relu(self.fc1(x))
        out_spoof = self.fc2(features)
        return features, out_spoof

print("âœ… Model class defined")

In [None]:
# CELL 5: Load model weights
model = ProsodyEncoder().to(DEVICE)
state_dict = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True)
model.load_state_dict(state_dict)
model.eval()
print("âœ… Model weights loaded")

In [None]:
# CELL 6: Load HuggingFace dataset
print("ðŸ“¥ Loading ASVspoof 2019 LA eval dataset...")
hf_dataset = load_dataset("Bisher/ASVspoof_2019_LA", split="eval", streaming=False)
print(f"âœ… Loaded {len(hf_dataset)} samples")

# Check first sample to understand structure
sample = hf_dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Label key name: 'key', value: {sample['key']}")

In [None]:
# CELL 7: Create PyTorch Dataset wrapper
class HFASVspoofDataset(Dataset):
    def __init__(self, hf_ds):
        self.hf_ds = hf_ds
        self.mel_transform = T.MelSpectrogram(
            sample_rate=SAMPLE_RATE, n_fft=400, win_length=400, hop_length=160, n_mels=N_MELS
        )
    
    def __len__(self):
        return len(self.hf_ds)
    
    def __getitem__(self, idx):
        item = self.hf_ds[idx]
        
        # Get audio
        audio_array = item['audio']['array']
        sr = item['audio']['sampling_rate']
        
        # Convert to tensor
        waveform = torch.FloatTensor(audio_array)
        
        # Resample if needed
        if sr != SAMPLE_RATE:
            resampler = T.Resample(sr, SAMPLE_RATE)
            waveform = resampler(waveform)
        
        # Pad/Truncate to fixed length
        if waveform.shape[0] > MAX_LEN_SAMPLES:
            waveform = waveform[:MAX_LEN_SAMPLES]
        else:
            padding = torch.zeros(MAX_LEN_SAMPLES - waveform.shape[0])
            waveform = torch.cat((waveform, padding), dim=0)
        
        # Create mel spectrogram
        melspec = self.mel_transform(waveform)
        melspec = melspec.unsqueeze(0)  # Add channel dim: (1, N_MELS, TIME)
        
        # Pad/Truncate mel frames
        if melspec.shape[2] > MAX_MEL_FRAMES:
            melspec = melspec[:, :, :MAX_MEL_FRAMES]
        else:
            pad = torch.zeros(1, N_MELS, MAX_MEL_FRAMES - melspec.shape[2])
            melspec = torch.cat((melspec, pad), dim=2)
        
        # Label: key is integer 0 or 1 in this dataset
        # 0 = bonafide, 1 = spoof
        label = int(item['key'])
        
        return melspec, label

# Test the dataset
test_ds = HFASVspoofDataset(hf_dataset)
test_mel, test_label = test_ds[0]
print(f"âœ… Dataset wrapper works!")
print(f"   Mel shape: {test_mel.shape}")
print(f"   Label: {test_label}")

In [None]:
# CELL 8: Run evaluation (low memory version)
def calculate_eer(labels, scores):
    fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1)
    fnr = 1 - tpr
    eer_index = np.nanargmin(np.abs(fpr - fnr))
    return (fpr[eer_index] + fnr[eer_index]) / 2

print(f"\nðŸš€ Starting evaluation on {len(test_ds)} samples...")

all_scores = []
all_labels = []
BATCH_SIZE = 32

# Process in batches manually to save memory
with torch.no_grad():
    for start_idx in tqdm(range(0, len(test_ds), BATCH_SIZE)):
        end_idx = min(start_idx + BATCH_SIZE, len(test_ds))
        
        batch_mels = []
        batch_labels = []
        
        for idx in range(start_idx, end_idx):
            mel, label = test_ds[idx]
            batch_mels.append(mel)
            batch_labels.append(label)
        
        # Stack into batch tensor
        batch_tensor = torch.stack(batch_mels).to(DEVICE)
        
        # Forward pass
        _, out_spoof = model(batch_tensor)
        
        # Get probabilities for class 1 (spoof)
        probs = F.softmax(out_spoof, dim=1)[:, 1]
        
        all_scores.extend(probs.cpu().numpy())
        all_labels.extend(batch_labels)
        
        # Free memory
        del batch_tensor, batch_mels
        if start_idx % 1000 == 0:
            gc.collect()
            torch.cuda.empty_cache()

print(f"\nâœ… Processed {len(all_scores)} samples")
print(f"   Bonafide (0): {all_labels.count(0)}")
print(f"   Spoof (1): {all_labels.count(1)}")

In [None]:
# CELL 9: Calculate and display EER
eer = calculate_eer(all_labels, all_scores)

print("\n" + "="*50)
print(f"ðŸŽ¯ ASVspoof 2019 LA Eval EER: {eer * 100:.2f}%")
print("="*50)