In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
from tqdm import tqdm
from typing import List
import random

# -------- Dataset --------
class MultilingualDeepfakeDataset(Dataset):
    def __init__(self, root_dir, sr=16000, max_len=160000):
        self.sr = sr
        self.max_len = max_len
        self.samples = []
        self.lang_map = {}

        lang_folders = os.listdir(root_dir)
        for i, lang in enumerate(lang_folders):
            self.lang_map[lang] = i
            for label_name in ["real", "fake"]:
                label = 0 if label_name == "real" else 1
                folder_path = os.path.join(root_dir, lang, label_name)
                for fname in os.listdir(folder_path):
                    fpath = os.path.join(folder_path, fname)
                    self.samples.append((fpath, label, lang))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label, lang = self.samples[idx]
        waveform, sr = torchaudio.load(path)
        waveform = torchaudio.functional.resample(waveform, sr, self.sr)
        waveform = waveform[0]  # mono

        if len(waveform) > self.max_len:
            waveform = waveform[:self.max_len]
        else:
            pad_len = self.max_len - len(waveform)
            waveform = F.pad(waveform, (0, pad_len))

        return waveform, torch.tensor(label, dtype=torch.float32), torch.tensor(self.lang_map[lang])

# -------- Collate --------
def collate_fn(batch):
    waveforms, labels, langs = zip(*batch)
    waveforms = torch.stack(waveforms)
    labels = torch.stack(labels)
    langs = torch.tensor(langs)
    return waveforms, labels, langs

# -------- Model --------
class MultiTaskHead(nn.Module):
    def __init__(self, in_dim, num_langs):
        super().__init__()
        self.deepfake = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )
        self.lang_head = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_langs)
        )

    def forward(self, x):
        return self.deepfake(x).squeeze(1), self.lang_head(x)

class Wav2VecMultilingualClassifier(nn.Module):
    def __init__(self, num_langs, freeze_encoder=True):
        super().__init__()
        self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53")
        if freeze_encoder:
            for param in self.wav2vec.parameters():
                param.requires_grad = False
        self.head = MultiTaskHead(self.wav2vec.config.hidden_size, num_langs)

    def forward(self, x):
        out = self.wav2vec(x).last_hidden_state
        pooled = out.mean(dim=1)
        return self.head(pooled)

# -------- Training --------
def train_model(model, train_loader, val_loader, num_langs, device, epochs=10, lr=1e-4, λ=0.5):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(train_loader), epochs=epochs)
    scaler = torch.cuda.amp.GradScaler()

    bce_loss = nn.BCEWithLogitsLoss()
    ce_loss = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}")

        for x, y, lang in loop:
            x, y, lang = x.to(device), y.to(device), lang.to(device)
            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                df_logits, lang_logits = model(x)
                loss_df = bce_loss(df_logits, y)
                loss_lang = ce_loss(lang_logits, lang)
                loss = loss_df + λ * loss_lang

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            total_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        print(f"✅ Epoch {epoch+1} Avg Loss: {total_loss / len(train_loader):.4f}")
        evaluate_model(model, val_loader, device)

# -------- Evaluation --------
@torch.no_grad()
def evaluate_model(model, loader, device):
    model.eval()
    correct_df, correct_lang = 0, 0
    total = 0
    for x, y, lang in loader:
        x, y, lang = x.to(device), y.to(device), lang.to(device)
        df_logits, lang_logits = model(x)
        df_preds = (torch.sigmoid(df_logits) > 0.5).float()
        lang_preds = torch.argmax(lang_logits, dim=1)

        correct_df += (df_preds == y).sum().item()
        correct_lang += (lang_preds == lang).sum().item()
        total += y.size(0)

    print(f"🎯 Deepfake Acc: {correct_df / total:.4f}, Lang ID Acc: {correct_lang / total:.4f}")


2025-06-19 09:28:56.753076: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750325337.179177      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750325337.285712      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
data_root = "/kaggle/input/multi-lingual-audio-deepfake-detection-challenge/archive"  # 🔁 Replace this
batch_size = 8
num_epochs = 10

dataset = MultilingualDeepfakeDataset(data_root)
num_langs = len(dataset.lang_map)

# Train/Val split
indices = list(range(len(dataset)))
random.shuffle(indices)
split = int(0.8 * len(indices))
train_idx, val_idx = indices[:split], indices[split:]

train_ds = torch.utils.data.Subset(dataset, train_idx)
val_ds = torch.utils.data.Subset(dataset, val_idx)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [None]:
model = Wav2VecMultilingualClassifier(num_langs=num_langs, freeze_encoder=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_model(model, train_loader, val_loader, num_langs, device, epochs=num_epochs)

config.json:   0%|          | 0.00/1.77k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

  scaler = torch.cuda.amp.GradScaler()

  with torch.cuda.amp.autocast():

Epoch 1:   0%|          | 0/24728 [00:05<?, ?it/s, loss=1.04][A
Epoch 1:   0%|          | 1/24728 [00:05<39:35:12,  5.76s/it, loss=1.04][A
Epoch 1:   0%|          | 1/24728 [00:06<39:35:12,  5.76s/it, loss=1.02][A
Epoch 1:   0%|          | 2/24728 [00:06<19:39:23,  2.86s/it, loss=1.02][A
Epoch 1:   0%|          | 2/24728 [00:07<19:39:23,  2.86s/it, loss=1.03][A
Epoch 1:   0%|          | 3/24728 [00:07<13:17:11,  1.93s/it, loss=1.03][A
Epoch 1:   0%|          | 3/24728 [00:08<13:17:11,  1.93s/it, loss=1.05][A
Epoch 1:   0%|          | 4/24728 [00:08<10:22:21,  1.51s/it, loss=1.05][A
Epoch 1:   0%|          | 4/24728 [00:09<10:22:21,  1.51s/it, loss=1.02][A
Epoch 1:   0%|          | 5/24728 [00:09<8:39:50,  1.26s/it, loss=1.02] [A
Epoch 1:   0%|          | 5/24728 [00:09<8:39:50,  1.26s/it, loss=1.02][A
Epoch 1:   0%|          | 6/24728 [00:09<7:38:14,  1.11s/it, loss=1.02][A
Epoch 1:   0%|          | 