In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/working/'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# ecapa_tdnn_kaggle.py
import os
import glob
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader

##############################################
# VoxCeleb1 Dataset Loader
##############################################
class VoxCeleb1Dataset(Dataset):
    """
    A simple dataset class for VoxCeleb1.
    Expected folder structure:
      data_dir/
          wav/
              id00001/
                  sample1.wav
                  sample2.wav
              id00002/
                  sample1.wav
                  ...
    """
    def __init__(self, data_dir, transform=None):
        self.data = []     # list of file paths
        self.labels = []   # numeric label per file
        self.label2idx = {}  # mapping from speaker id to integer label

        # Look for all speaker directories (assuming names like idXXXX)
        speaker_dirs = glob.glob(os.path.join(data_dir, 'vox1_dev_wav', 'id*', '*'))
        for spkr_dir in speaker_dirs:
            spkr_id = os.path.basename(spkr_dir)
            if spkr_id not in self.label2idx:
                self.label2idx[spkr_id] = len(self.label2idx)
            # Collect all wav files for this speaker
            wav_files = glob.glob(os.path.join(spkr_dir, '*.wav'))
            for wav_file in wav_files:
                self.data.append(wav_file)
                self.labels.append(self.label2idx[spkr_id])
                
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Load waveform and its sampling rate
        file_path = self.data[idx]
        waveform, sr = torchaudio.load(file_path)
        # Resample to 16 kHz if necessary (the model expects 16kHz mono)
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(sr, 16000)
            waveform = resampler(waveform)
        # If stereo, take the first channel
        if waveform.shape[0] > 1:
            waveform = waveform[0:1, :]
        # Apply any extra transform (if provided)
        if self.transform:
            waveform = self.transform(waveform)
        label = self.labels[idx]
        return waveform, label

##############################################
# Simplified ECAPA-TDNN Model
##############################################
class ECAPATDNN(nn.Module):
    """
    A simplified ECAPA-TDNN architecture.
    It computes log Mel-spectrograms on the fly, passes through several
    1D convolutional layers and uses an attentive statistical pooling
    layer. Finally, it outputs an embedding and class logits.
    """
    def __init__(self, num_classes, channels=512, emb_dim=192):
        super(ECAPATDNN, self).__init__()
        # Instead of the full ECAPA-TDNN, we use a few 1D Conv layers.
        # The input feature dimension will be 80 (log Mel bins).
        self.conv1 = nn.Conv1d(in_channels=80, out_channels=512, kernel_size=5, dilation=1)
        self.bn1   = nn.BatchNorm1d(512)
        self.conv2 = nn.Conv1d(512, 512, kernel_size=3, dilation=2)
        self.bn2   = nn.BatchNorm1d(512)
        self.conv3 = nn.Conv1d(512, 512, kernel_size=3, dilation=3)
        self.bn3   = nn.BatchNorm1d(512)
        self.relu  = nn.ReLU()

        # Attentive Statistical Pooling
        self.attention = nn.Sequential(
            nn.Conv1d(512, 128, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(128, 512, kernel_size=1),
            nn.Softmax(dim=2)
        )
        # Fully-connected layer for speaker embedding from pooled statistics
        self.fc_emb = nn.Linear(512 * 2, emb_dim)  # concatenation of mean and std
        # Final classifier for training (speaker identification)
        self.classifier = nn.Linear(emb_dim, num_classes)

        # MelSpectrogram transformation (fixed parameters)
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000, n_fft=400, win_length=400, hop_length=160, n_mels=80
        )

    def forward(self, waveform):
        """
        Input:
          waveform: Tensor of shape (batch, 1, time)
        """
        # Compute Mel-spectrogram and then take the logarithm.
        # Output shape: (batch, n_mels, time_frames)
        mel_spec = self.mel_transform(waveform)  # (B, 80, T)
        log_mel = torch.log(mel_spec + 1e-6)

        # Pass through convolutional layers
        x = self.conv1(log_mel)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)  # now x has shape (B, 512, T')

        # Apply attentive statistical pooling
        attn = self.attention(x)  # (B, 512, T')
        # Compute weighted mean
        mean = torch.sum(x * attn, dim=2)
        # Compute weighted standard deviation
        std = torch.sqrt(torch.sum(((x - mean.unsqueeze(2)) ** 2) * attn, dim=2) + 1e-6)
        stat_pool = torch.cat((mean, std), dim=1)  # (B, 1024)

        # Obtain speaker embedding
        emb = self.fc_emb(stat_pool)  # (B, emb_dim)

        # Classification logits (for training)
        logits = self.classifier(emb)  # (B, num_classes)
        return logits, emb

##############################################
# Training and Evaluation Functions
##############################################
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for wave, labels in dataloader:
        wave = wave.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits, _ = model(wave)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * wave.size(0)
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

def evaluate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    with torch.no_grad():
        for wave, labels in dataloader:
            wave = wave.to(device)
            labels = labels.to(device)
            logits, _ = model(wave)
            loss = criterion(logits, labels)
            running_loss += loss.item() * wave.size(0)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
    epoch_loss = running_loss / len(dataloader.dataset)
    accuracy = correct / len(dataloader.dataset)
    return epoch_loss, accuracy

##############################################
# Main training script
##############################################
def main(args):
    # Set up device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # Create dataset – assume the Kaggle dataset folder is provided via --data_dir
    full_dataset = VoxCeleb1Dataset(args.data_dir)
    num_classes = len(full_dataset.label2idx)
    print(f"Found {len(full_dataset)} samples from {num_classes} speakers.")

    # Split dataset into training and validation (e.g., 90/10 split)
    indices = list(range(len(full_dataset)))
    random.shuffle(indices)
    split = int(0.9 * len(full_dataset))
    train_indices, valid_indices = indices[:split], indices[split:]
    train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
    valid_dataset = torch.utils.data.Subset(full_dataset, valid_indices)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)

    # Create model instance
    model = ECAPATDNN(num_classes=num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    best_acc = 0.0
    for epoch in range(1, args.epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
        valid_loss, valid_acc = evaluate_epoch(model, valid_loader, criterion, device)
        print(f"Epoch {epoch}/{args.epochs}: Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_acc*100:.2f}%")
        # Save the model if validation accuracy improves
        if valid_acc > best_acc:
            best_acc = valid_acc
            torch.save(model.state_dict(), "best_ecapa_tdnn.pth")
            print("Saved best model with acc {:.2f}%".format(best_acc * 100))
    print("Training complete. Best validation accuracy: {:.2f}%".format(best_acc * 100))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train simplified ECAPA-TDNN on VoxCeleb1 (Kaggle)")
    parser.add_argument("--data_dir", type=str, default="/kaggle/input/audiodataset10percent/VoxCeleb/", help="Path to VoxCeleb1 dataset folder")
    parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
    args = parser.parse_args()
    main(args)