In [None]:
# Kaggle Migration: Mounted drive logic removed.
# If you are running on Kaggle, your data will be in /kaggle/input/
# and your outputs will be in /kaggle/working/
import os
print("Running on Kaggle Environment")

Mounted at /content/drive


In [None]:
!pip -q install transformers torchaudio librosa soundfile tqdm scikit-learn webrtcvad

In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import json
from collections import OrderedDict

import torch
import torch.nn as nn
import torchaudio
import librosa
import numpy as np

from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler
from transformers import WavLMModel
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, roc_curve,
)

In [None]:
# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

# Data paths (Kaggle standard paths) - DIRECT PATH ENTRY
# Paste your folder paths directly into these lists. 
# The code will handle flat folders, nested languages, and mixed sources automatically.
BASE_DATA = {
    "Human": [
        "/kaggle/input/ai-4-bharat/AI_4_bharat/Human/English",
        "/kaggle/input/ai-4-bharat/AI_4_bharat/Human/Hindi",
        "/kaggle/input/ai-4-bharat/AI_4_bharat/Human/Tamil",
        "/kaggle/input/ai-4-bharat/AI_4_bharat/Human/Telugu",
        "/kaggle/input/ai-4-bharat/AI_4_bharat/Human/Malayalam"
    ],
    "AI": [
        "/kaggle/input/datasets/harshshah9104/ai-summit/AI/English-20260213T140035Z-3-001/English",
        "/kaggle/input/datasets/harshshah9104/ai-summit/AI/Hindi-20260213T140140Z-3-001/Hindi",
        "/kaggle/input/datasets/harshshah9104/ai-summit/AI/Tamil-20260213T135756Z-3-001/Tamil",
        "/kaggle/input/datasets/harshshah9104/ai-summit/AI/Telugu-20260213T135744Z-3-001/Telugu",
        "/kaggle/input/datasets/harshshah9104/ai-summit/AI/Malyalam-20260213T135808Z-3-001/Malyalam"
    ]
}
SAVE_DIR = "/kaggle/working/wavlm_ensemble_checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

# Audio processing - FIXED LENGTH
SAMPLE_RATE = 16000
TARGET_DURATION = 5.0  # All audio will be exactly 5 seconds
# This means: 5.0 * 16000 = 80,000 samples per audio

# Allowed audio formats in dataset folders (Lowercased for comparison)
AUDIO_EXTS = (".wav", ".mp3", ".flac")

# Normalization settings
NORM_TYPE = "peak"           # "peak" or "rms"
RMS_TARGET = 0.1             # Target RMS level for RMS normalization
SILENCE_THRESHOLD = 1e-4     # Threshold to detect silence

# Preprocessing (denoise / filtering)
USE_DENOISE = False          # Spectral-gate denoise (OFF per ASVspoof)
DENOISE_N_FFT = 1024
DENOISE_HOP_LENGTH = 256
DENOISE_NOISE_PERCENTILE = 10
DENOISE_THRESHOLD_MULT = 1.5
DENOISE_ATTENUATION = 0.2

USE_BANDPASS = True
HIGHPASS_CUTOFF_HZ = 80.0
LOWPASS_CUTOFF_HZ = 7800.0      # keep below Nyquist (8000 for 16kHz)

# Inference-time speech selection (recommended for call recordings)
USE_VAD_INFERENCE = True
VAD_AGGRESSIVENESS = 2       # 0..3 (higher = more strict)
MIN_VOICED_SECONDS = 1.0     # require at least this much speech after VAD
MAX_INFER_WINDOWS = 6        # cap compute on long calls

# Regularization / anti-overfitting
DROPOUT_P = 0.3
WEIGHT_DECAY = 1e-2
GRAD_CLIP_NORM = 1.0
EARLY_STOPPING_PATIENCE = 3
MIN_DELTA_AUC = 1e-4
LABEL_SMOOTHING = 0.05           
SPEC_AUG_FREQ_MASKS = 2          
SPEC_AUG_FREQ_WIDTH = 30         
SPEC_AUG_TIME_MASKS = 2          # New: Temporal robustness
SPEC_AUG_TIME_WIDTH = 40         

# Forensic Augmentation Settings
USE_AUGMENT = True
SPEED_PERTURB_RANGE = (0.9, 1.1)  
GAIN_RANGE = (0.5, 1.2)           
NOISE_STD = 0.002                 
USE_CODEC_AUG = True              
USE_RANDOM_EQ = True             # New: Mic/Hardware simulation
CLIPPING_PROB = 0.2              # New: Digital artifact simulation

# Model Architecture strategy
UNFREEZE_TOP_LAYERS = 2           # New: Fine-tune last 2 layers of WavLM

# Forensic Data Augmentation paths (Kaggle)
ESC50_PATH = "/kaggle/input/esc50/ESC-50-master/audio" # Adjust to your ESC-50 Kaggle path
HUMAN_NOISE_COUNT = 50            
MP3_BITRATE = "32k"               

# Training hyperparameters
BATCH_SIZE = 32
EPOCHS = 10  # Kaggle GPUs are faster, can do more epochs
LEARNING_RATE = 2e-4   
VAL_SPLIT = 0.2

# Ensemble weights
AASIST_WEIGHT = 0.6
OCSOFT_WEIGHT = 0.4

# Cache settings (only for validation)
MAX_CACHE_SIZE = 1000

# Class-imbalance handling (instead of oversampling): BCE pos_weight for AI=1

OVERSAMPLE_AI = True  # oversample minority class via WeightedRandomSampler
POS_WEIGHT_AI = None  # set automatically after train/val split
# Class-imbalance handling (instead of oversampling): BCE pos_weight for AI=1

Using device: cpu


In [None]:
# Quick sanity check: do we actually have WAVs where we think we do?
import os

print("BASE_DATA Configuration:", BASE_DATA)
# Fixed: handled dict type for the existence check
if isinstance(BASE_DATA, str):
    print("Base Path Exists:", os.path.isdir(BASE_DATA))
else:
    print("Mode: Multiple Dataset Paths")
    
print("AUDIO_EXTS:", AUDIO_EXTS)

def _count_wavs(folder: str):
    try:
        files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
    except Exception:
        return 0, []
    # Case-insensitive extension check
    wavs = [f for f in files if f.lower().endswith(AUDIO_EXTS)]
    return len(wavs), wavs[:5]

total_wavs = 0
for cls in ["Human", "AI"]:
    if isinstance(BASE_DATA, dict):
        cls_paths = BASE_DATA.get(cls, [])
        if isinstance(cls_paths, str): cls_paths = [cls_paths]
    else:
        cls_paths = [os.path.join(BASE_DATA, cls)]
        
    for cls_path in cls_paths:
        if not os.path.isdir(cls_path):
            print(f"⚠️ Missing folder: {cls_path}")
            continue
            
        # Detect if it's flat or nested with languages
        items = sorted(os.listdir(cls_path))
        langs = [d for d in items if os.path.isdir(os.path.join(cls_path, d))]
        
        if langs:
            print(f"\n{cls} (dataset: {os.path.basename(cls_path)}): {len(langs)} language folders")
            for lang in langs:
                lang_path = os.path.join(cls_path, lang)
                n, sample = _count_wavs(lang_path)
                total_wavs += n
                print(f"  {lang}: {n} wav files (sample: {sample})")
        else:
            n, sample = _count_wavs(cls_path)
            total_wavs += n
            print(f"\n{cls} (Flat dataset: {os.path.basename(cls_path)}): {n} wav files")

print(f"\nTotal WAV files seen across dataset: {total_wavs}")
if total_wavs == 0:
    print("\nIf this is 0, then either:")
    print("- your Drive path is different than BASE_DATA, or")
    print("- your files aren’t actually .wav in those folders, or")
    print("- they’re nested deeper (this loader only scans one level), or")
    print("- the conversion cell didn’t run / wrote WAVs elsewhere.")

BASE_DATA: /content/drive/MyDrive/AI_4_bharat
Exists: True
AUDIO_EXTS: ('.wav',)

Human: 5 language folders
  English: 2500 wav files (sample: ['real_1064.wav', 'real_1041.wav', 'real_1031.wav', 'real_1054.wav', 'real_1053.wav'])
  Hindi: 2660 wav files (sample: ['Hindi_002229.wav', 'Hindi_002230.wav', 'Hindi_002231.wav', 'Hindi_002232.wav', 'Hindi_002233.wav'])
  Malayalam: 2710 wav files (sample: ['Malayalam_001002.wav', 'Malayalam_001003.wav', 'Malayalam_001004.wav', 'Malayalam_001005.wav', 'Malayalam_001001.wav'])
  Tamil: 2870 wav files (sample: ['Tamil_001287.wav', 'Tamil_001289.wav', 'Tamil_001290.wav', 'Tamil_001291.wav', 'Tamil_002583.wav'])
  Telugu: 2560 wav files (sample: ['Telugu_002512.wav', 'Telugu_002513.wav', 'Telugu_001049.wav', 'Telugu_002514.wav', 'Telugu_001050.wav'])

AI: 5 language folders
  English: 2500 wav files (sample: ['synthetic_1500.wav', 'synthetic_1501.wav', 'synthetic_1502.wav', 'synthetic_1503.wav', 'synthetic_1504.wav'])
  Hindi: 1749 wav files (samp

In [None]:
def _apply_bandpass_torch(wav_t: torch.Tensor, sr: int) -> torch.Tensor:
    """Bandpass filter to focus on speech band and reduce rumble/hiss."""
    if not USE_BANDPASS:
        return wav_t
    wav_t = torchaudio.functional.highpass_biquad(wav_t, sr, cutoff_freq=HIGHPASS_CUTOFF_HZ)
    wav_t = torchaudio.functional.lowpass_biquad(wav_t, sr, cutoff_freq=LOWPASS_CUTOFF_HZ)
    return wav_t


def _denoise_spectral_gate_np(wav_np: np.ndarray, _sr: int) -> np.ndarray:
    """Mild spectral gating denoise (keeps speech; reduces steady background noise)."""
    if not USE_DENOISE:
        return wav_np
    if wav_np.size == 0:
        return wav_np
    if not np.isfinite(wav_np).all():
        return wav_np

    stft = librosa.stft(wav_np, n_fft=DENOISE_N_FFT, hop_length=DENOISE_HOP_LENGTH)
    mag = np.abs(stft)
    phase = np.exp(1j * np.angle(stft))

    noise_floor = np.percentile(mag, DENOISE_NOISE_PERCENTILE, axis=1, keepdims=True)
    thresh = noise_floor * float(DENOISE_THRESHOLD_MULT)

    mask = (mag >= thresh).astype(np.float32)
    mag_d = mag * mask + mag * (1.0 - mask) * float(DENOISE_ATTENUATION)

    stft_d = mag_d * phase
    wav_out = librosa.istft(stft_d, hop_length=DENOISE_HOP_LENGTH, length=len(wav_np))
    return wav_out.astype(np.float32)


class AudioDataset(Dataset):
    """Dataset with forensic augmentation: Human-noise injection + ESC-50 + MP3."""

    def __init__(self, base_dir, target_duration=5.0, mode='train', max_cache_size=1000):
        self.samples = []
        self.cache = OrderedDict()
        self.max_cache_size = max_cache_size
        self.target_duration = target_duration
        self.mode = mode
        self.failed_files = []
        
        # Forensic Noise Banks
        self.human_noises = []
        self.esc50_bank = []

        # 1. Collect all audio files (Human/AI)
        human_files = []
        for label, cls in [(0, "Human"), (1, "AI")]:
            # Use specific path list from dict if available
            if isinstance(base_dir, dict):
                cls_paths = base_dir.get(cls, [])
                if isinstance(cls_paths, str): cls_paths = [cls_paths]
            else:
                cls_paths = [os.path.join(base_dir, cls)]
                
            for cls_path in cls_paths:
                if not os.path.isdir(cls_path): 
                    print(f"⚠️ Warning: Folder not found: {cls_path}")
                    continue
                
                # Check if this specific folder has subfolders (nested language structure)
                # or contains files directly (direct language structure)
                items = sorted(os.listdir(cls_path))
                subdirs = [d for d in items if os.path.isdir(os.path.join(cls_path, d))]
                
                if subdirs:
                    # Recursive collection for nested folders
                    for lang in subdirs:
                        lang_path = os.path.join(cls_path, lang)
                        found_in_subdir = 0
                        for f in os.listdir(lang_path):
                            if f.lower().endswith(AUDIO_EXTS):
                                path = os.path.join(lang_path, f)
                                self.samples.append((path, label, lang))
                                if label == 0: human_files.append(path)
                                found_in_subdir += 1
                        if found_in_subdir > 0:
                            print(f"  - Loaded {found_in_subdir} files from {lang}")
                else:
                    # Direct collection from the provided folder
                    folder_name = os.path.basename(cls_path)
                    found_direct = 0
                    for f in items:
                        if f.lower().endswith(AUDIO_EXTS):
                            path = os.path.join(cls_path, f)
                            self.samples.append((path, label, folder_name))
                            if label == 0: human_files.append(path)
                            found_direct += 1
                    if found_direct > 0:
                        print(f"  - Loaded {found_direct} files from {folder_name}")
                    else:
                        print(f"  ⚠️ No valid audio files found in: {cls_path}")

        # 2. Extract background noise from Human dataset (only for training)
        if mode == 'train':
            print(f"Extracting background noise from first {HUMAN_NOISE_COUNT} human files...")
            for p in human_files[:HUMAN_NOISE_COUNT]:
                noise_segment = self._extract_background_noise(p)
                if noise_segment is not None:
                    self.human_noises.append(noise_segment)
            
            # 3. Pre-load a small bank of ESC-50 noises (Avoids disk I/O in loop)
            if os.path.isdir(ESC50_PATH):
                all_esc = []
                for root, _, files in os.walk(ESC50_PATH):
                    for f in files:
                        if f.lower().endswith(".wav"):
                            all_esc.append(os.path.join(root, f))
                
                print(f"Pre-loading bank of 100 ESC-50 noises...")
                for p in all_esc[:100]:
                    try:
                        n, _ = librosa.load(p, sr=SAMPLE_RATE, duration=5.0, mono=True)
                        self.esc50_bank.append(torch.tensor(n).float())
                    except: continue
                print(f"Bank ready with {len(self.esc50_bank)} noise profiles.")

    def _extract_background_noise(self, path):
        """Finds 'silent' or low-energy regions in a human file to use as noise."""
        try:
            # Load short snippet to save memory
            wav, _ = librosa.load(path, sr=SAMPLE_RATE, duration=10, mono=True)
            if len(wav) < SAMPLE_RATE: return None
            
            # Find low energy segments
            intervals = librosa.effects.split(wav, top_db=30) # db threshold for "silence"
            # Invert intervals to find the "silence" parts
            noise_segs = []
            last_end = 0
            for start, end in intervals:
                if start > last_end + int(0.5 * SAMPLE_RATE):
                    noise_segs.append(wav[last_end:start])
                last_end = end
            
            if not noise_segs:
                # If no clear silence, just use the whole file at low amplitude
                return torch.tensor(wav).float()
            
            # Pick longest silence
            noise_segs.sort(key=len, reverse=True)
            return torch.tensor(noise_segs[0]).float()
        except:
            return None

    def __len__(self):
        return len(self.samples)

    def _normalize_audio(self, wav):
        """Robust audio normalization."""
        if np.abs(wav).max() < SILENCE_THRESHOLD:
            return wav

        if NORM_TYPE == "peak":
            wav = wav / max(np.abs(wav).max(), 1e-6)
        elif NORM_TYPE == "rms":
            rms = np.sqrt(np.mean(wav**2))
            if rms > 1e-6:
                wav = wav * (RMS_TARGET / rms)
                wav = np.clip(wav, -1.0, 1.0)

        return wav

    def _load_full_audio(self, path):
        """Load, denoise/filter, and normalize FULL audio (cache this)."""
        try:
            wav, sr = librosa.load(path, sr=SAMPLE_RATE, mono=True)

            if len(wav) == 0:
                raise ValueError("Empty audio file")
            if not np.isfinite(wav).all():
                raise ValueError("Audio contains NaN or Inf values")

            # Simple bandpass
            wav_t = torch.tensor(wav).float()
            wav_t = _apply_bandpass_torch(wav_t, SAMPLE_RATE)
            wav = wav_t.cpu().numpy()

            # Normalize full audio
            wav = self._normalize_audio(wav)

            return torch.tensor(wav).float()

        except Exception as e:
            self.failed_files.append((path, str(e)))
            return None

    def _crop_to_fixed_duration(self, wav):
        """Crop cached audio to fixed duration."""
        target_length = int(self.target_duration * SAMPLE_RATE)
        current_length = len(wav)

        if current_length == 0:
            return torch.zeros(target_length).float()

        if current_length < target_length:
            # Zero-pad short clips
            pad_length = target_length - current_length
            wav = torch.cat([wav, torch.zeros(pad_length)])
            return wav

        if current_length > target_length:
            if self.mode == 'train':
                max_start = current_length - target_length
                start = np.random.randint(0, max_start + 1)
            else:
                start = (current_length - target_length) // 2
            wav = wav[start:start + target_length]

        return wav

    def _apply_codec_sim(self, wav: torch.Tensor) -> torch.Tensor:
        """Simulate MP3 compression and telephony codecs."""
        if not USE_CODEC_AUG: return wav
        
        # Randomly choose between real MP3 simulation or simple resampling
        if torch.rand(1).item() < 0.5:
            # Fallback: Resampling + Bit-depth reduction (avoids ffmpeg/sox errors)
            low_sr = torch.randint(8000, 12000, (1,)).item()
            wav = torchaudio.functional.resample(wav.unsqueeze(0), SAMPLE_RATE, low_sr)
            wav = torchaudio.functional.resample(wav, low_sr, SAMPLE_RATE).squeeze(0)
            bits = torch.randint(4, 9, (1,)).item()
            wav = torch.round(wav * (2**(bits-1))) / (2**(bits-1))
        return wav

    def _apply_random_eq(self, wav: torch.Tensor) -> torch.Tensor:
        """Simulate different microphone quality by random EQ boosts/cuts."""
        if not USE_RANDOM_EQ: return wav
        try:
            # Randomly boost/cut lows or highs
            f0 = torch.randint(100, 3000, (1,)).item()
            gain = torch.empty(1).uniform_(-6, 6).item() # +/- 6dB
            wav = torchaudio.functional.lowpass_biquad(wav.unsqueeze(0), SAMPLE_RATE, f0).squeeze(0) if gain < 0 else wav
            
            # Random peak filter
            f_center = torch.randint(500, 4000, (1,)).item()
            gain_p = torch.empty(1).uniform_(-10, 5).item()
            wav = torchaudio.functional.equalizer_biquad(wav.unsqueeze(0), SAMPLE_RATE, f_center, gain_p, Q=0.707).squeeze(0)
        except: pass
        return wav

    def _generate_pink_noise(self, length: int) -> torch.Tensor:
        """Generates pink noise by filtering white noise (approx -3dB/octave)."""
        white = torch.randn(length)
        b = [0.049922035, -0.095993537, 0.050223151, -0.004947755, 0.000000000, 0.000000000, 0.000000000]
        fft = torch.fft.rfft(white)
        f = torch.linspace(1, len(fft), len(fft))
        fft = fft / torch.sqrt(f) 
        return torch.fft.irfft(fft, n=length)

    def _augment(self, wav: torch.Tensor, label: int) -> torch.Tensor:
        """Enhanced Forensic suite with Hierarchical Noise Probabilities."""
        if not USE_AUGMENT:
            return wav

        target_length = int(self.target_duration * SAMPLE_RATE)

        # 1. Hardware & Gain (Independent)
        wav = wav * float(torch.empty(1).uniform_(*GAIN_RANGE))
        wav = self._apply_random_eq(wav)

        # 2. Hierarchical Noise Selection (User Specified)
        noise_roll = torch.rand(1).item()
        
        # Mode 1: Untouched (10%) - No noise stage
        if noise_roll < 0.10:
            pass 
            
        # Mode 2: Pink/White Noise (20%)
        elif noise_roll < 0.30:
            snr = torch.randint(20, 35, (1,)).item()
            alpha = 10**(-snr/20)
            if torch.rand(1).item() < 0.7:
                noise = self._generate_pink_noise(len(wav))
            else:
                noise = torch.randn_like(wav)
            wav = torch.clamp(wav + alpha * noise, -1.0, 1.0)

        # Mode 3: Human Room-Hum Injection (35%)
        elif noise_roll < 0.65:
            # Primarily for AI. If human, this effectively stays clean/untouched room hum.
            if label == 1 and self.human_noises:
                 noise = self.human_noises[np.random.randint(len(self.human_noises))]
                 snr = torch.randint(15, 30, (1,)).item()
                 alpha = 10**(-snr/20)
                 if len(noise) < len(wav):
                     noise = noise.repeat(int(np.ceil(len(wav)/len(noise))))[:len(wav)]
                 else:
                     noise = noise[:len(wav)]
                 wav = torch.clamp(wav + alpha * noise, -1.0, 1.0)

        # Mode 4: ESC-50 Environmental (35%)
        else:
            if self.esc50_bank:
                noise = self.esc50_bank[np.random.randint(len(self.esc50_bank))]
                alpha = torch.empty(1).uniform_(0.01, 0.04).item()
                if len(noise) < len(wav):
                    noise = noise.repeat(int(np.ceil(len(wav)/len(noise))))[:len(wav)]
                else:
                    noise = noise[:len(wav)]
                wav = torch.clamp(wav + alpha * noise, -1.0, 1.0)

        # 3. Codec & Digital Artifacts (Independent)
        wav = self._apply_codec_sim(wav)
        
        if torch.rand(1).item() < CLIPPING_PROB:
            limit = torch.empty(1).uniform_(0.7, 0.95).item()
            wav = torch.clamp(wav, -limit, limit)

        # 4. Speed perturbation (Independent)
        if torch.rand(1).item() < 0.3:
            speed = float(torch.empty(1).uniform_(*SPEED_PERTURB_RANGE))
            wav = torchaudio.functional.resample(wav.unsqueeze(0), int(SAMPLE_RATE * speed), SAMPLE_RATE).squeeze(0)

        # Ensure exact length
        if len(wav) > target_length:
            wav = wav[:target_length]
        elif len(wav) < target_length:
            wav = torch.cat([wav, torch.zeros(target_length - len(wav))])

        return wav

    def __getitem__(self, idx):
        path, label, _lang = self.samples[idx]

        if path not in self.cache:
            if len(self.cache) >= self.max_cache_size: self.cache.popitem(last=False)
            self.cache[path] = self._load_full_audio(path)
        else:
            self.cache.move_to_end(path)

        full_audio = self.cache[path]
        if full_audio is None: return None

        cropped_audio = self._crop_to_fixed_duration(full_audio)

        if self.mode == 'train':
            cropped_audio = self._augment(cropped_audio, label)

        return cropped_audio, label

    def get_failed_files(self):
        return self.failed_files

In [None]:
def collate_fn(batch):
    """Fixed-length collate with invalid-sample filtering."""
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None, None

    waves, labels = zip(*batch)
    waves = torch.stack(waves)
    labels = torch.tensor(labels)
    return waves, labels

In [10]:
# Create dataset instances for train and val modes
train_dataset_full = AudioDataset(BASE_DATA, target_duration=TARGET_DURATION,
                                   mode='train', max_cache_size=MAX_CACHE_SIZE)
val_dataset_full = AudioDataset(BASE_DATA, target_duration=TARGET_DURATION,
                                 mode='val', max_cache_size=MAX_CACHE_SIZE)

print(f"Total samples collected: {len(train_dataset_full)}")
print(f"Each audio is exactly {TARGET_DURATION} seconds = {int(TARGET_DURATION * SAMPLE_RATE)} samples")


Total samples collected: 22961
Each audio is exactly 5.0 seconds = 80000 samples


In [None]:
# Create stratified train/val split by (label, language)
strata = [
    f"{train_dataset_full.samples[i][1]}_{train_dataset_full.samples[i][2]}"
    for i in range(len(train_dataset_full))
]
train_indices, val_indices = train_test_split(
    range(len(train_dataset_full)),
    test_size=VAL_SPLIT,
    stratify=strata,
    random_state=42,
)

# Create subset datasets
train_dataset = Subset(train_dataset_full, train_indices)
val_dataset = Subset(val_dataset_full, val_indices)

train_subset_labels = [train_dataset_full.samples[i][1] for i in train_indices]  # 0/1
val_subset_labels = [train_dataset_full.samples[i][1] for i in val_indices]      # 0/1
train_counts = np.bincount(train_subset_labels, minlength=2)
val_counts = np.bincount(val_subset_labels, minlength=2)

print(f"Train samples: {len(train_dataset)} | Human={train_counts[0]}, AI={train_counts[1]}")
print(f"Validation samples: {len(val_dataset)} | Human={val_counts[0]}, AI={val_counts[1]}")
if val_counts.min() == 0:
    print("⚠️ Validation split contains only one class. AUC will be undefined (NaN).")
    print("   Fix: lower VAL_SPLIT, or stratify only by label, or ensure each (label,lang) has enough samples.")

# Class-imbalance handling: compute pos_weight ratio for reference
POS_WEIGHT_AI = float(train_counts[0] / max(train_counts[1], 1))
print(f"POS_WEIGHT_AI (Human/AI): {POS_WEIGHT_AI:.3f}")

# Oversampling AI class through weighted random sampling
# Since all AI data comes from Edge TTS, augmentation diversity is critical
if OVERSAMPLE_AI:
    sample_weights = []
    for idx in train_indices:
        label = train_dataset_full.samples[idx][1]
        # Upweight AI samples to appear ~equally often as Human
        weight = float(train_counts[0] / max(train_counts[1], 1)) if label == 1 else 1.0
        sample_weights.append(weight)
    train_sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(train_indices),
        replacement=True,
    )
else:
    train_sampler = None

# Create dataloaders
# Kaggle T4x2: num_workers=4 is ideal to balance CPU/GPU load
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=(train_sampler is None),   
    sampler=train_sampler,             
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)

if OVERSAMPLE_AI:
    print("Train loader uses WeightedRandomSampler (AI oversampled via augmentation).")
else:
    print("Train loader uses shuffle (no oversampling).")

print(f"Batch shape: ({BATCH_SIZE}, {int(TARGET_DURATION * SAMPLE_RATE)})")
print("Regularization: label smoothing + SpecAugment + dropout.")
print("No padding needed - all samples exactly 5 seconds!")

Train samples: 18368 | Human=10640, AI=7728
Validation samples: 4593 | Human=2660, AI=1933

Batch shape: (16, 80000)
Train loader uses oversampling to balance classes.
No padding needed - all samples exactly 5 seconds!


In [None]:
# Load WavLM backbone
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base")

# Multi-GPU Support for Kaggle T4x2
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    wavlm = nn.DataParallel(wavlm)
wavlm.to(DEVICE)

# Selective Unfreezing
if UNFREEZE_TOP_LAYERS > 0:
    # Handle DataParallel naming if active
    model_ref = wavlm.module if hasattr(wavlm, "module") else wavlm
    for name, param in model_ref.named_parameters():
        if "encoder.layers" in name:
            layer_num = int(name.split("encoder.layers.")[1].split(".")[0])
            if layer_num >= (12 - UNFREEZE_TOP_LAYERS):
                param.requires_grad = True
            else:
                param.requires_grad = False
        else:
            param.requires_grad = False
    print(f"Fine-tuning top {UNFREEZE_TOP_LAYERS} layers.")
else:
    for param in wavlm.parameters():
        param.requires_grad = False
    print("Backbone frozen.")

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/378M [00:00<?, ?B/s]

WavLM backbone loaded and frozen.


In [None]:
class AASISTHead(nn.Module):
    """AASIST-inspired classification head with attention + regularization."""

    def __init__(self, dim=768, dropout=DROPOUT_P, n_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, 64),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = self.norm(x + attn_out)
        pooled = x.mean(dim=1)
        return self.mlp(pooled)


class OCSoftmaxHead(nn.Module):
    """Regularized one-class style head (trained with BCE)."""

    def __init__(self, dim=768, dropout=DROPOUT_P):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, 1),
        )

    def forward(self, x):
        pooled = self.norm(x.mean(dim=1))
        return self.mlp(pooled)


# Initialize classification heads with Multi-GPU support
aasist = AASISTHead().to(DEVICE)
ocsoft = OCSoftmaxHead().to(DEVICE)

if torch.cuda.device_count() > 1:
    aasist = nn.DataParallel(aasist)
    ocsoft = nn.DataParallel(ocsoft)

# Print model summary
model_ref = wavlm.module if hasattr(wavlm, "module") else wavlm
total_params = sum(p.numel() for p in model_ref.parameters())
trainable_params = sum(p.numel() for p in aasist.parameters()) + sum(p.numel() for p in ocsoft.parameters())

print(f"\nModel Summary:")
print(f"WavLM parameters (frozen): {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Total parameters: {total_params + trainable_params:,}")


Model Summary:
WavLM parameters (frozen): 94,381,936
Trainable parameters: 2,775,938
Total parameters: 97,157,874


In [None]:
# With oversampling via WeightedRandomSampler, no pos_weight needed.
# Label smoothing is applied in the training loop, not in the loss function.
criterion = nn.BCEWithLogitsLoss()

# Collect all trainable parameters (backbone + heads)
trainable_params = [p for p in list(wavlm.parameters()) + list(aasist.parameters()) + list(ocsoft.parameters()) if p.requires_grad]

optimizer = torch.optim.AdamW(
    trainable_params,
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
 )

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=2
 )

In [None]:
def validate(wavlm, aasist, ocsoft, val_loader, criterion, device):
    """Validation function."""
    aasist.eval()
    ocsoft.eval()

    total_loss = 0
    all_preds = []
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for wavs, labels in tqdm(val_loader, desc="Validating", leave=False, mininterval=1.0):
            if wavs is None:
                continue

            wavs = wavs.float().to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            feats = wavlm(wavs).last_hidden_state

            logits_aasist = aasist(feats).squeeze(-1)
            logits_oc = ocsoft(feats).squeeze(-1)

            loss1 = criterion(logits_aasist, labels.float())
            loss2 = criterion(logits_oc, labels.float())
            loss = AASIST_WEIGHT * loss1 + OCSOFT_WEIGHT * loss2

            total_loss += loss.item()

            score_aasist = torch.sigmoid(logits_aasist)
            score_oc = torch.sigmoid(logits_oc)
            final_score = AASIST_WEIGHT * score_aasist + OCSOFT_WEIGHT * score_oc

            all_scores.extend(final_score.detach().cpu().numpy())
            all_labels.extend(labels.detach().cpu().numpy())
            all_preds.extend((final_score > 0.5).detach().cpu().numpy())

    avg_loss = total_loss / max(len(all_labels) // BATCH_SIZE, 1)  # use actual processed count
    accuracy = accuracy_score(all_labels, all_preds) if len(all_labels) else 0.0
    precision = precision_score(all_labels, all_preds, zero_division=0) if len(all_labels) else 0.0
    recall = recall_score(all_labels, all_preds, zero_division=0) if len(all_labels) else 0.0
    f1 = f1_score(all_labels, all_preds, zero_division=0) if len(all_labels) else 0.0

    labels_np = np.asarray(all_labels)
    scores_np = np.asarray(all_scores)
    auc = float("nan")
    if len(scores_np) and np.isfinite(scores_np).all() and len(np.unique(labels_np)) == 2:
        auc = roc_auc_score(labels_np, scores_np)

    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc,
        'scores': all_scores,
        'labels': all_labels
    }

In [None]:
history = {
    'train_loss': [],
    'val_loss': [],
    'val_accuracy': [],
    'val_auc': []
}

best_auc = 0.0
no_improve = 0

for epoch in range(EPOCHS):
    # Training phase
    aasist.train()
    ocsoft.train()
    if UNFREEZE_TOP_LAYERS > 0:
        wavlm.train() # Enable dropout/norm in unfrozen layers
    else:
        wavlm.eval()

    total_loss = 0.0
    seen_batches = 0
    progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False, mininterval=1.0)

    for step, (wavs, labels) in enumerate(progress):
        if wavs is None:
            continue

        wavs = wavs.float().to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)

        optimizer.zero_grad()

        # Backbone inference: partial gradients if unfrozen
        if UNFREEZE_TOP_LAYERS > 0:
            output = wavlm(wavs)
            feats = output.last_hidden_state
        else:
            with torch.no_grad():
                output = wavlm(wavs)
                feats = output.last_hidden_state

        # SpecAugment: random frequency + temporal masking on WavLM features
        if SPEC_AUG_FREQ_MASKS > 0 or SPEC_AUG_TIME_MASKS > 0:
            feats = feats.clone()
            if SPEC_AUG_FREQ_MASKS > 0:
                for _ in range(SPEC_AUG_FREQ_MASKS):
                    f_start = torch.randint(0, max(feats.size(-1) - SPEC_AUG_FREQ_WIDTH, 1), (1,)).item()
                    feats[:, :, f_start:f_start + SPEC_AUG_FREQ_WIDTH] = 0.0
            if SPEC_AUG_TIME_MASKS > 0:
                for _ in range(SPEC_AUG_TIME_MASKS):
                    t_start = torch.randint(0, max(feats.size(1) - SPEC_AUG_TIME_WIDTH, 1), (1,)).item()
                    feats[:, t_start:t_start + SPEC_AUG_TIME_WIDTH, :] = 0.0

        logits_aasist = aasist(feats).squeeze(-1)
        logits_oc = ocsoft(feats).squeeze(-1)

        # Label smoothing: prevents loss collapse to 0
        smooth_labels = labels.float() * (1.0 - 2.0 * LABEL_SMOOTHING) + LABEL_SMOOTHING

        loss1 = criterion(logits_aasist, smooth_labels)
        loss2 = criterion(logits_oc, smooth_labels)
        loss = AASIST_WEIGHT * loss1 + OCSOFT_WEIGHT * loss2

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, GRAD_CLIP_NORM)
        optimizer.step()

        total_loss += float(loss.item())
        seen_batches += 1
        
        if step % 20 == 0:
            ai_rate = float(labels.float().mean().item())
            progress.set_postfix({
                'loss': f'{loss.item():.3f}',
                'ai%': f'{100.0 * ai_rate:.0f}',
                'bs': int(labels.numel()),
            })

    avg_train_loss = total_loss / max(seen_batches, 1)
    history['train_loss'].append(avg_train_loss)

    # Validation phase
    val_metrics = validate(wavlm, aasist, ocsoft, val_loader, criterion, DEVICE)

    history['val_loss'].append(val_metrics['loss'])
    history['val_accuracy'].append(val_metrics['accuracy'])
    history['val_auc'].append(val_metrics['auc'])

    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {avg_train_loss:.4f} | Val Loss: {val_metrics['loss']:.4f}")
    print(f"  Accuracy: {val_metrics['accuracy']:.4f} | AUC: {val_metrics['auc']}")

    metric_for_scheduler = val_metrics['auc'] if np.isfinite(val_metrics['auc']) else val_metrics['accuracy']
    scheduler.step(metric_for_scheduler)

    # Save checkpoint
    ckpt = {
        'wavlm': wavlm.state_dict(),
        'aasist': aasist.state_dict(),
        'ocsoft': ocsoft.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'epoch': epoch + 1,
        'val_metrics': val_metrics,
        'history': history,
        'config': { 'target_duration': TARGET_DURATION, 'sample_rate': SAMPLE_RATE },
    }

    torch.save(ckpt, f"{SAVE_DIR}/latest_model.pt")
    if epoch == 0: torch.save(ckpt, f"{SAVE_DIR}/epoch_1.pt")

    current_auc = val_metrics['auc']
    if np.isfinite(current_auc) and current_auc > best_auc + MIN_DELTA_AUC:
        best_auc = current_auc
        no_improve = 0
        torch.save(ckpt, f"{SAVE_DIR}/best_model.pt")
        print(f"  ✓ Save Best Model (AUC: {best_auc:.4f})")
    else:
        no_improve += 1
        if no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"  Early stopping triggered at epoch {epoch+1}")
            break
    print("-" * 30)


Epoch 1/5: 100%|██████████| 1330/1330 [48:22<00:00,  2.18s/it, loss=0.0199]
Validating: 100%|██████████| 288/288 [15:39<00:00,  3.26s/it]



Epoch 1 Summary:
  Train Loss: 0.0360
  Val Loss: 0.0091
  Val Accuracy: 0.9978
  Val Precision: 0.9959
  Val Recall: 0.9990
  Val F1: 0.9974
  Val AUC: 0.9999669375196917
  ✓ Best model saved (AUC: 1.0000)



Epoch 2/5: 100%|██████████| 1330/1330 [29:07<00:00,  1.31s/it, loss=0.0003]
Validating: 100%|██████████| 288/288 [04:47<00:00,  1.00it/s]



Epoch 2 Summary:
  Train Loss: 0.0148
  Val Loss: 0.0107
  Val Accuracy: 0.9974
  Val Precision: 0.9964
  Val Recall: 0.9974
  Val F1: 0.9969
  Val AUC: 0.9999858025819852
  Early-stopping counter: 1/3



Epoch 3/5: 100%|██████████| 1330/1330 [24:27<00:00,  1.10s/it, loss=0.0021]
Validating: 100%|██████████| 288/288 [04:56<00:00,  1.03s/it]



Epoch 3 Summary:
  Train Loss: 0.0126
  Val Loss: 0.0083
  Val Accuracy: 0.9978
  Val Precision: 0.9959
  Val Recall: 0.9990
  Val F1: 0.9974
  Val AUC: 0.9999900812559076
  Early-stopping counter: 2/3



Epoch 4/5: 100%|██████████| 1330/1330 [23:58<00:00,  1.08s/it, loss=0.0010]
Validating: 100%|██████████| 288/288 [04:51<00:00,  1.01s/it]



Epoch 4 Summary:
  Train Loss: 0.0114
  Val Loss: 0.0075
  Val Accuracy: 0.9980
  Val Precision: 0.9964
  Val Recall: 0.9990
  Val F1: 0.9977
  Val AUC: 0.9999963047816126
  Early-stopping counter: 3/3
Early stopping triggered.
