# Ensemble model training to detect  Infant Cries, Screams, and Normal Utterances

## Dataset preparation

* For Infant cry, I chose [Infant Cry Audio Corpus](https://www.kaggle.com/datasets/warcoder/infant-cry-audio-corpus?utm_source=chatgpt.com) from Kaggle and [Infant's Cry Sound](https://www.kaggle.com/datasets/warcoder/infant-cry-audio-corpus?utm_source=chatgpt.com) from Mendeley Data. I combined both datasets to get a larger dataset named "infant_cry" with 520 audio files.

* For Screams, I chose [Human Screaming Detection Dataset](https://www.kaggle.com/datasets/whats2000/human-screaming-detection-dataset?resource=download&select=Screaming) from kaggle and extracted 550 audio files and saved them in folder "scream".

* For Normal Utterances, I chose **Common Voice Dataset** from Mozilla and extracted 535 audio files from it and saved them in folder "normal_utterance".

## Justification of the dataset
1. **Balanced Classes:**
    * The dataset is balanced with 520 audio files for Infant Cry, 550 audio files for Screams, and 535 audio files for Normal Utterances.
    * This prevents the model from being biased towards a particular class.
2. **Sample Size Adequacy:** 
    * With a total of 1605 audio samples, the dataset provides a solid foundation for fine-tuning pre-trained models like YAMNet and Wav2Vec2.
    * While not excessively large, this sample size is sufficient to adapt the pre-trained models to the specific nuances of infant cries, screams, and normal utterances.
    * The dataset size allows for a reasonable split into training (70%), validation (15%), and testing (15%) sets, providing enough data for training while reserving sufficient samples for robust validation and performance evaluation.

3. **Dataset Diversity:** 
    * The combination of datasets from various sources (Kaggle, Mendeley Data, Mozilla Common Voice) introduces diversity in recording conditions, environments, and subject demographics.
    * This diversity is essential for building models that are robust to real-world variability and can perform reliably across different scenarios.
    * The inclusion of normal utterances from the Common Voice dataset ensures that the models are exposed to a wide range of speech patterns, improving their ability to distinguish between normal speech and distress vocalizations.

## Libraries Used

In [17]:
import os
import glob
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import torch
import torchaudio
import numpy as np
import librosa
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.models as models
from tqdm.auto import tqdm
import tensorflow as tf
import tensorflow_hub as hub
from transformers import Wav2Vec2Model, Wav2Vec2Config
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    classification_report, confusion_matrix, 
    roc_curve, auc, precision_recall_curve
)
from sklearn.model_selection import KFold
from datetime import datetime
import json
import warnings
warnings.filterwarnings("ignore", message="Exception ignored in.*DataLoader.*shutdown_workers")

In [18]:
# Create a df with columns: filepath, label(0-infant_cry; 1-scream; 2-normal_utterance)
def create_df(data_dir):
    df = pd.DataFrame()
    for label in ['infant_cry', 'scream', 'normal_utterance']:
        wav_files = glob.glob(os.path.join(data_dir, label, '*.wav'))
        mp3_files = glob.glob(os.path.join(data_dir, label, '*.mp3'))
        files = wav_files + mp3_files
        if not files:
            print(f"No files found for label: {label}")  # Debugging print
            continue  # Skip to the next label if no files are found

        temp_df = pd.DataFrame(files, columns=['filepath'])
        if label == 'infant_cry':
            temp_df['label'] = 0
        elif label == 'scream':
            temp_df['label'] = 1
        else:
            temp_df['label'] = 2

        df = pd.concat([df, temp_df], ignore_index=True)

    return df

In [19]:
df = create_df('/kaggle/input/frontera-health-sounds-dataset/dataset/')
df.head()

Unnamed: 0,filepath,label
0,/kaggle/input/frontera-health-sounds-dataset/d...,0
1,/kaggle/input/frontera-health-sounds-dataset/d...,0
2,/kaggle/input/frontera-health-sounds-dataset/d...,0
3,/kaggle/input/frontera-health-sounds-dataset/d...,0
4,/kaggle/input/frontera-health-sounds-dataset/d...,0


In [20]:
# no. of files in each class
df['label'].value_counts()

label
1    550
2    535
0    514
Name: count, dtype: int64

## Data splitting

In [21]:
# do stratified train-val-test split in 70-15-15 ratio
train_df, test_df = train_test_split(df, test_size=0.3, stratify=df['label'], random_state=42)
val_df, test_df = train_test_split(test_df, test_size=0.5, stratify=test_df['label'], random_state=42)
# Print dataset distribution
print("Dataset Distribution:")
print("\nTrain:")
print(train_df['label'].value_counts().to_frame().to_string())
print("\nValidation:")
print(val_df['label'].value_counts().to_frame().to_string())
print("\nTest:")
print(test_df['label'].value_counts().to_frame().to_string())

Dataset Distribution:

Train:
       count
label       
1        385
2        374
0        360

Validation:
       count
label       
1         82
2         81
0         77

Test:
       count
label       
1         83
2         80
0         77


## Configuration of the models

In [22]:
class CFG:
    # Audio Processing Parameters
    sample_rate = 16000
    yamnet_frame_duration = 0.96    # YAMNet's expected frame duration in seconds
    yamnet_hop_duration = 0.48      # YAMNet's hop duration in seconds
    wav2vec_duration = 10.0         # Duration for Wav2Vec2
    yamnet_duration = 10.0
    # YAMNet specific parameters
    yamnet_samples_per_frame = int(yamnet_frame_duration * sample_rate)  # 15360 samples
    yamnet_hop_samples = int(yamnet_hop_duration * sample_rate)          # 7680 samples
    
    # Spectrogram Parameters (YAMNet specific)
    n_mels = 64                     # YAMNet uses 64 mel bins
    window_size = int(0.025 * sample_rate)  # 25ms window (400 samples at 16kHz)
    hop_length = int(0.01 * sample_rate)    # 10ms hop (160 samples at 16kHz)
    fmin = 125                      # YAMNet's minimum frequency
    fmax = 7500                     # YAMNet's maximum frequency
    
    # Training Parameters
    epochs = 10
    batch_size = 32                 # Reduced from 64 to handle longer sequences
    num_classes = 3                 # Our specific classes
    
    # Label Mapping
    int2label = {
        0: 'infant_cry',
        1: 'scream',
        2: 'normal_utterance'
    }
    
    # Augmentation Parameters
    aug_prob = 0.5
    time_stretch_range = (0.8, 1.2)
    pitch_shift_range = (-2, 2)
    noise_factor = (0.001, 0.015)
    
    # Cross-validation
    n_fold = 5
    seed = 42
    num_workers = 0
    
    # Device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    learning_rate = 1e-4
    weight_decay = 1e-6
    class_distribution = [514, 550, 535]

## YAMNet Decoder
* **Specifications:** 
    * Sample Rate: 16 kHz
    * Audio Range: [-1, 1] (normalized)
    * Input Shape: (batch_size, num_samples)
    * Frame Size: 0.96 seconds = 15,360 samples
    * Hop Size: 0.48 seconds = 7,680 samples

In [23]:
class YAMNetDecoder:
    def __init__(self, cfg):
        self.cfg = cfg
        self.target_length = int(cfg.yamnet_duration * cfg.sample_rate)  # 10 sec * 16000 Hz = 160000 samples
        
    def get_audio(self, file_path: str) -> np.ndarray:
        """Load and normalize audio with fixed 10-second duration"""
        try:
            # Load audio
            audio, sr = librosa.load(
                file_path, 
                sr=self.cfg.sample_rate, 
                mono=True
            )
            
            # Normalize audio
            audio = librosa.util.normalize(audio)
            
            # Handle duration
            if len(audio) < self.target_length:
                # Pad shorter audio with zeros
                audio = np.pad(audio, (0, self.target_length - len(audio)))
            else:
                # For longer audio, take the middle segment
                start = (len(audio) - self.target_length) // 2
                audio = audio[start:start + self.target_length]
            
            return audio
            
        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
    
    def frame_audio(self, audio: np.ndarray) -> np.ndarray:
        """Convert fixed-length audio into frames for YAMNet"""
        # YAMNet uses 25ms frames (15360 samples at 16kHz)
        frame_length = 15360
        
        # Calculate number of complete frames
        n_frames = self.target_length // frame_length
        
        target_length = n_frames * frame_length
        
        # Pad or trim audio to match target length
        if len(audio) < target_length:
            audio = np.pad(audio, (0, target_length - len(audio)))
        else:
            audio = audio[:target_length]
        
        # Reshape into frames
        frames = audio.reshape(n_frames, frame_length)
        return frames

## Data Augmentation for YAMNet
* I'm doing On-the-fly data augmentation because:
    1. Storage Efficiency: No need to store additional augmented audio files
    2. Dynamic Variety: Each epoch sees different augmented versions of the same samples
    3. Memory Efficient: Loads and augments data in batches
    4. Training Flexibility: Can easily adjust augmentation parameters

* The following augmentations are applied:
    * **Time Stretching:** Randomly stretches the audio in time.
    * **Pitch Shifting:** Randomly shifts the pitch of the audio.
    * **Addition of Gaussian Noise:** Adds gaussian noise to the audio.


In [24]:
class YAMNetAugmentations:
    def __init__(self, cfg):
        self.cfg = cfg
        
    def apply_augmentations(self, audio: np.ndarray) -> np.ndarray:
        """Apply augmentations to audio"""
        if np.random.rand() < self.cfg.aug_prob:
            audio = audio.copy()  # Create a writable copy
            audio = self.time_stretch(audio)
            audio = self.pitch_shift(audio)
        return audio
    
    def time_stretch(self, audio: np.ndarray) -> np.ndarray:
        try:
            factor = np.random.uniform(*self.cfg.time_stretch_range)
            return librosa.effects.time_stretch(y=audio, rate=factor)
        except Exception as e:
            print(f"Time stretch failed: {str(e)}")
            return audio
    
    def pitch_shift(self, audio: np.ndarray) -> np.ndarray:
        try:
            n_steps = np.random.randint(*self.cfg.pitch_shift_range)
            return librosa.effects.pitch_shift(
                y=audio,
                sr=self.cfg.sample_rate,
                n_steps=n_steps
            )
        except Exception as e:
            print(f"Pitch shift failed: {str(e)}")
            return audio

## YAMNet Data Pipeline

In [25]:
class YAMNetDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 df: pd.DataFrame, 
                 cfg: CFG, 
                 train: bool = True):
        """
        Args:
            df: DataFrame with columns ['filepath', 'label']
            cfg: Configuration class
            train: Whether this is training set (for augmentations)
        """
        self.df = df
        self.cfg = cfg
        self.train = train
        self.decoder = YAMNetDecoder(cfg)
        self.augmenter = YAMNetAugmentations(cfg) if train else None

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        try:
            # Get file path and label
            row = self.df.iloc[idx]
            file_path = row['filepath']
            label = row['label']
            
            # Load and normalize audio
            audio = self.decoder.get_audio(file_path)
            
            # Apply augmentations during training
            if self.train and self.augmenter is not None:
                audio = self.augmenter.apply_augmentations(audio)
            
            # Frame audio for YAMNet
            audio_frames = self.decoder.frame_audio(audio)
            # Create a writable copy of the array
            audio_frames = np.array(audio_frames, copy=True)
            
            # Convert to tensors
            input_tensor = torch.from_numpy(audio_frames).float()
            label_tensor = torch.tensor(label, dtype=torch.long)
            
            return {
                'input': input_tensor,      # Shape: (num_frames, 15360)
                'label': label_tensor,      # Shape: (1,)
            }
            
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")

## Wav2Vec2 Decoder

In [26]:
class Wav2Vec2Decoder:
    def __init__(self, cfg):
        self.cfg = cfg
        
    def get_audio(self, file_path: str) -> np.ndarray:
        """Load and normalize audio"""
        try:
            audio, sr = librosa.load(
                file_path, 
                sr=self.cfg.sample_rate, 
                mono=True
            )
            # Normalize
            audio = librosa.util.normalize(audio)
            return audio
        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
            return np.zeros(self.cfg.sample_rate)
    
    def process_audio(self, audio: np.ndarray) -> np.ndarray:
        """Process audio for Wav2Vec2"""
        target_length = int(self.cfg.wav2vec_duration * self.cfg.sample_rate)
        
        if len(audio) > target_length:
            # Take center segment
            start = (len(audio) - target_length) // 2
            audio = audio[start:start + target_length]
        else:
            # Pad with zeros
            audio = np.pad(audio, (0, target_length - len(audio)), mode='constant')
            
        return audio  # Shape: (target_length,)

## Wav2Vec2 Augmentations:

In [27]:
class Wav2Vec2Augmentations:
    def __init__(self, cfg):
        self.cfg = cfg
        
    def apply_augmentations(self, audio: np.ndarray) -> np.ndarray:
        """Apply augmentations to audio"""
        if np.random.rand() < self.cfg.aug_prob:
            # Apply augmentations
            audio = self.time_stretch(audio)
            audio = self.pitch_shift(audio)
        return audio
    
    def time_stretch(self, audio: np.ndarray) -> np.ndarray:
        """Apply time stretching to audio"""
        try:
            factor = np.random.uniform(*self.cfg.time_stretch_range)
            return librosa.effects.time_stretch(y=audio, rate=factor)
        except Exception as e:
            print(f"Time stretch failed: {str(e)}")
            return audio
    
    def pitch_shift(self, audio: np.ndarray) -> np.ndarray:
        """Apply pitch shifting to audio"""
        try:
            n_steps = np.random.randint(*self.cfg.pitch_shift_range)
            return librosa.effects.pitch_shift(
                y=audio, 
                sr=self.cfg.sample_rate, 
                n_steps=n_steps
            )
        except Exception as e:
            print(f"Pitch shift failed: {str(e)}")
            return audio

## Wav2Vec2 Data Pipeline

In [28]:
class Wav2Vec2Dataset(torch.utils.data.Dataset):
    def __init__(self, 
                 df: pd.DataFrame, 
                 cfg: CFG, 
                 train: bool = True):
        """
        Args:
            df: DataFrame with columns ['filepath', 'label']
            cfg: Configuration class
            train: Whether this is training set (for augmentations)
        """
        self.df = df
        self.cfg = cfg
        self.train = train
        self.decoder = Wav2Vec2Decoder(cfg)
        self.augmenter = Wav2Vec2Augmentations(cfg) if train else None

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        try:
            # Get file path and label
            row = self.df.iloc[idx]
            file_path = row['filepath']
            label = row['label']
            
            # Load and normalize audio
            audio = self.decoder.get_audio(file_path)
            # Create a writable copy
            audio = np.array(audio, copy=True)
            # Apply augmentations during training
            if self.train and self.augmenter is not None:
                audio = self.augmenter.apply_augmentations(audio)
            
            # Process audio for Wav2Vec2
            audio = self.decoder.process_audio(audio)
            
            # Convert to tensors
            input_tensor = torch.from_numpy(audio).float()
            label_tensor = torch.tensor(label, dtype=torch.long)
            
            return input_tensor, label_tensor
            
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")

## YAMNet Model Training

In [29]:
import torch
import torch.nn as nn
import tensorflow_hub as hub
import tensorflow as tf

class YAMNetFeatureExtractor:
    """Efficient YAMNet feature extraction for batched PyTorch tensors."""
    def __init__(self):
        self.model = hub.load('https://tfhub.dev/google/yamnet/1')

    def extract_features(self, waveform_batch):
        """Extracts YAMNet embeddings efficiently for a batch of PyTorch tensors."""
        with torch.no_grad():
            batch_size = waveform_batch.shape[0]
            embeddings_list = []
            
            for i in range(batch_size):
                # Get single waveform and flatten it to 1D
                single_waveform = waveform_batch[i].cpu().flatten().numpy()
                
                # Extract features using YAMNet (expects 1D input)
                _, embeddings, _ = self.model(single_waveform)
                embeddings_list.append(embeddings.numpy())
            
            # Stack embeddings and convert to PyTorch tensor
            embeddings_torch = torch.tensor(np.stack(embeddings_list), dtype=torch.float32)
            return embeddings_torch.to(waveform_batch.device)

class YAMNetFineTuner(nn.Module):
    def __init__(self, num_classes, cfg):
        super().__init__()
        self.cfg = cfg
        self.yamnet_extractor = YAMNetFeatureExtractor()
        
        # Define trainable classifier layers with proper BatchNorm momentum
        self.classifier = nn.Sequential(
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512, momentum=0.01),  # Lower momentum for stable training
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256, momentum=0.01),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize classifier weights with proper scaling"""
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                # Use Xavier initialization for better gradient flow
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Input validation
        if x.dim() != 3:  # (batch, frames, samples)
            raise ValueError(f"Expected 3D input (batch, frames, samples), got shape {x.shape}")
            
        # Extract features using frozen YAMNet
        features = self.yamnet_extractor.extract_features(x)  # (batch, frames, 1024)
        
        # Global average pooling with safe handling
        if features.dim() == 3:
            pooled = torch.mean(features, dim=1)  # (batch, 1024)
        else:
            raise ValueError(f"Expected 3D features, got shape {features.shape}")
        
        # Classification
        output = self.classifier(pooled)
        return output

yamnet = YAMNetFineTuner(num_classes=CFG.num_classes, cfg=CFG).to(CFG.device)
print(yamnet)

YAMNetFineTuner(
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): BatchNorm1d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.2, inplace=False)
    (8): Linear(in_features=256, out_features=3, bias=True)
  )
)


## Word2Vec2 model

In [30]:
class Wav2Vec2FineTuner(nn.Module):
    def __init__(self, num_classes, pretrained=True, freeze_encoder=True):
        super().__init__()
        
        # Load pretrained Wav2Vec2 model
        if pretrained:
            self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        else:
            config = Wav2Vec2Config()
            self.wav2vec2 = Wav2Vec2Model(config)
        
        # Initial freezing
        if freeze_encoder:
            self.freeze_feature_extractor()
            self.freeze_encoder_layers(12)  # Freeze all layers initially
        
        hidden_size = self.wav2vec2.config.hidden_size
        
        # Improved classifier architecture
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        
        self._initialize_classifier()
    
    def _initialize_classifier(self):
        """Initialize classifier weights properly"""
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                # Use scaled initialization for GELU
                nn.init.xavier_uniform_(m.weight, gain=1.0)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def freeze_feature_extractor(self):
        """Freeze the CNN feature extractor"""
        for param in self.wav2vec2.feature_extractor.parameters():
            param.requires_grad = False
    
    def freeze_encoder_layers(self, num_layers):
        """Freeze specified number of transformer layers"""
        for layer in self.wav2vec2.encoder.layers[:num_layers]:
            for param in layer.parameters():
                param.requires_grad = False
    
    def unfreeze_encoder_layers(self, num_layers_to_unfreeze):
        """Gradually unfreeze layers from top"""
        total_layers = len(self.wav2vec2.encoder.layers)
        for i, layer in enumerate(reversed(self.wav2vec2.encoder.layers)):
            if i < num_layers_to_unfreeze:
                for param in layer.parameters():
                    param.requires_grad = True

    def forward(self, x):
        # Input validation
        if x.dim() != 2:  # (batch_size, sequence_length)
            raise ValueError(f"Expected input shape (batch_size, sequence_length), got {x.shape}")

        # Wav2Vec2 forward pass
        outputs = self.wav2vec2(x)
        hidden_states = outputs.last_hidden_state
        
        # Global mean pooling
        pooled = torch.mean(hidden_states, dim=1)
        
        # Classification
        output = self.classifier(pooled)
        return output

# Initialize model
wav2vec2 = Wav2Vec2FineTuner(num_classes=3).to(CFG.device)
print(wav2vec2)



Wav2Vec2FineTuner(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encod

## Training

In [31]:
class ModelTrainer:
    def __init__(self, model_type, cfg):
        """Initialize trainer with model type and config"""
        self.model_type = model_type
        self.cfg = cfg
        self.device = cfg.device
        self.k_folds = cfg.n_fold
        self.current_phase = 0
        
        # Create directories for logs and plots
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.run_dir = os.path.join('runs', f"{model_type}_{timestamp}")
        self.log_dir = os.path.join(self.run_dir, 'logs')
        self.plot_dir = os.path.join(self.run_dir, 'plots')
        
        for dir_path in [self.run_dir, self.log_dir, self.plot_dir]:
            os.makedirs(dir_path, exist_ok=True)
            
        # Initialize model and dataset class
        self.init_model()
        self.dataset_class = YAMNetDataset if model_type == 'yamnet' else Wav2Vec2Dataset

        # Initialize metrics tracking
        self.best_metrics = None
        self.fold_predictions = []
        self.training_history = []
        
    def init_model(self):
        """Initialize model based on type"""
        if self.model_type == 'yamnet':
            self.model = YAMNetFineTuner(
                num_classes=self.cfg.num_classes,
                cfg=self.cfg
            ).to(self.device)
        else:
            self.model = Wav2Vec2FineTuner(
                num_classes=self.cfg.num_classes,
                pretrained=True,
                freeze_encoder=True
            ).to(self.device)
    
    def setup_training(self, phase=0):
        """Setup training components for current phase"""
        # Calculate class weights
        class_weights = self.calculate_class_weights()
        self.criterion = nn.CrossEntropyLoss(
            weight=class_weights.to(self.device)
        )
        
        # Setup optimizer with different learning rates for different components
        if self.model_type == 'wav2vec2':
            # Gradually unfreeze layers for Wav2Vec2
            if phase > 0:
                num_layers = phase * 2  # Unfreeze 2 layers at a time
                self.model.unfreeze_encoder_layers(num_layers)
        
        # Setup optimizer with appropriate parameters
        self.optimizer = AdamW(
            self.get_param_groups(phase),
            weight_decay=self.cfg.weight_decay
        )
        
        # Setup scheduler
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=2,
            verbose=True
        )
        
        self.early_stopping = EarlyStopping(
            patience=5,
            min_delta=1e-4
        )
    
    def get_param_groups(self, phase):
        """Get parameter groups with different learning rates"""
        if self.model_type == 'yamnet':
            return [
                {
                    'params': self.model.classifier.parameters(),
                    'lr': self.cfg.learning_rate
                }
            ]
        else:  # wav2vec2
            params = [
                {
                    'params': self.model.classifier.parameters(),
                    'lr': self.cfg.learning_rate
                }
            ]
            if phase > 0:
                params.append({
                    'params': self.model.wav2vec2.parameters(),
                    'lr': self.cfg.learning_rate * 0.1
                })
            return params
    
    def train(self, train_loader, val_loader, num_epochs, fold):
        """Training loop with progressive unfreezing"""
        print(f"Training on {self.device}")
        best_metrics = None
        best_val_loss = float('inf')
        
        # Define training phases
        num_phases = 1 if self.model_type == 'yamnet' else 3
        epochs_per_phase = max(1,num_epochs // num_phases)
        
        for phase in range(num_phases):
            print(f"\nPhase {phase + 1}/{num_phases}")
            self.setup_training(phase)
            
            for epoch in range(epochs_per_phase):
                print(f'\nEpoch {epoch + 1}/{epochs_per_phase}')
                
                # Training phase
                train_loss, train_preds, train_labels, train_probs = self.train_epoch(train_loader)
                
                # Validation phase
                val_loss, val_preds, val_labels, val_probs = self.validate(val_loader)
                
                # Compute metrics with probabilities
                train_metrics = self.compute_metrics(train_labels, train_preds, train_probs)
                val_metrics = self.compute_metrics(val_labels, val_preds, val_probs)
                
                # Log and save metrics
                self.log_metrics(
                    epoch + phase * epochs_per_phase,
                    train_metrics,
                    val_metrics,
                    train_loss,
                    val_loss,
                    fold
                )
                self.save_metrics(val_metrics, fold)
                
                # Learning rate scheduling
                self.scheduler.step(val_loss)
                
                # Early stopping check
                if self.early_stopping.step(val_loss):
                    print('Early stopping triggered')
                    break
                
                # Save best model
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_metrics = {
                        'val_loss': val_loss,
                        'val_accuracy': val_metrics['basic']['report']['accuracy'],
                        'val_f1': val_metrics['basic']['report']['weighted avg']['f1-score']
                    }
                    self.save_model(f'models/best_{self.model_type}_fold_{fold}_model.pt')
        
        return best_metrics

    def train_epoch(self, train_loader):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        predictions = []
        probabilities = []
        true_labels = []
        
        pbar = tqdm(train_loader, desc='Training')
        for batch in pbar:
            if self.model_type == 'yamnet':
                inputs, labels = batch['input'].to(self.device), batch['label'].to(self.device)
            else:
                inputs, labels = batch[0].to(self.device), batch[1].to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            
            # Get probabilities using softmax
            probs = F.softmax(outputs, dim=1).detach().cpu().numpy()
            preds = outputs.argmax(dim=1).cpu().numpy().tolist()
            
            total_loss += loss.item()
            predictions.extend(preds)
            probabilities.extend(probs)
            true_labels.extend(labels.cpu().numpy().tolist())
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        return (
            total_loss / len(train_loader),
            np.array(predictions, dtype=np.int32),  # Ensure integer type
            np.array(true_labels, dtype=np.int32),  # Ensure integer type
            np.array(probabilities)
        )
    
    @torch.no_grad()
    def validate(self, val_loader):
        """Validation loop"""
        self.model.eval()
        total_loss = 0
        predictions = []
        probabilities = []
        true_labels = []
        
        for batch in tqdm(val_loader, desc='Validation'):
            if self.model_type == 'yamnet':
                inputs, labels = batch['input'].to(self.device), batch['label'].to(self.device)
            else:
                inputs, labels = batch[0].to(self.device), batch[1].to(self.device)
            
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            
            # Get probabilities using softmax
            probs = F.softmax(outputs, dim=1).cpu().numpy()
            preds = outputs.argmax(dim=1).cpu().numpy()
            labels = labels.cpu().numpy()
            
            total_loss += loss.item()
            predictions.extend(preds)
            probabilities.extend(probs)
            true_labels.extend(labels)
        
        return (
            total_loss / len(val_loader),
            np.array(predictions, dtype=np.int32),
            np.array(true_labels, dtype=np.int32),
            np.array(probabilities)
        )

    def cross_validate(self, full_df):
        """Perform k-fold cross-validation"""
        try:
            if full_df is None or len(full_df) == 0:
                raise ValueError("Empty or invalid dataset provided")
                
            kfold = KFold(n_splits=self.k_folds, shuffle=True, random_state=self.cfg.seed)
            fold_metrics = []
            
            for fold, (train_idx, val_idx) in enumerate(kfold.split(full_df)):
                print(f'\nFold {fold + 1}/{self.k_folds}')
                
                try:
                    # Create datasets
                    train_fold_df = full_df.iloc[train_idx].reset_index(drop=True)
                    val_fold_df = full_df.iloc[val_idx].reset_index(drop=True)
                    
                    train_dataset = self.dataset_class(train_fold_df, self.cfg, train=True)
                    val_dataset = self.dataset_class(val_fold_df, self.cfg, train=False)
                    
                    train_loader = DataLoader(
                        train_dataset,
                        batch_size=self.cfg.batch_size,
                        shuffle=True,
                        num_workers=self.cfg.num_workers,
                        pin_memory=True
                    )
                    
                    val_loader = DataLoader(
                        val_dataset,
                        batch_size=self.cfg.batch_size,
                        shuffle=False,
                        num_workers=self.cfg.num_workers,
                        pin_memory=True
                    )
                    
                    # Reset model
                    self.init_model()
                    
                    # Train fold
                    best_metrics = self.train(train_loader, val_loader, self.cfg.epochs, fold)
                    if best_metrics:
                        fold_metrics.append({
                            'fold': fold + 1,
                            **best_metrics
                        })
                    
                        # Get final validation metrics with probabilities
                        val_loss, val_preds, val_labels, val_probs = self.validate(val_loader)
                        val_metrics = self.compute_metrics(val_labels, val_preds, val_probs)
                        self.save_metrics(val_metrics, fold)
                        
                        # Generate plots
                        self.plot_training_history(fold)
                        self.plot_confusion_matrix(val_metrics['basic']['confusion_matrix'], fold)
                        if 'curves' in val_metrics:
                            self.plot_roc_curves(val_metrics, fold)
                            
                except Exception as e:
                    print(f"Error in fold {fold + 1}: {str(e)}")
                    continue
                    
            if fold_metrics:
                self.log_cv_results(fold_metrics)
            else:
                print("No valid metrics collected during cross-validation")
                
        except Exception as e:
            print(f"Error in cross-validation: {str(e)}")

    def calculate_class_weights(self):
        """Calculate class weights based on distribution in config"""
        if hasattr(self.cfg, 'class_distribution'):
            # If distribution is provided in config
            class_counts = torch.tensor(self.cfg.class_distribution)
        else:
            # Default to balanced weights
            class_counts = torch.ones(self.cfg.num_classes)
        
        weights = 1.0 / class_counts
        weights = weights / weights.sum()
        return weights
    
    
    def compute_metrics(self, true_labels, predictions, probabilities=None):
        """Compute comprehensive performance metrics"""
        # Convert lists to numpy arrays
        true_labels = np.array(true_labels, dtype=np.int32)
        predictions = np.array(predictions, dtype=np.int32)
        
        metrics = {
            'basic': {
                'report': classification_report(true_labels, predictions, 
                                             output_dict=True, 
                                             zero_division=0),
                'confusion_matrix': confusion_matrix(true_labels, predictions).tolist()
            }
        }
        
        if probabilities is not None:
            probabilities = np.array(probabilities)
            metrics['curves'] = {}
            
            for i in range(self.cfg.num_classes):
                try:
                    # Create binary labels for current class
                    binary_labels = (true_labels == i).astype(int)
                    class_probs = probabilities[:, i]
                    
                    # ROC curve
                    fpr, tpr, _ = roc_curve(binary_labels, class_probs)
                    roc_auc = float(auc(fpr, tpr))
                    metrics['curves'][f'class_{i}_roc'] = {
                        'fpr': fpr.tolist(),
                        'tpr': tpr.tolist(),
                        'auc': roc_auc
                    }
                    
                    # PR curve
                    precision, recall, _ = precision_recall_curve(binary_labels, class_probs)
                    metrics['curves'][f'class_{i}_pr'] = {
                        'precision': precision.tolist(),
                        'recall': recall.tolist()
                    }
                except Exception as e:
                    print(f"Warning: Could not compute curves for class {i}: {str(e)}")
                    continue
            
        # Convert any remaining numpy numbers to Python types
        metrics = self._convert_numpy_to_python(metrics)
        return metrics
    
    def _convert_numpy_to_python(self, obj):
        """Convert numpy types to Python native types for JSON serialization"""
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, dict):
            return {key: self._convert_numpy_to_python(value) for key, value in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [self._convert_numpy_to_python(item) for item in obj]
        return obj
    
    def save_metrics(self, metrics_log, fold):
        """Save metrics to JSON file"""
        try:
            # Convert metrics to JSON-serializable format
            metrics_log = self._convert_numpy_to_python(metrics_log)
            
            # Save to file
            metrics_file = os.path.join(self.log_dir, f'metrics_fold_{fold}.json')
            with open(metrics_file, 'w') as f:
                json.dump(metrics_log, f, indent=4)
        except Exception as e:
            print(f"Warning: Could not save metrics for fold {fold}: {str(e)}")

    def plot_training_history(self, fold):
        """Plot training history for current fold"""
        try:
            history_df = pd.DataFrame(self.training_history)
            if len(history_df) == 0:
                print("Warning: No training history to plot")
                return
                
            # Loss plot
            plt.figure(figsize=(10, 5))
            plt.plot(history_df['train_loss'], label='Train Loss')
            plt.plot(history_df['val_loss'], label='Validation Loss')
            plt.title(f'Loss History - Fold {fold+1}')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            plt.savefig(os.path.join(self.plot_dir, f'loss_fold_{fold+1}.png'))
            plt.close()
            
            # Accuracy plot
            plt.figure(figsize=(10, 5))
            plt.plot(history_df['train_accuracy'], label='Train Accuracy')
            plt.plot(history_df['val_accuracy'], label='Validation Accuracy')
            plt.title(f'Accuracy History - Fold {fold+1}')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy') 
            plt.legend()
            plt.savefig(os.path.join(self.plot_dir, f'accuracy_fold_{fold+1}.png'))
            plt.close()
        except Exception as e:
            print(f"Error plotting training history: {str(e)}")

    def plot_confusion_matrix(self, cm, fold):
        """Plot confusion matrix heatmap"""
        try:
            # Convert to numpy array if not already
            cm = np.array(cm) if isinstance(cm, list) else cm
            
            if not isinstance(cm, np.ndarray):
                print("Warning: Invalid confusion matrix data")
                return
                
            plt.figure(figsize=(10, 8))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
            plt.title(f'Confusion Matrix - Fold {fold+1}')
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.savefig(os.path.join(self.plot_dir, f'confusion_matrix_fold_{fold+1}.png'))
            plt.close()
        except Exception as e:
            print(f"Error plotting confusion matrix: {str(e)}")
    
    def plot_roc_curves(self, metrics, fold):
        """Plot ROC curves for each class"""
        try:
            if 'curves' not in metrics:
                print("Warning: No ROC curve data available")
                return
                
            plt.figure(figsize=(10, 8))
            for i in range(self.cfg.num_classes):
                curve_key = f'class_{i}_roc'
                if curve_key not in metrics['curves']:
                    continue
                    
                curve_data = metrics['curves'][curve_key]
                plt.plot(
                    curve_data['fpr'],
                    curve_data['tpr'],
                    label=f'Class {i} (AUC = {curve_data["auc"]:.2f})'
                )
            plt.plot([0, 1], [0, 1], 'k--')
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title(f'ROC Curves - Fold {fold+1}')
            plt.legend()
            plt.savefig(os.path.join(self.plot_dir, f'roc_curves_fold_{fold+1}.png'))
            plt.close()
        except Exception as e:
            print(f"Error plotting ROC curves: {str(e)}")
    
    def log_metrics(self, epoch, train_metrics, val_metrics, train_loss, val_loss, fold):
        """Log metrics and update history"""
        metrics = {
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'train_accuracy': train_metrics['basic']['report']['accuracy'],
            'val_accuracy': val_metrics['basic']['report']['accuracy'],
            'train_f1_weighted': train_metrics['basic']['report']['weighted avg']['f1-score'],
            'val_f1_weighted': val_metrics['basic']['report']['weighted avg']['f1-score'],
            'learning_rate': self.optimizer.param_groups[0]['lr']
        }
        
        self.training_history.append(metrics)
        
        # Print current metrics
        print(f"\nEpoch {epoch+1}")
        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        print(f"Train Acc: {metrics['train_accuracy']:.4f}, Val Acc: {metrics['val_accuracy']:.4f}")
        print(f"Train F1: {metrics['train_f1_weighted']:.4f}, Val F1: {metrics['val_f1_weighted']:.4f}")
        
        return metrics
    
    def log_cv_results(self, fold_metrics):
        """Log cross-validation results"""
        metrics = ['val_loss', 'val_accuracy', 'val_f1']
        cv_results = {}
        
        for metric in metrics:
            values = [m[metric] for m in fold_metrics]
            mean_val = np.mean(values)
            std_val = np.std(values)
            print(f'CV {metric}: {mean_val:.4f} ± {std_val:.4f}')
            cv_results[f'cv_mean_{metric}'] = mean_val
            cv_results[f'cv_std_{metric}'] = std_val
        
        # Save CV results
        with open(os.path.join(self.log_dir, 'cv_results.json'), 'w') as f:
            json.dump(cv_results, f, indent=4)
    
    def save_model(self, path):
        """Save model checkpoint"""
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'config': self.cfg
        }, path)

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        
    def step(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        else:
            self.best_loss = val_loss
            self.counter = 0
        return False

In [32]:
# Combine train and validation dataframes
full_df = pd.concat([train_df, val_df], ignore_index=True)

# Update class distribution in config
CFG.class_distribution = full_df['label'].value_counts().tolist()

# Train models
yamnet_trainer = ModelTrainer('yamnet', CFG)
yamnet_trainer.cross_validate(full_df)


Fold 1/5
Training on cuda

Phase 1/1

Epoch 1/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 1
Train Loss: 0.4093, Val Loss: 0.7723
Train Acc: 0.8537, Val Acc: 0.6360
Train F1: 0.8536, Val F1: 0.5358

Epoch 2/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 2
Train Loss: 0.1897, Val Loss: 0.4102
Train Acc: 0.9246, Val Acc: 0.9044
Train F1: 0.9243, Val F1: 0.9052

Epoch 3/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 3
Train Loss: 0.1551, Val Loss: 0.2377
Train Acc: 0.9448, Val Acc: 0.9522
Train F1: 0.9446, Val F1: 0.9526

Epoch 4/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 4
Train Loss: 0.1158, Val Loss: 0.1484
Train Acc: 0.9623, Val Acc: 0.9706
Train F1: 0.9622, Val F1: 0.9706

Epoch 5/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 5
Train Loss: 0.1080, Val Loss: 0.1322
Train Acc: 0.9632, Val Acc: 0.9669
Train F1: 0.9631, Val F1: 0.9673

Epoch 6/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 6
Train Loss: 0.0949, Val Loss: 0.0922
Train Acc: 0.9669, Val Acc: 0.9853
Train F1: 0.9668, Val F1: 0.9854

Epoch 7/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 7
Train Loss: 0.0805, Val Loss: 0.0778
Train Acc: 0.9724, Val Acc: 0.9779
Train F1: 0.9723, Val F1: 0.9779

Epoch 8/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 8
Train Loss: 0.0968, Val Loss: 0.0536
Train Acc: 0.9696, Val Acc: 0.9890
Train F1: 0.9696, Val F1: 0.9890

Epoch 9/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 9
Train Loss: 0.0733, Val Loss: 0.0427
Train Acc: 0.9779, Val Acc: 0.9926
Train F1: 0.9779, Val F1: 0.9926

Epoch 10/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 10
Train Loss: 0.0805, Val Loss: 0.0262
Train Acc: 0.9752, Val Acc: 0.9963
Train F1: 0.9751, Val F1: 0.9963


Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Fold 2/5
Training on cuda

Phase 1/1

Epoch 1/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 1
Train Loss: 0.3572, Val Loss: 0.8331
Train Acc: 0.8629, Val Acc: 0.6176
Train F1: 0.8618, Val F1: 0.5129

Epoch 2/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 2
Train Loss: 0.1827, Val Loss: 0.5124
Train Acc: 0.9365, Val Acc: 0.7831
Train F1: 0.9363, Val F1: 0.7750

Epoch 3/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 3
Train Loss: 0.1339, Val Loss: 0.2589
Train Acc: 0.9531, Val Acc: 0.9375
Train F1: 0.9529, Val F1: 0.9379

Epoch 4/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 4
Train Loss: 0.1258, Val Loss: 0.1856
Train Acc: 0.9623, Val Acc: 0.9412
Train F1: 0.9621, Val F1: 0.9412

Epoch 5/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 5
Train Loss: 0.1159, Val Loss: 0.1647
Train Acc: 0.9614, Val Acc: 0.9669
Train F1: 0.9612, Val F1: 0.9670

Epoch 6/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 6
Train Loss: 0.1127, Val Loss: 0.1002
Train Acc: 0.9623, Val Acc: 0.9743
Train F1: 0.9622, Val F1: 0.9743

Epoch 7/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 7
Train Loss: 0.0917, Val Loss: 0.0776
Train Acc: 0.9678, Val Acc: 0.9743
Train F1: 0.9677, Val F1: 0.9743

Epoch 8/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 8
Train Loss: 0.0997, Val Loss: 0.0656
Train Acc: 0.9660, Val Acc: 0.9816
Train F1: 0.9659, Val F1: 0.9816

Epoch 9/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 9
Train Loss: 0.1045, Val Loss: 0.0993
Train Acc: 0.9669, Val Acc: 0.9779
Train F1: 0.9669, Val F1: 0.9779

Epoch 10/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 10
Train Loss: 0.0799, Val Loss: 0.0839
Train Acc: 0.9752, Val Acc: 0.9816
Train F1: 0.9751, Val F1: 0.9817


Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Fold 3/5
Training on cuda

Phase 1/1

Epoch 1/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 1
Train Loss: 0.4159, Val Loss: 0.8295
Train Acc: 0.8381, Val Acc: 0.6176
Train F1: 0.8363, Val F1: 0.5200

Epoch 2/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 2
Train Loss: 0.1736, Val Loss: 0.5336
Train Acc: 0.9512, Val Acc: 0.8015
Train F1: 0.9510, Val F1: 0.7979

Epoch 3/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 3
Train Loss: 0.1546, Val Loss: 0.2636
Train Acc: 0.9503, Val Acc: 0.9301
Train F1: 0.9502, Val F1: 0.9305

Epoch 4/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 4
Train Loss: 0.1251, Val Loss: 0.1929
Train Acc: 0.9586, Val Acc: 0.9559
Train F1: 0.9584, Val F1: 0.9559

Epoch 5/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 5
Train Loss: 0.0986, Val Loss: 0.1541
Train Acc: 0.9678, Val Acc: 0.9559
Train F1: 0.9677, Val F1: 0.9558

Epoch 6/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 6
Train Loss: 0.0953, Val Loss: 0.0919
Train Acc: 0.9669, Val Acc: 0.9779
Train F1: 0.9669, Val F1: 0.9778

Epoch 7/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 7
Train Loss: 0.0927, Val Loss: 0.0744
Train Acc: 0.9706, Val Acc: 0.9816
Train F1: 0.9705, Val F1: 0.9815

Epoch 8/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 8
Train Loss: 0.0989, Val Loss: 0.0636
Train Acc: 0.9733, Val Acc: 0.9743
Train F1: 0.9733, Val F1: 0.9741

Epoch 9/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 9
Train Loss: 0.1004, Val Loss: 0.0539
Train Acc: 0.9660, Val Acc: 0.9816
Train F1: 0.9660, Val F1: 0.9815

Epoch 10/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 10
Train Loss: 0.0686, Val Loss: 0.0554
Train Acc: 0.9770, Val Acc: 0.9816
Train F1: 0.9770, Val F1: 0.9815


Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Fold 4/5
Training on cuda

Phase 1/1

Epoch 1/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 1
Train Loss: 0.4091, Val Loss: 0.8090
Train Acc: 0.8381, Val Acc: 0.6691
Train F1: 0.8384, Val F1: 0.5825

Epoch 2/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 2
Train Loss: 0.2000, Val Loss: 0.4754
Train Acc: 0.9282, Val Acc: 0.8456
Train F1: 0.9279, Val F1: 0.8448

Epoch 3/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 3
Train Loss: 0.1632, Val Loss: 0.2481
Train Acc: 0.9448, Val Acc: 0.9338
Train F1: 0.9446, Val F1: 0.9342

Epoch 4/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 4
Train Loss: 0.1318, Val Loss: 0.1526
Train Acc: 0.9558, Val Acc: 0.9596
Train F1: 0.9557, Val F1: 0.9594

Epoch 5/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 5
Train Loss: 0.1147, Val Loss: 0.1136
Train Acc: 0.9650, Val Acc: 0.9743
Train F1: 0.9649, Val F1: 0.9742

Epoch 6/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 6
Train Loss: 0.1041, Val Loss: 0.0881
Train Acc: 0.9604, Val Acc: 0.9743
Train F1: 0.9602, Val F1: 0.9742

Epoch 7/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 7
Train Loss: 0.0899, Val Loss: 0.0753
Train Acc: 0.9696, Val Acc: 0.9706
Train F1: 0.9696, Val F1: 0.9705

Epoch 8/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 8
Train Loss: 0.0972, Val Loss: 0.0670
Train Acc: 0.9660, Val Acc: 0.9926
Train F1: 0.9659, Val F1: 0.9926

Epoch 9/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 9
Train Loss: 0.0670, Val Loss: 0.0588
Train Acc: 0.9761, Val Acc: 0.9816
Train F1: 0.9761, Val F1: 0.9816

Epoch 10/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 10
Train Loss: 0.0792, Val Loss: 0.0723
Train Acc: 0.9742, Val Acc: 0.9853
Train F1: 0.9742, Val F1: 0.9853


Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Fold 5/5
Training on cuda

Phase 1/1

Epoch 1/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 1
Train Loss: 0.4189, Val Loss: 0.7711
Train Acc: 0.8575, Val Acc: 0.6863
Train F1: 0.8574, Val F1: 0.5927

Epoch 2/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 2
Train Loss: 0.2259, Val Loss: 0.4576
Train Acc: 0.9228, Val Acc: 0.8708
Train F1: 0.9226, Val F1: 0.8679

Epoch 3/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 3
Train Loss: 0.1345, Val Loss: 0.2285
Train Acc: 0.9522, Val Acc: 0.9410
Train F1: 0.9521, Val F1: 0.9410

Epoch 4/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 4
Train Loss: 0.1166, Val Loss: 0.1379
Train Acc: 0.9577, Val Acc: 0.9705
Train F1: 0.9575, Val F1: 0.9703

Epoch 5/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 5
Train Loss: 0.0983, Val Loss: 0.0902
Train Acc: 0.9632, Val Acc: 0.9742
Train F1: 0.9632, Val F1: 0.9740

Epoch 6/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 6
Train Loss: 0.0848, Val Loss: 0.0724
Train Acc: 0.9706, Val Acc: 0.9779
Train F1: 0.9706, Val F1: 0.9778

Epoch 7/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 7
Train Loss: 0.0904, Val Loss: 0.0691
Train Acc: 0.9706, Val Acc: 0.9815
Train F1: 0.9706, Val F1: 0.9815

Epoch 8/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 8
Train Loss: 0.0819, Val Loss: 0.0602
Train Acc: 0.9688, Val Acc: 0.9852
Train F1: 0.9687, Val F1: 0.9852

Epoch 9/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 9
Train Loss: 0.0829, Val Loss: 0.0402
Train Acc: 0.9724, Val Acc: 0.9926
Train F1: 0.9724, Val F1: 0.9926

Epoch 10/10


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 10
Train Loss: 0.1024, Val Loss: 0.0359
Train Acc: 0.9678, Val Acc: 0.9889
Train F1: 0.9677, Val F1: 0.9889


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

CV val_loss: 0.0481 ± 0.0147
CV val_accuracy: 0.9860 ± 0.0059
CV val_f1: 0.9860 ± 0.0059


In [33]:
wav2vec_trainer = ModelTrainer('wav2vec2', CFG)
wav2vec_trainer.cross_validate(full_df)




Fold 1/5
Training on cuda

Phase 1/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 1
Train Loss: 0.4727, Val Loss: 0.1401
Train Acc: 0.7976, Val Acc: 0.9412
Train F1: 0.7947, Val F1: 0.9409

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 2
Train Loss: 0.1270, Val Loss: 0.1257
Train Acc: 0.9586, Val Acc: 0.9412
Train F1: 0.9585, Val F1: 0.9407

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 3
Train Loss: 0.1061, Val Loss: 0.1820
Train Acc: 0.9614, Val Acc: 0.9265
Train F1: 0.9612, Val F1: 0.9250

Phase 2/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 4
Train Loss: 0.1192, Val Loss: 0.1302
Train Acc: 0.9549, Val Acc: 0.9449
Train F1: 0.9548, Val F1: 0.9442

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 5
Train Loss: 0.0519, Val Loss: 0.0943
Train Acc: 0.9825, Val Acc: 0.9596
Train F1: 0.9825, Val F1: 0.9594

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 6
Train Loss: 0.0446, Val Loss: 0.0499
Train Acc: 0.9880, Val Acc: 0.9779
Train F1: 0.9880, Val F1: 0.9779

Phase 3/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 7
Train Loss: 0.0553, Val Loss: 0.0393
Train Acc: 0.9779, Val Acc: 0.9926
Train F1: 0.9779, Val F1: 0.9926

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 8
Train Loss: 0.0300, Val Loss: 0.0460
Train Acc: 0.9917, Val Acc: 0.9779
Train F1: 0.9917, Val F1: 0.9779

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 9
Train Loss: 0.0230, Val Loss: 0.0204
Train Acc: 0.9899, Val Acc: 0.9926
Train F1: 0.9899, Val F1: 0.9926


Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Fold 2/5




Training on cuda

Phase 1/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 1
Train Loss: 0.3459, Val Loss: 0.1922
Train Acc: 0.8574, Val Acc: 0.9191
Train F1: 0.8559, Val F1: 0.9184

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 2
Train Loss: 0.1294, Val Loss: 0.1970
Train Acc: 0.9512, Val Acc: 0.9265
Train F1: 0.9511, Val F1: 0.9257

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 3
Train Loss: 0.1048, Val Loss: 0.2240
Train Acc: 0.9577, Val Acc: 0.9081
Train F1: 0.9575, Val F1: 0.9068

Phase 2/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 4
Train Loss: 0.0935, Val Loss: 0.1659
Train Acc: 0.9687, Val Acc: 0.9449
Train F1: 0.9687, Val F1: 0.9447

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 5
Train Loss: 0.0650, Val Loss: 0.2624
Train Acc: 0.9724, Val Acc: 0.9081
Train F1: 0.9723, Val F1: 0.9072

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 6
Train Loss: 0.0552, Val Loss: 0.1429
Train Acc: 0.9807, Val Acc: 0.9596
Train F1: 0.9807, Val F1: 0.9594

Phase 3/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 7
Train Loss: 0.0797, Val Loss: 0.2417
Train Acc: 0.9696, Val Acc: 0.9375
Train F1: 0.9696, Val F1: 0.9372

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 8
Train Loss: 0.0389, Val Loss: 0.3907
Train Acc: 0.9862, Val Acc: 0.8824
Train F1: 0.9862, Val F1: 0.8799

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 9
Train Loss: 0.0294, Val Loss: 0.0949
Train Acc: 0.9908, Val Acc: 0.9816
Train F1: 0.9908, Val F1: 0.9816


Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Fold 3/5




Training on cuda

Phase 1/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 1
Train Loss: 0.3406, Val Loss: 0.2244
Train Acc: 0.8749, Val Acc: 0.8971
Train F1: 0.8734, Val F1: 0.8961

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 2
Train Loss: 0.1077, Val Loss: 0.1769
Train Acc: 0.9577, Val Acc: 0.9191
Train F1: 0.9576, Val F1: 0.9186

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 3
Train Loss: 0.0868, Val Loss: 0.1668
Train Acc: 0.9678, Val Acc: 0.9265
Train F1: 0.9677, Val F1: 0.9261

Phase 2/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 4
Train Loss: 0.0617, Val Loss: 0.1526
Train Acc: 0.9807, Val Acc: 0.9338
Train F1: 0.9806, Val F1: 0.9333

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 5
Train Loss: 0.0564, Val Loss: 0.1350
Train Acc: 0.9770, Val Acc: 0.9485
Train F1: 0.9770, Val F1: 0.9481

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 6
Train Loss: 0.0428, Val Loss: 0.1333
Train Acc: 0.9880, Val Acc: 0.9559
Train F1: 0.9880, Val F1: 0.9556

Phase 3/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 7
Train Loss: 0.0392, Val Loss: 0.1595
Train Acc: 0.9871, Val Acc: 0.9522
Train F1: 0.9871, Val F1: 0.9518

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 8
Train Loss: 0.0183, Val Loss: 0.1317
Train Acc: 0.9945, Val Acc: 0.9522
Train F1: 0.9945, Val F1: 0.9518

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 9
Train Loss: 0.0207, Val Loss: 0.1456
Train Acc: 0.9917, Val Acc: 0.9485
Train F1: 0.9917, Val F1: 0.9481


Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Fold 4/5




Training on cuda

Phase 1/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 1
Train Loss: 0.3443, Val Loss: 0.2271
Train Acc: 0.8712, Val Acc: 0.9191
Train F1: 0.8703, Val F1: 0.9178

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 2
Train Loss: 0.1267, Val Loss: 0.2076
Train Acc: 0.9531, Val Acc: 0.9265
Train F1: 0.9528, Val F1: 0.9249

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 3
Train Loss: 0.0903, Val Loss: 0.2421
Train Acc: 0.9660, Val Acc: 0.9044
Train F1: 0.9659, Val F1: 0.9003

Phase 2/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 4
Train Loss: 0.0723, Val Loss: 0.1407
Train Acc: 0.9715, Val Acc: 0.9485
Train F1: 0.9714, Val F1: 0.9479

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 5
Train Loss: 0.0563, Val Loss: 0.1535
Train Acc: 0.9752, Val Acc: 0.9485
Train F1: 0.9752, Val F1: 0.9477

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 6
Train Loss: 0.0405, Val Loss: 0.2052
Train Acc: 0.9844, Val Acc: 0.9375
Train F1: 0.9844, Val F1: 0.9355

Phase 3/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 7
Train Loss: 0.0341, Val Loss: 0.3129
Train Acc: 0.9880, Val Acc: 0.9338
Train F1: 0.9880, Val F1: 0.9318

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 8
Train Loss: 0.0315, Val Loss: 0.1806
Train Acc: 0.9871, Val Acc: 0.9559
Train F1: 0.9871, Val F1: 0.9555

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 9
Train Loss: 0.0215, Val Loss: 0.1343
Train Acc: 0.9917, Val Acc: 0.9743
Train F1: 0.9917, Val F1: 0.9741


Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Fold 5/5




Training on cuda

Phase 1/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 1
Train Loss: 0.4104, Val Loss: 0.1767
Train Acc: 0.8401, Val Acc: 0.9483
Train F1: 0.8391, Val F1: 0.9478

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 2
Train Loss: 0.1369, Val Loss: 0.1230
Train Acc: 0.9540, Val Acc: 0.9483
Train F1: 0.9539, Val F1: 0.9480

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 3
Train Loss: 0.1012, Val Loss: 0.1171
Train Acc: 0.9642, Val Acc: 0.9483
Train F1: 0.9641, Val F1: 0.9480

Phase 2/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 4
Train Loss: 0.0879, Val Loss: 0.0924
Train Acc: 0.9669, Val Acc: 0.9594
Train F1: 0.9668, Val F1: 0.9592

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 5
Train Loss: 0.0494, Val Loss: 0.1199
Train Acc: 0.9835, Val Acc: 0.9446
Train F1: 0.9835, Val F1: 0.9442

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 6
Train Loss: 0.0458, Val Loss: 0.2402
Train Acc: 0.9844, Val Acc: 0.9299
Train F1: 0.9844, Val F1: 0.9291

Phase 3/3

Epoch 1/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 7
Train Loss: 0.0503, Val Loss: 0.1941
Train Acc: 0.9816, Val Acc: 0.9373
Train F1: 0.9816, Val F1: 0.9359

Epoch 2/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 8
Train Loss: 0.0360, Val Loss: 0.2240
Train Acc: 0.9899, Val Acc: 0.9410
Train F1: 0.9899, Val F1: 0.9399

Epoch 3/3


Training:   0%|          | 0/34 [00:00<?, ?it/s]

Validation:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 9
Train Loss: 0.0281, Val Loss: 0.0904
Train Acc: 0.9917, Val Acc: 0.9631
Train F1: 0.9917, Val F1: 0.9629


Validation:   0%|          | 0/9 [00:00<?, ?it/s]

CV val_loss: 0.0943 ± 0.0412
CV val_accuracy: 0.9728 ± 0.0141
CV val_f1: 0.9726 ± 0.0142


## Ensemble model

In [38]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, roc_curve, auc, classification_report
)
class EnhancedEnsembleModel:
    def __init__(self, yamnet_paths, wav2vec2_paths, cfg):
        """Initialize ensemble with multiple model paths"""
        self.cfg = cfg
        self.device = cfg.device
        self.yamnet_models = []
        self.wav2vec2_models = []
        
        # Load YAMNet models
        for path in yamnet_paths:
            model = YAMNetFineTuner(num_classes=cfg.num_classes, cfg=cfg).to(cfg.device)
            checkpoint = torch.load(path, map_location=cfg.device)
            model.load_state_dict(checkpoint['model_state_dict'])
            model.eval()
            self.yamnet_models.append(model)
            
        # Load Wav2Vec2 models  
        for path in wav2vec2_paths:
            model = Wav2Vec2FineTuner(num_classes=cfg.num_classes).to(cfg.device)
            checkpoint = torch.load(path, map_location=cfg.device)
            model.load_state_dict(checkpoint['model_state_dict'])
            model.eval()
            self.wav2vec2_models.append(model)

    @torch.no_grad()
    def predict(self, yamnet_batch, wav2vec2_batch):
        """Get predictions from all models"""
        # Individual YAMNet predictions
        yamnet_probs = []
        for model in self.yamnet_models:
            logits = model(yamnet_batch['input'].to(self.device))
            probs = F.softmax(logits, dim=1)
            yamnet_probs.append(probs)
        yamnet_mean_probs = torch.stack(yamnet_probs).mean(dim=0)
        
        # Individual Wav2Vec2 predictions
        wav2vec2_probs = []
        for model in self.wav2vec2_models:
            logits = model(wav2vec2_batch[0].to(self.device))
            probs = F.softmax(logits, dim=1)
            wav2vec2_probs.append(probs)
        wav2vec2_mean_probs = torch.stack(wav2vec2_probs).mean(dim=0)
        
        # Ensemble predictions
        ensemble_probs = (yamnet_mean_probs + wav2vec2_mean_probs) / 2
        
        return {
            'yamnet': yamnet_mean_probs,
            'wav2vec2': wav2vec2_mean_probs,
            'ensemble': ensemble_probs
        }

def evaluate_all_models(test_df, cfg):
    """Evaluate individual models and ensemble"""
    # Prepare model paths
    yamnet_paths = [f'/kaggle/working/models/best_yamnet_fold_{i}_model.pt' for i in range(5)]
    wav2vec2_paths = [f'/kaggle/working/models/best_wav2vec2_fold_{i}_model.pt' for i in range(5)]
    
    # Create datasets and loaders
    yamnet_dataset = YAMNetDataset(test_df, cfg, train=False)
    wav2vec2_dataset = Wav2Vec2Dataset(test_df, cfg, train=False)
    
    yamnet_loader = DataLoader(yamnet_dataset, batch_size=cfg.batch_size, 
                             shuffle=False, num_workers=cfg.num_workers)
    wav2vec2_loader = DataLoader(wav2vec2_dataset, batch_size=cfg.batch_size,
                                shuffle=False, num_workers=cfg.num_workers)
    
    # Initialize enhanced ensemble
    ensemble = EnhancedEnsembleModel(yamnet_paths, wav2vec2_paths, cfg)
    
    # Storage for predictions
    predictions = {
        'yamnet': [], 
        'wav2vec2': [], 
        'ensemble': []
    }
    probabilities = {
        'yamnet': [],
        'wav2vec2': [],
        'ensemble': []
    }
    labels = []
    
    # Evaluation loop
    for yamnet_batch, wav2vec2_batch in tqdm(zip(yamnet_loader, wav2vec2_loader), 
                                           desc="Evaluating Models"):
        # Get all predictions
        probs = ensemble.predict(yamnet_batch, wav2vec2_batch)
        
        # Store results for each model
        for model_name in probs:
            model_probs = probs[model_name]
            model_preds = torch.argmax(model_probs, dim=1)
            
            predictions[model_name].extend(model_preds.cpu().numpy())
            probabilities[model_name].extend(model_probs.cpu().numpy())
        
        labels.extend(yamnet_batch['label'].numpy())
    
    # Calculate metrics for each model
    metrics = {}
    for model_name in predictions:
        model_preds = np.array(predictions[model_name])
        model_probs = np.array(probabilities[model_name])
        
        metrics[model_name] = {
            'accuracy': accuracy_score(labels, model_preds),
            'precision': precision_score(labels, model_preds, average='weighted'),
            'recall': recall_score(labels, model_preds, average='weighted'),
            'f1': f1_score(labels, model_preds, average='weighted'),
            'confusion_matrix': confusion_matrix(labels, model_preds),
            'classification_report': classification_report(labels, model_preds)
        }
        
        # Add ROC curves
        metrics[model_name]['roc_curves'] = {}
        for i in range(cfg.num_classes):
            binary_labels = (np.array(labels) == i).astype(int)
            fpr, tpr, _ = roc_curve(binary_labels, model_probs[:, i])
            roc_auc = auc(fpr, tpr)
            metrics[model_name]['roc_curves'][f'class_{i}'] = {
                'fpr': fpr,
                'tpr': tpr,
                'auc': roc_auc
            }
    
    # Plot results
    plot_comparative_results(metrics, cfg)
    
    return metrics

def plot_comparative_results(metrics, cfg):
    """Plot comparative visualizations for all models"""
    # Confusion Matrices
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    for idx, (model_name, model_metrics) in enumerate(metrics.items()):
        sns.heatmap(model_metrics['confusion_matrix'], 
                   annot=True, fmt='d', ax=axes[idx], cmap='Blues')
        axes[idx].set_title(f'{model_name.capitalize()} Confusion Matrix')
    plt.tight_layout()
    plt.savefig('all_models_confusion_matrices.png')
    plt.close()
    
    # ROC Curves
    plt.figure(figsize=(12, 8))
    colors = plt.cm.rainbow(np.linspace(0, 1, len(metrics) * cfg.num_classes))
    color_idx = 0
    
    for model_name, model_metrics in metrics.items():
        for class_idx in range(cfg.num_classes):
            roc_data = model_metrics['roc_curves'][f'class_{class_idx}']
            plt.plot(roc_data['fpr'], roc_data['tpr'], 
                    color=colors[color_idx],
                    label=f'{model_name}-Class {class_idx} (AUC={roc_data["auc"]:.2f})')
            color_idx += 1
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves - All Models')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig('all_models_roc_curves.png')
    plt.close()

# Run evaluation
all_metrics = evaluate_all_models(test_df, CFG)

# Print comparative results
print("\nModel Performance Comparison:")
for model_name, metrics in all_metrics.items():
    print(f"\n{model_name.upper()} METRICS:")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print("\nClassification Report:")
    print(metrics['classification_report'])

  checkpoint = torch.load(path, map_location=cfg.device)
  checkpoint = torch.load(path, map_location=cfg.device)


Evaluating Models: 0it [00:00, ?it/s]


Model Performance Comparison:

YAMNET METRICS:
Accuracy: 0.9875
Precision: 0.9877
Recall: 0.9875
F1 Score: 0.9875

Classification Report:
              precision    recall  f1-score   support

           0       0.97      1.00      0.99        77
           1       0.99      0.98      0.98        83
           2       1.00      0.99      0.99        80

    accuracy                           0.99       240
   macro avg       0.99      0.99      0.99       240
weighted avg       0.99      0.99      0.99       240


WAV2VEC2 METRICS:
Accuracy: 0.9875
Precision: 0.9876
Recall: 0.9875
F1 Score: 0.9875

Classification Report:
              precision    recall  f1-score   support

           0       0.97      0.99      0.98        77
           1       0.99      0.98      0.98        83
           2       1.00      1.00      1.00        80

    accuracy                           0.99       240
   macro avg       0.99      0.99      0.99       240
weighted avg       0.99      0.99      0.99 