In [None]:
"""
MS-SNSD VAD Classifier for Clean Speech vs. Noise
--------------------------------------------------
Dataset details:
- CleanSpeech: clean speech files (16 kHz, 16-bit WAV)
- Noise: noise files (16 kHz, 16-bit WAV)
The task is to classify entire audio files as either speech (label 1) or noise (label 0).

This script builds a robust CNN classifier using MFCC spectrograms as input.
"""

import os
import torch
import torchaudio
import librosa
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import torch.nn as nn
import torch.optim as optim

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ========================
# Audio Loading and MFCC Extraction Functions
# ========================

def load_audio(file_path, sample_rate=16000):
    try:
        waveform, sr = torchaudio.load(file_path)
        if sr != sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
            waveform = resampler(waveform)
        return waveform.numpy().flatten()
    except RuntimeError as e:
        print(f"Error loading {file_path}: {e}")
        return np.array([])

def extract_mfcc_spectrogram(audio, sr=16000, n_mfcc=40, n_fft=512, hop_length=160):
    """
    Compute an MFCC spectrogram.
    Parameters:
      - n_mfcc: number of MFCC coefficients (using 40 here for better representation)
      - n_fft: FFT window length (set to the minimum of n_fft and len(audio))
      - hop_length: hop length between frames
    Returns a numpy array of shape (n_mfcc, time_frames).
    """
    n_fft = min(n_fft, len(audio))
    mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)
    return mfccs

# ========================
# Custom Collate Function to Pad Variable Length MFCC Spectrograms
# ========================
def pad_collate_fn(batch):
    """
    Each sample in the batch is a tuple (features, label) where features is a 2D numpy array
    of shape (n_mfcc, frames). We pad along the time axis so that all samples have the same number of frames.
    Finally, we add a channel dimension for CNN input.
    """
    # Filter out any None samples (if any)
    batch = [sample for sample in batch if sample is not None]
    if len(batch) == 0:
        return None
    features, labels = zip(*batch)
    # Convert features to tensors and get maximum time frames
    feature_tensors = [torch.tensor(f, dtype=torch.float32) for f in features]
    max_frames = max(f.shape[1] for f in feature_tensors)
    # Pad each sample (pad along time dimension, i.e. dimension 1)
    padded_features = []
    for f in feature_tensors:
        pad_amount = max_frames - f.shape[1]
        if pad_amount > 0:
            # Pad on the right side along the time dimension
            f = nn.functional.pad(f, (0, pad_amount), "constant", 0)
        padded_features.append(f)
    # Stack into a tensor and add channel dimension: shape (batch_size, 1, n_mfcc, max_frames)
    batch_features = torch.stack(padded_features).unsqueeze(1)
    batch_labels = torch.tensor(labels, dtype=torch.long)
    return batch_features, batch_labels

# ========================
# Dataset Classes
# ========================

class CleanSpeechDataset(Dataset):
    def __init__(self, speech_dir, sample_rate=16000):
        self.sample_rate = sample_rate
        self.file_paths = []
        self.labels = []  # Label 1 for speech
        # Recursively find all WAV (or FLAC) files
        for root, dirs, files in os.walk(speech_dir):
            for f in files:
                if f.lower().endswith('.wav') or f.lower().endswith('.flac'):
                    self.file_paths.append(os.path.join(root, f))
                    self.labels.append(1)
        print(f"CleanSpeechDataset: Found {len(self.file_paths)} speech files.")

    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        audio = load_audio(self.file_paths[idx], self.sample_rate)
        if audio.size == 0:
            raise ValueError(f"Failed to load {self.file_paths[idx]}")
        # Compute MFCC spectrogram (without averaging)
        mfcc_spec = extract_mfcc_spectrogram(audio, sr=self.sample_rate)
        label = self.labels[idx]
        return mfcc_spec, label

class NoiseDataset(Dataset):
    def __init__(self, noise_dir, sample_rate=16000):
        self.sample_rate = sample_rate
        self.file_paths = [os.path.join(noise_dir, f) for f in os.listdir(noise_dir) if f.lower().endswith('.wav')]
        self.labels = [0] * len(self.file_paths)  # Label 0 for noise
        print(f"NoiseDataset: Found {len(self.file_paths)} noise files.")

    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        audio = load_audio(self.file_paths[idx], self.sample_rate)
        if audio.size == 0:
            raise ValueError(f"Failed to load {self.file_paths[idx]}")
        mfcc_spec = extract_mfcc_spectrogram(audio, sr=self.sample_rate)
        label = self.labels[idx]
        return mfcc_spec, label

# ========================
# CNN Classifier Definition
# ========================

class VAD_CNN(nn.Module):
    def __init__(self):
        super(VAD_CNN, self).__init__()
        # Input shape: (batch, 1, n_mfcc, time_frames). Here we use n_mfcc=40.
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1   = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2   = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Global average pooling: Adaptive pooling to (1, 1)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(64, 32)
        self.fc2 = nn.Linear(32, 2)  # 2 classes

    def forward(self, x):
        # x shape: (batch, 1, 40, time_frames)
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.global_pool(x)  # shape: (batch, 64, 1, 1)
        x = x.view(x.size(0), -1)  # shape: (batch, 64)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# ========================
# Training and Evaluation Function
# ========================

def train_and_evaluate(clean_speech_dir, noise_dir):
    # Create datasets for clean speech and noise
    speech_ds = CleanSpeechDataset(clean_speech_dir)
    noise_ds = NoiseDataset(noise_dir)
    # Merge datasets
    combined_ds = ConcatDataset([speech_ds, noise_ds])
    print(f"Combined dataset contains {len(combined_ds)} samples.")
    
    # Split into training and testing sets (80-20 split)
    all_samples = list(combined_ds)
    train_samples, test_samples = train_test_split(all_samples, test_size=0.2, random_state=42)
    
    # Create DataLoaders with custom collate function for padding
    train_loader = DataLoader(train_samples, batch_size=32, shuffle=True, collate_fn=pad_collate_fn)
    test_loader = DataLoader(test_samples, batch_size=32, shuffle=False, collate_fn=pad_collate_fn)
    
    # Initialize CNN model, loss function, optimizer
    model = VAD_CNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    num_epochs = 30
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for features, labels in train_loader:
            features, labels = features.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * features.size(0)
        epoch_loss = running_loss / len(train_samples)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    
    # Evaluation
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for features, labels in test_loader:
            features, labels = features.to(device), labels.to(device)
            outputs = model(features)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, zero_division=0)
    rec = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    
    print("\nEvaluation Metrics:")
    print(f"Accuracy:  {acc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall:    {rec:.4f}")
    print(f"F1 Score:  {f1:.4f}")
    
    # Plot metrics
    metrics_labels = ['Accuracy', 'Precision', 'Recall', 'F1']
    metrics_values = [acc, prec, rec, f1]
    x = np.arange(len(metrics_labels))
    width = 0.35
    fig, ax = plt.subplots()
    rects = ax.bar(x, metrics_values, width, label="Metrics")
    ax.set_ylabel("Scores")
    ax.set_title("VAD Classifier Performance on MS-SNSD")
    ax.set_xticks(x)
    ax.set_xticklabels(metrics_labels)
    ax.legend()
    def autolabel(rects):
        for rect in rects:
            height = rect.get_height()
            ax.annotate(f"{height:.4f}",
                        xy=(rect.get_x()+rect.get_width()/2, height),
                        xytext=(0,3), textcoords="offset points",
                        ha="center", va="bottom")
    autolabel(rects)
    fig.tight_layout()
    plt.show()

# ========================
# Run the training and evaluation
# ========================

# Update these paths to your local directories for MS-SNSD
clean_speech_dir = "./CleanSpeech"  # Directory containing clean speech files
noise_dir = "./Noise"               # Directory containing noise files

train_and_evaluate(clean_speech_dir, noise_dir)