In [None]:
import os
import torch
import torchaudio
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
from torchmetrics import AveragePrecision
from models.AASIST import Model

# Configurations
SAMPLE_RATE = 16000
BATCH_SIZE = 16
TRAIN_RATIO = 0.7
DEV_RATIO = 0.15
TEST_RATIO = 0.15

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

# Load metadata
def split_dataset(metadata_csv):
    print(f"[INFO] Loading and splitting metadata from {metadata_csv}")
    df = pd.read_csv(metadata_csv)
    print(f"[INFO] Total dataset size: {len(df)} samples")
    train_df, dev_df, test_df = np.split(df.sample(frac=1, random_state=42),
                                         [int(TRAIN_RATIO*len(df)), int((TRAIN_RATIO+DEV_RATIO)*len(df))])
    print(f"[INFO] Split sizes - Train: {len(train_df)}, Dev: {len(dev_df)}, Test: {len(test_df)}")
    return train_df, dev_df, test_df

class InTheWildDataset(Dataset):
    def __init__(self, data_df, audio_dir):
        self.data = data_df.reset_index(drop=True)
        self.audio_dir = audio_dir
        print(f"[INFO] Created dataset with {len(self.data)} samples from {audio_dir}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        file_path = os.path.join(self.audio_dir, row['file'])
        label = 1 if row['label'] == 'spoof' else 0  # 1 for fake, 0 for real

        # Load audio
        try:
            waveform, sr = torchaudio.load(file_path)
            if idx % 100 == 0:  # Print details occasionally to avoid console spam
                print(f"[DEBUG] Loaded file {file_path} - Shape: {waveform.shape}, SR: {sr}, Label: {label}")
            
            if sr != SAMPLE_RATE:
                resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
                waveform = resampler(waveform)
                if idx % 100 == 0:
                    print(f"[DEBUG] Resampled to {SAMPLE_RATE}Hz - New shape: {waveform.shape}")

            # Convert to mono if stereo
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
                if idx % 100 == 0:
                    print(f"[DEBUG] Converted to mono - New shape: {waveform.shape}")
            
            # Ensure correct length (4 seconds for AASIST compatibility)
            target_length = SAMPLE_RATE * 4  # 64000 samples for 4 seconds at 16kHz
            if waveform.shape[1] > target_length:
                waveform = waveform[:, :target_length]
                if idx % 100 == 0:
                    print(f"[DEBUG] Trimmed to {target_length} samples - New shape: {waveform.shape}")
            else:
                pad_size = target_length - waveform.shape[1]
                waveform = torch.nn.functional.pad(waveform, (0, pad_size))
                if idx % 100 == 0:
                    print(f"[DEBUG] Padded with {pad_size} zeros - New shape: {waveform.shape}")

            # Return raw waveform with shape [time_steps] - the model will add channel dimension
            final_waveform = waveform.squeeze()
            final_label = torch.tensor(label, dtype=torch.long)
            
            if idx % 100 == 0:
                print(f"[DEBUG] Final waveform shape: {final_waveform.shape}, Label: {label}")
            
            return final_waveform, final_label
            
        except Exception as e:
            print(f"[ERROR] Error loading file {file_path}: {e}")
            # Return a dummy waveform as fallback
            dummy_waveform = torch.zeros(SAMPLE_RATE * 4)
            return dummy_waveform, torch.tensor(label, dtype=torch.long)

# Load AASIST model from checkpoint
def load_aasist_model(checkpoint_path):
    print(f"[INFO] Loading model from {checkpoint_path}")
    
    model_config = {
        "nb_samp": 64000,  # Updated to match our target length
        "first_conv": 128,
        "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
        "gat_dims": [64, 32],
        "pool_ratios": [0.5, 0.7, 0.5, 0.5],
        "temperatures": [2.0, 2.0, 100.0, 100.0]
    }
    
    print(f"[DEBUG] Model configuration: {model_config}")
    
    try:
        model = Model(model_config)
        print(f"[INFO] Model instance created")
        
        # Print model structure
        print(f"[DEBUG] Model architecture:\n{model}")
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        print(f"[INFO] Total model parameters: {total_params:,}")
        
        # Load weights
        print(f"[INFO] Loading weights from checkpoint...")
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print(f"[INFO] Weights loaded successfully")
        
        model.to(device)
        print(f"[INFO] Model moved to device: {device}")
        return model
    except Exception as e:
        print(f"[ERROR] Failed to load model: {e}")
        raise

# Training function
def train(model, train_loader, dev_loader, epochs=10, lr=1e-4):
    print(f"[INFO] Starting training for {epochs} epochs with learning rate {lr}")
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        print(f"[INFO] Starting epoch {epoch+1}/{epochs}")
        model.train()
        print(f"[INFO] Model set to training mode")
        
        total_loss, correct, total = 0, 0, 0
        batch_count = 0
        
        for batch_idx, (waveform, label) in enumerate(train_loader):
            # Print batch info every few batches
            if batch_idx % 10 == 0:
                print(f"[INFO] Processing batch {batch_idx+1}/{len(train_loader)}")
                print(f"[DEBUG] Batch waveform shape: {waveform.shape}, Label shape: {label.shape}")
            
            # Move to device
            waveform = waveform.to(device)
            label = label.to(device)
            
            if batch_idx % 10 == 0:
                print(f"[DEBUG] Data moved to {device}")
            
            # Reset gradients
            optimizer.zero_grad()
            
            try:
                # Forward pass
                if batch_idx % 10 == 0:
                    print(f"[DEBUG] Running forward pass...")
                
                _, output = model(waveform)  # Model returns (last_hidden, output)
                
                if batch_idx % 10 == 0:
                    print(f"[DEBUG] Forward pass complete. Output shape: {output.shape}")
                    print(f"[DEBUG] Sample outputs: {output[0]}")
                
                # Calculate loss
                loss = criterion(output, label)
                if batch_idx % 10 == 0:
                    print(f"[DEBUG] Loss: {loss.item():.4f}")
                
                # Backward pass
                loss.backward()
                if batch_idx % 10 == 0:
                    print(f"[DEBUG] Backward pass complete")
                
                # Update weights
                optimizer.step()
                if batch_idx % 10 == 0:
                    print(f"[DEBUG] Weights updated")
                
                # Track metrics
                total_loss += loss.item()
                pred = output.argmax(dim=1)
                batch_correct = (pred == label).sum().item()
                correct += batch_correct
                total += label.size(0)
                batch_count += 1
                
                if batch_idx % 10 == 0:
                    batch_acc = batch_correct / label.size(0)
                    print(f"[DEBUG] Batch accuracy: {batch_acc:.4f} ({batch_correct}/{label.size(0)})")
                
            except Exception as e:
                print(f"[ERROR] Exception during training batch {batch_idx}: {e}")
                import traceback
                traceback.print_exc()
                continue
        
        # Calculate epoch metrics
        epoch_loss = total_loss / batch_count if batch_count > 0 else 0
        train_acc = correct / total if total > 0 else 0
        
        print(f"[INFO] Evaluating on dev set...")
        dev_acc = evaluate(model, dev_loader)
        
        print(f"[RESULT] Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Train Accuracy: {train_acc:.4f}, Dev Accuracy: {dev_acc:.4f}")

def evaluate(model, dataloader):
    print(f"[INFO] Starting evaluation on {len(dataloader)} batches")
    model.eval()
    print(f"[INFO] Model set to evaluation mode")
    
    correct, total = 0, 0
    
    with torch.no_grad():
        for batch_idx, (waveform, label) in enumerate(dataloader):
            if batch_idx % 10 == 0:
                print(f"[INFO] Evaluating batch {batch_idx+1}/{len(dataloader)}")
                print(f"[DEBUG] Batch waveform shape: {waveform.shape}")
            
            waveform = waveform.to(device)
            label = label.to(device)
            
            try:
                _, output = model(waveform)  # Model returns (last_hidden, output)
                
                pred = output.argmax(dim=1)
                batch_correct = (pred == label).sum().item()
                correct += batch_correct
                total += label.size(0)
                
                if batch_idx % 10 == 0:
                    batch_acc = batch_correct / label.size(0)
                    print(f"[DEBUG] Batch accuracy: {batch_acc:.4f} ({batch_correct}/{label.size(0)})")
                
            except Exception as e:
                print(f"[ERROR] Exception during evaluation batch {batch_idx}: {e}")
                continue
    
    final_acc = correct / total if total > 0 else 0
    print(f"[RESULT] Evaluation complete. Accuracy: {final_acc:.4f} ({correct}/{total})")
    return final_acc

print("[INFO] Script starting...")


audio_dir = "release_in_the_wild"
print(f"[INFO] Processing data from directory: {audio_dir}")

train_df, dev_df, test_df = split_dataset("release_in_the_wild/meta.csv")

print("[INFO] Creating datasets...")
train_dataset = InTheWildDataset(train_df, audio_dir)
dev_dataset = InTheWildDataset(dev_df, audio_dir)
test_dataset = InTheWildDataset(test_df, audio_dir)

print("[INFO] Creating data loaders...")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
print(f"[INFO] Created {len(train_loader)} training batches, {len(dev_loader)} dev batches, {len(test_loader)} test batches")

# Load model and train
print("[INFO] Loading model...")
model = load_aasist_model("models/AASIST.pth")

print("[INFO] Starting training process...")
train(model, train_loader, dev_loader, epochs=10, lr=1e-4)

# Evaluate on test set
print("[INFO] Final evaluation on test set...")
test_accuracy = evaluate(model, test_loader)
print(f"[RESULT] Final Test Accuracy: {test_accuracy:.4f}")
print("[INFO] Script completed successfully")

[INFO] Using device: cuda
[INFO] Script starting...
[INFO] Processing data from directory: release_in_the_wild
[INFO] Loading and splitting metadata from release_in_the_wild/meta.csv
[INFO] Total dataset size: 31779 samples
[INFO] Split sizes - Train: 22245, Dev: 4767, Test: 4767
[INFO] Creating datasets...
[INFO] Created dataset with 22245 samples from release_in_the_wild
[INFO] Created dataset with 4767 samples from release_in_the_wild
[INFO] Created dataset with 4767 samples from release_in_the_wild
[INFO] Creating data loaders...
[INFO] Created 1391 training batches, 298 dev batches, 298 test batches
[INFO] Loading model...
[INFO] Loading model from models/weights/AASIST.pth
[DEBUG] Model configuration: {'nb_samp': 64000, 'first_conv': 128, 'filts': [70, [1, 32], [32, 32], [32, 64], [64, 64]], 'gat_dims': [64, 32], 'pool_ratios': [0.5, 0.7, 0.5, 0.5], 'temperatures': [2.0, 2.0, 100.0, 100.0]}
[INFO] Model instance created
[DEBUG] Model architecture:
Model(
  (conv_time): CONV()
  (

In [None]:
torch.save(model.state_dict(), "models/aasist_trained.pth")

In [10]:
!pip install sounddevice

Defaulting to user installation because normal site-packages is not writeable
Collecting sounddevice
  Downloading sounddevice-0.4.4-py3-none-any.whl (31 kB)
Installing collected packages: sounddevice
Successfully installed sounddevice-0.4.4


In [11]:
def preprocess_audio(file_path, target_sample_rate=16000, target_length=64000):
    try:
        waveform, sr = torchaudio.load(file_path)
        print(f"[DEBUG] Loaded file {file_path} - Shape: {waveform.shape}, SR: {sr}")

        # Resample if needed
        if sr != target_sample_rate:
            resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
            waveform = resampler(waveform)
            print(f"[DEBUG] Resampled to {target_sample_rate}Hz")

        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
            print(f"[DEBUG] Converted to mono - New shape: {waveform.shape}")

        # Ensure correct length (4 seconds)
        if waveform.shape[1] > target_length:
            waveform = waveform[:, :target_length]  # Trim
        else:
            pad_size = target_length - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, pad_size))  # Pad

        print(f"[DEBUG] Final waveform shape: {waveform.shape}")
        return waveform.squeeze()
    
    except Exception as e:
        print(f"[ERROR] Failed to process audio {file_path}: {e}")
        return None


In [None]:
def infer(file_path):
    print(f"[INFO] Running inference on {file_path}")
    waveform = preprocess_audio(file_path)
    
    if waveform is None:
        return "Error in processing audio"

    waveform = waveform.to(device).unsqueeze(0)  # Add batch dimension
    
    with torch.no_grad():
        _, output = model(waveform)  # Forward pass
        prediction = torch.argmax(output, dim=1).item()

    label = "spoof" if prediction == 1 else "bona-fide"
    print(f"[RESULT] Predicted label: {label}")
    return label


In [None]:
# code for real-time testing