# Improved Hybrid CNN+MLP Training (V4.3 Enhanced) with Context Windows

**Enhanced version** with Focal Loss, SpecAugment, and improved architecture:

**Key improvements in V4.3:**
1. **Focal Loss**: Replaces LabelSmoothingCrossEntropy to focus on hard examples (70%+ high-confidence errors)
2. **SpecAugment**: Frequency and time masking for spectrogram augmentation during training
3. **Enhanced Architecture**:
   - Multi-head attention in cross-attention fusion
   - Residual connections in MLP branch
   - Enhanced SE blocks in CNN branch

**Expected improvements:**
- Better handling of hard examples (Focal Loss)
- Improved generalization (SpecAugment)
- Better feature fusion (Multi-head attention)
- More stable training (Residual connections)
- Target: Accuracy > 0.976 (from 0.9660)


In [3]:
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import pandas as pd
import numpy as np
import h5py
import joblib
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset
from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix
)
from sklearn.preprocessing import StandardScaler, LabelEncoder
import math

# Project root
PROJECT_ROOT = Path('/Volumes/SSanDisk/SpeechRec-German')

# Data directory (with context v2 - includes VOT, burst features)
DATA_DIR = PROJECT_ROOT / 'artifacts' / 'd-t_dl_models_with_context_v2'
FEATURES_DIR = DATA_DIR / 'features'

# Device setup
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print(f"Using MPS device")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using CUDA device")
else:
    device = torch.device("cpu")
    print(f"Using CPU device")

print(f"Data directory: {DATA_DIR}")
print(f"Features directory: {FEATURES_DIR}")


Using MPS device
Data directory: /Volumes/SSanDisk/SpeechRec-German/artifacts/d-t_dl_models_with_context_v2
Features directory: /Volumes/SSanDisk/SpeechRec-German/artifacts/d-t_dl_models_with_context_v2/features


## Load Data with Context Windows (V2 - with VOT and Burst Features)


In [4]:
# Load feature columns
with open(DATA_DIR / 'feature_cols.json', 'r') as f:
    feature_cols = json.load(f)

# Load feature scaler
feature_scaler = joblib.load(DATA_DIR / 'feature_scaler.joblib')

# Load class weights
with open(DATA_DIR / 'class_weights.json', 'r') as f:
    class_weights_dict = json.load(f)

# Load features DataFrame (from 02.2 - includes VOT, burst features)
df = pd.read_parquet(FEATURES_DIR / 'features.parquet')
print(f"Dataset shape: {df.shape}")
print(f"Feature columns (loaded): {len(feature_cols)}")

# Filter feature_cols to only include columns that exist in DataFrame
original_feature_count = len(feature_cols)
feature_cols = [col for col in feature_cols if col in df.columns and pd.api.types.is_numeric_dtype(df[col])]

if len(feature_cols) != original_feature_count:
    missing_cols = set([col for col in json.load(open(DATA_DIR / 'feature_cols.json', 'r')) if col not in df.columns])
    print(f"Warning: {original_feature_count - len(feature_cols)} feature columns are missing from DataFrame")
    if missing_cols:
        print(f"Missing columns: {list(missing_cols)[:10]}...")
        
    if 'duration_ms_features' in missing_cols:
        print("Note: 'duration_ms_features' is missing - this is expected if duration_ms wasn't duplicated during merge.")
        print("      This column is not a real feature and can be safely ignored.")

print(f"Feature columns (filtered): {len(feature_cols)}")

# Verify feature count matches scaler
if hasattr(feature_scaler, 'n_features_in_'):
    if len(feature_cols) != feature_scaler.n_features_in_:
        print(f"Warning: Feature count mismatch. Scaler expects {feature_scaler.n_features_in_} features, but we have {len(feature_cols)}")
        print("This is OK if some features were removed from the dataset. The scaler will be applied to available features.")

# Check what metadata columns we have
metadata_cols = ['phoneme_id', 'class', 'duration_ms', 'phoneme', 'utterance_id']
present_metadata = [col for col in metadata_cols if col in df.columns]
print(f"\nMetadata columns present: {present_metadata}")

# Handle class column
if 'class' not in df.columns:
    if 'phoneme' in df.columns:
        df['class'] = df['phoneme']
        print("Created 'class' column from 'phoneme'")
    else:
        raise ValueError("Neither 'class' nor 'phoneme' column found in features.parquet.")

# Filter to only d and t classes
if 'pf' in df['class'].values:
    df = df[df['class'].isin(['d', 't'])].copy()
    print(f"Dataset after filtering to d/t: {len(df)} samples")

# Encode target
le = LabelEncoder()
df['class_encoded'] = le.fit_transform(df['class'])  # d=0, t=1
print(f"\nClass encoding: {dict(zip(le.classes_, le.transform(le.classes_)))}")
print(f"Class distribution:\n{df['class'].value_counts()}")

# Load split indices
with open(DATA_DIR / 'split_indices.json', 'r') as f:
    split_indices = json.load(f)

# Reset index
df = df.reset_index(drop=True)

# Create split column based on indices
df['split'] = 'train'
if len(df) > max(split_indices['val'] + split_indices['test']):
    df.loc[split_indices['val'], 'split'] = 'val'
    df.loc[split_indices['test'], 'split'] = 'test'
else:
    print("Warning: Split indices may not match DataFrame indices. Using phoneme_id matching...")
    val_ids = set(df.loc[split_indices['val'], 'phoneme_id'].values) if len(df) > max(split_indices['val']) else set()
    test_ids = set(df.loc[split_indices['test'], 'phoneme_id'].values) if len(df) > max(split_indices['test']) else set()
    df.loc[df['phoneme_id'].isin(val_ids), 'split'] = 'val'
    df.loc[df['phoneme_id'].isin(test_ids), 'split'] = 'test'

print(f"\nSplit distribution:")
print(df['split'].value_counts())

# Load spectrograms
spectrograms_dict = {}
with h5py.File(FEATURES_DIR / 'spectrograms.h5', 'r') as f:
    phoneme_ids = list(f.keys())
    for phoneme_id in tqdm(phoneme_ids, desc="Loading spectrograms"):
        spectrograms_dict[phoneme_id] = f[phoneme_id][:]

print(f"\nLoaded {len(spectrograms_dict):,} spectrograms")
if spectrograms_dict:
    print(f"Spectrogram shape: {list(spectrograms_dict.values())[0].shape}")

# Filter to only phonemes with spectrograms
df['phoneme_id_str'] = df['phoneme_id'].astype(str)
df['has_spectrogram'] = df['phoneme_id_str'].isin(spectrograms_dict.keys())
df = df[df['has_spectrogram']].copy()
print(f"\nDataset after filtering for spectrograms: {len(df)} samples")


Dataset shape: (132992, 134)
Feature columns (loaded): 130
Missing columns: ['duration_ms_features']...
Note: 'duration_ms_features' is missing - this is expected if duration_ms wasn't duplicated during merge.
      This column is not a real feature and can be safely ignored.
Feature columns (filtered): 129
This is OK if some features were removed from the dataset. The scaler will be applied to available features.

Metadata columns present: ['phoneme_id', 'class', 'duration_ms']

Class encoding: {'d': np.int64(0), 't': np.int64(1)}
Class distribution:
class
t    74454
d    58538
Name: count, dtype: int64

Split distribution:
split
train    93147
test     19949
val      19896
Name: count, dtype: int64


Loading spectrograms: 100%|██████████| 132992/132992 [00:45<00:00, 2944.35it/s]



Loaded 132,992 spectrograms
Spectrogram shape: (128, 7)

Dataset after filtering for spectrograms: 132992 samples


## Define SpecAugment for Spectrogram Augmentation


In [5]:
class SpecAugment:
    """
    SpecAugment: Simple spectrogram augmentation for speech recognition.
    Applies frequency masking and time masking to spectrograms.
    Adaptively adjusts parameters based on spectrogram dimensions.
    """
    def __init__(self, F=27, T=40, m_F=2, m_T=2):
        """
        Args:
            F: Maximum frequency mask width (will be clamped to H-1)
            T: Maximum time mask width (will be clamped to W-1)
            m_F: Number of frequency masks
            m_T: Number of time masks
        """
        self.F = F
        self.T = T
        self.m_F = m_F
        self.m_T = m_T
    
    def __call__(self, spectrogram):
        """
        Apply SpecAugment to spectrogram.
        Args:
            spectrogram: numpy array of shape (C, H, W) or (H, W)
        Returns:
            Augmented spectrogram
        """
        # Handle different input shapes
        if len(spectrogram.shape) == 2:
            # (H, W) -> (1, H, W)
            spec = np.expand_dims(spectrogram, axis=0)
            squeeze_output = True
        else:
            spec = spectrogram.copy()
            squeeze_output = False
        
        C, H, W = spec.shape
        
        # Adaptively adjust mask sizes to fit spectrogram dimensions
        # Ensure we can always apply at least some masking
        max_F = max(1, min(self.F, H - 1))  # At least 1, at most H-1
        max_T = max(1, min(self.T, W - 1))  # At least 1, at most W-1
        
        # Apply frequency masking
        for _ in range(self.m_F):
            if max_F > 0 and H > 0:
                f = np.random.randint(0, max_F + 1)
                if f > 0 and H - f > 0:
                    f0 = np.random.randint(0, H - f + 1)
                    spec[:, f0:f0+f, :] = 0
        
        # Apply time masking
        for _ in range(self.m_T):
            if max_T > 0 and W > 0:
                t = np.random.randint(0, max_T + 1)
                if t > 0 and W - t > 0:
                    t0 = np.random.randint(0, W - t + 1)
                    spec[:, :, t0:t0+t] = 0
        
        if squeeze_output:
            spec = np.squeeze(spec, axis=0)
        
        return spec

print("SpecAugment class defined successfully!")


SpecAugment class defined successfully!


## Create Dataset Classes with SpecAugment


In [6]:
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler

class HybridDataset(Dataset):
    """Dataset for hybrid models using both spectrograms and features with SpecAugment"""
    def __init__(self, df, spectrograms_dict, feature_cols, scaler=None, split='train', fit_scaler=False, transform=None, use_specaugment=False):
        self.df = df[df['split'] == split].reset_index(drop=True)
        self.spectrograms_dict = spectrograms_dict
        self.transform = transform
        self.split = split
        self.use_specaugment = use_specaugment and (split == 'train')  # Only apply to training
        
        # Initialize SpecAugment if needed
        if self.use_specaugment:
            self.specaugment = SpecAugment(F=27, T=40, m_F=2, m_T=2)
        
        self.feature_cols = [col for col in feature_cols if col in self.df.columns and pd.api.types.is_numeric_dtype(self.df[col])]
        if len(self.feature_cols) != len(feature_cols):
            missing = set(feature_cols) - set(self.feature_cols)
            print(f"Warning: {len(missing)} feature columns missing from DataFrame: {list(missing)[:5]}...")
        
        X_features = self.df[self.feature_cols].values.astype(np.float32)
        X_features = np.nan_to_num(X_features, nan=0.0, posinf=0.0, neginf=0.0)
        
        if fit_scaler:
            self.scaler = StandardScaler()
            X_features = self.scaler.fit_transform(X_features)
        elif scaler is not None:
            if hasattr(scaler, 'n_features_in_') and X_features.shape[1] != scaler.n_features_in_:
                print(f"Warning: Feature count mismatch ({X_features.shape[1]} vs {scaler.n_features_in_}). Retraining scaler on current features.")
                self.scaler = StandardScaler()
                X_features = self.scaler.fit_transform(X_features)
            else:
                self.scaler = scaler
                X_features = self.scaler.transform(X_features)
        else:
            self.scaler = None
        
        self.X_features = torch.from_numpy(X_features)
        self.y = torch.from_numpy(self.df['class_encoded'].values).long()
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        phoneme_id = str(row['phoneme_id'])
        
        spectrogram = self.spectrograms_dict[phoneme_id].astype(np.float32)
        if len(spectrogram.shape) == 2:
            spectrogram = np.expand_dims(spectrogram, axis=0)
        spectrogram = (spectrogram - spectrogram.min()) / (spectrogram.max() - spectrogram.min() + 1e-8)
        
        # Apply SpecAugment during training
        if self.use_specaugment:
            spectrogram = self.specaugment(spectrogram)
        
        if self.transform:
            spectrogram = self.transform(spectrogram)
        
        features = self.X_features[idx]
        label = self.y[idx]
        
        return (torch.from_numpy(spectrogram), features), label

# Check and retrain scaler if needed
train_df = df[df['split'] == 'train'].reset_index(drop=True)
train_feature_cols = [col for col in feature_cols if col in train_df.columns and pd.api.types.is_numeric_dtype(train_df[col])]
feature_cols = train_feature_cols

if hasattr(feature_scaler, 'n_features_in_') and len(feature_cols) != feature_scaler.n_features_in_:
    print(f"Feature count mismatch detected: {len(feature_cols)} features in DataFrame vs {feature_scaler.n_features_in_} in scaler")
    print("Retraining scaler on train split with current features...")
    X_train_features = train_df[feature_cols].values.astype(np.float32)
    X_train_features = np.nan_to_num(X_train_features, nan=0.0, posinf=0.0, neginf=0.0)
    feature_scaler = StandardScaler()
    feature_scaler.fit(X_train_features)
    print(f"Scaler retrained on {len(feature_cols)} features")
else:
    print(f"Using existing scaler with {feature_scaler.n_features_in_} features")

# Create datasets with SpecAugment for training
train_hybrid_ds = HybridDataset(df, spectrograms_dict, feature_cols, scaler=feature_scaler, split='train', use_specaugment=True)
val_hybrid_ds = HybridDataset(df, spectrograms_dict, feature_cols, scaler=feature_scaler, split='val', use_specaugment=False)
test_hybrid_ds = HybridDataset(df, spectrograms_dict, feature_cols, scaler=feature_scaler, split='test', use_specaugment=False)

print(f"Train dataset: {len(train_hybrid_ds)} samples (with SpecAugment)")
print(f"Val dataset: {len(val_hybrid_ds)} samples")
print(f"Test dataset: {len(test_hybrid_ds)} samples")

# Create weighted sampler
train_labels = df[df['split'] == 'train']['class_encoded'].values
class_weights_array = np.array([class_weights_dict.get(str(i), class_weights_dict.get(i, 1.0)) for i in range(2)])
sample_weights = np.array([class_weights_array[label] for label in train_labels])
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

# Create DataLoaders
BATCH_SIZE = 64
train_hybrid_loader = DataLoader(train_hybrid_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
val_hybrid_loader = DataLoader(val_hybrid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_hybrid_loader = DataLoader(test_hybrid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"\nTrain batches: {len(train_hybrid_loader)}")
print(f"Val batches: {len(val_hybrid_loader)}")
print(f"Test batches: {len(test_hybrid_loader)}")

# Test a batch
sample_batch = next(iter(train_hybrid_loader))
print(f"\nSample batch - Spectrogram shape: {sample_batch[0][0].shape}, Features shape: {sample_batch[0][1].shape}, Labels shape: {sample_batch[1].shape}")


Feature count mismatch detected: 129 features in DataFrame vs 130 in scaler
Retraining scaler on train split with current features...
Scaler retrained on 129 features
Train dataset: 93147 samples (with SpecAugment)
Val dataset: 19896 samples
Test dataset: 19949 samples

Train batches: 1456
Val batches: 311
Test batches: 312

Sample batch - Spectrogram shape: torch.Size([64, 1, 128, 7]), Features shape: torch.Size([64, 129]), Labels shape: torch.Size([64])


## Define Enhanced Model Architecture V4.3 with Multi-Head Attention and Residual Connections


In [7]:
# Define Residual Block for CNN
class ResidualBlock2D(nn.Module):
    """Residual block for CNN branch"""
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock2D, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


# Enhanced Channel Attention Module (SE block with improved design)
class EnhancedChannelAttention(nn.Module):
    """Enhanced Channel attention module with improved design"""
    def __init__(self, channels, reduction=8):
        super(EnhancedChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        # Use smaller reduction for better capacity
        reduced_dim = max(1, channels // reduction)
        
        self.fc = nn.Sequential(
            nn.Linear(channels, reduced_dim, bias=False),
            nn.ReLU(),
            nn.Linear(reduced_dim, channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        out = avg_out + max_out
        return x * out.view(b, c, 1, 1)


# Define Feature Attention Module (SE-like for MLP features)
class FeatureAttention(nn.Module):
    """Squeeze-and-Excitation attention for feature vectors"""
    def __init__(self, n_features, reduction=8):
        super(FeatureAttention, self).__init__()
        self.reduction = reduction
        reduced_dim = max(1, n_features // reduction)
        
        self.fc = nn.Sequential(
            nn.Linear(n_features, reduced_dim, bias=False),
            nn.ReLU(),
            nn.Linear(reduced_dim, n_features, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        attention_weights = self.fc(x)
        return x * attention_weights


# Define Multi-Scale Convolution Block
class MultiScaleConvBlock(nn.Module):
    """Multi-scale convolution with parallel 3x3 and 5x5 kernels"""
    def __init__(self, in_channels, out_channels):
        super(MultiScaleConvBlock, self).__init__()
        self.conv3x3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels // 2),
            nn.ReLU()
        )
        self.conv5x5 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 2, kernel_size=5, padding=2),
            nn.BatchNorm2d(out_channels // 2),
            nn.ReLU()
        )
    
    def forward(self, x):
        out3x3 = self.conv3x3(x)
        out5x5 = self.conv5x5(x)
        return torch.cat([out3x3, out5x5], dim=1)


# Multi-Head Cross-Attention Fusion Module
class MultiHeadCrossAttentionFusion(nn.Module):
    """Multi-head cross-attention between CNN and MLP outputs"""
    def __init__(self, cnn_dim, mlp_dim, hidden_dim=256, num_heads=4):
        super(MultiHeadCrossAttentionFusion, self).__init__()
        self.cnn_dim = cnn_dim
        self.mlp_dim = mlp_dim
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        
        # Projections for multi-head attention
        self.cnn_to_qkv = nn.Linear(cnn_dim, hidden_dim * 3)
        self.mlp_to_qkv = nn.Linear(mlp_dim, hidden_dim * 3)
        
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        
        # Projections back to original dimensions
        self.cnn_proj = nn.Linear(hidden_dim, cnn_dim)
        self.mlp_proj = nn.Linear(hidden_dim, mlp_dim)
        
    def forward(self, cnn_out, mlp_out):
        # cnn_out: (batch, cnn_dim)
        # mlp_out: (batch, mlp_dim)
        batch_size = cnn_out.size(0)
        
        # CNN enhanced by MLP (multi-head attention)
        cnn_qkv = self.cnn_to_qkv(cnn_out).reshape(batch_size, 3, self.num_heads, self.head_dim).permute(1, 0, 2, 3)
        cnn_q, cnn_k, cnn_v = cnn_qkv[0], cnn_qkv[1], cnn_qkv[2]  # (batch, num_heads, head_dim)
        
        mlp_qkv = self.mlp_to_qkv(mlp_out).reshape(batch_size, 3, self.num_heads, self.head_dim).permute(1, 0, 2, 3)
        mlp_q, mlp_k, mlp_v = mlp_qkv[0], mlp_qkv[1], mlp_qkv[2]
        
        # Cross-attention: CNN queries attend to MLP keys/values
        scores = torch.matmul(cnn_q, mlp_k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_weights = F.softmax(scores, dim=-1)
        cnn_attended = torch.matmul(attn_weights, mlp_v)  # (batch, num_heads, head_dim)
        cnn_attended = cnn_attended.transpose(1, 2).contiguous().view(batch_size, self.hidden_dim)
        cnn_enhanced = cnn_out + self.cnn_proj(self.norm1(cnn_attended))
        
        # MLP enhanced by CNN (multi-head attention)
        scores2 = torch.matmul(mlp_q, cnn_k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_weights2 = F.softmax(scores2, dim=-1)
        mlp_attended = torch.matmul(attn_weights2, cnn_v)
        mlp_attended = mlp_attended.transpose(1, 2).contiguous().view(batch_size, self.hidden_dim)
        mlp_enhanced = mlp_out + self.mlp_proj(self.norm2(mlp_attended))
        
        return cnn_enhanced, mlp_enhanced


# Define Hybrid CNN+MLP Model V4.3
class HybridCNNMLP_V4_3(nn.Module):
    """
    Enhanced Hybrid model: CNN for spectrograms + MLP for features
    Version 4.3 Improvements:
    - Multi-Head Cross-Attention Fusion
    - Residual connections in MLP branch
    - Enhanced SE blocks in CNN branch
    - SpecAugment support
    - Focal Loss support
    """
    
    def __init__(self, n_features=129, num_classes=2, dropout=0.4):
        super(HybridCNNMLP_V4_3, self).__init__()
        
        # Multi-Scale CNN branch with enhanced attention
        self.cnn_initial = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # (64, 64, 3)
        )
        
        # Multi-scale block
        self.multiscale = MultiScaleConvBlock(64, 128)
        
        self.cnn_branch = nn.Sequential(
            ResidualBlock2D(128, 128),
            EnhancedChannelAttention(128, reduction=8),
            nn.MaxPool2d(2, 2),  # (128, 32, 1)
            
            ResidualBlock2D(128, 256),
            EnhancedChannelAttention(256, reduction=8),
            ResidualBlock2D(256, 512),
            EnhancedChannelAttention(512, reduction=8),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )
        
        # MLP branch with feature attention and residual connections
        self.feature_attention = FeatureAttention(n_features, reduction=8)
        
        # First layer
        self.mlp_layer1 = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Second layer with residual
        self.mlp_layer2 = nn.Sequential(
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout * 0.75)
        )
        self.mlp_residual1 = nn.Linear(256, 512)  # For residual connection
        
        # Third layer with residual
        self.mlp_layer3 = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )
        self.mlp_residual2 = nn.Linear(512, 256)  # For residual connection
        
        # Final layer
        self.mlp_layer4 = nn.Linear(256, 128)
        
        # Multi-head cross-attention fusion
        self.cross_attention = MultiHeadCrossAttentionFusion(cnn_dim=512, mlp_dim=128, hidden_dim=256, num_heads=4)
        
        # Enhanced Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(512 + 128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout * 0.75),
            
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            
            nn.Linear(64, num_classes)
        )
        
    def forward(self, x):
        spectrogram, features = x
        
        # CNN branch with multi-scale
        cnn_init = self.cnn_initial(spectrogram)
        cnn_multiscale = self.multiscale(cnn_init)
        cnn_out = self.cnn_branch(cnn_multiscale)  # (batch, 512)
        
        # MLP branch with feature attention and residual connections
        features_attended = self.feature_attention(features)
        
        mlp = self.mlp_layer1(features_attended)  # (batch, 256)
        mlp_input1 = mlp  # Save input for residual
        mlp = self.mlp_layer2(mlp) + self.mlp_residual1(mlp_input1)  # (batch, 512) with residual
        mlp_input2 = mlp  # Save input for residual
        mlp = self.mlp_layer3(mlp) + self.mlp_residual2(mlp_input2)  # (batch, 256) with residual
        mlp_out = self.mlp_layer4(mlp)  # (batch, 128)
        
        # Multi-head cross-attention fusion
        cnn_enhanced, mlp_enhanced = self.cross_attention(cnn_out, mlp_out)
        
        # Concatenate enhanced outputs
        fused = torch.cat([cnn_enhanced, mlp_enhanced], dim=1)  # (batch, 640)
        
        # Final classification
        out = self.fusion(fused)  # (batch, 2)
        
        return out
    
    def get_config(self):
        """Return model configuration"""
        return {
            'model_type': 'HybridCNNMLP_V4_3',
            'num_classes': 2,
            'n_features': 129,
            'input_shapes': {
                'spectrogram': (1, 128, 7),
                'features': (129,)
            },
            'version': '4.3'
        }

print("Model architecture V4.3 (Enhanced) defined successfully!")


Model architecture V4.3 (Enhanced) defined successfully!


In [8]:
class FocalLoss(nn.Module):
    """
    Focal Loss for addressing class imbalance and hard examples.
    FL(p_t) = -alpha * (1 - p_t)^gamma * log(p_t)
    """
    def __init__(self, alpha=0.25, gamma=2.0, weight=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight  # Class weights
        self.reduction = reduction
    
    def forward(self, pred, target):
        """
        Args:
            pred: (N, C) logits
            target: (N,) class indices
        Returns:
            Focal loss value
        """
        log_prob = F.log_softmax(pred, dim=1)
        prob = torch.exp(log_prob)
        
        # Get probability of true class
        prob_t = prob.gather(1, target.unsqueeze(1)).squeeze(1)
        
        # Compute focal weight: (1 - p_t)^gamma
        focal_weight = (1 - prob_t) ** self.gamma
        
        # Compute cross entropy
        ce_loss = -log_prob.gather(1, target.unsqueeze(1)).squeeze(1)
        
        # Apply class weights if provided
        if self.weight is not None:
            class_weights = self.weight[target]
            ce_loss = ce_loss * class_weights
        
        # Apply alpha weighting
        alpha_t = self.alpha if self.alpha is not None else 1.0
        if isinstance(alpha_t, (float, int)):
            alpha_t = torch.tensor(alpha_t, device=pred.device)
        if isinstance(alpha_t, torch.Tensor) and len(alpha_t.shape) == 0:
            # Scalar alpha - apply uniformly
            focal_loss = alpha_t * focal_weight * ce_loss
        else:
            # Per-class alpha
            alpha_t = alpha_t[target]
            focal_loss = alpha_t * focal_weight * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

print("FocalLoss class defined successfully!")


FocalLoss class defined successfully!


## Define Training Utilities


In [9]:
# Training utilities
def train_epoch(model, dataloader, criterion, optimizer, device, max_grad_norm=None):
    """Train for one epoch with optional gradient clipping"""
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for batch in tqdm(dataloader, desc="Training", leave=False):
        if isinstance(batch[0], (tuple, list)) and len(batch[0]) == 2:
            inputs = tuple(x.to(device) for x in batch[0])
        else:
            inputs = batch[0].to(device)
        
        labels = batch[1].to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        optimizer.step()
        
        running_loss += loss.item()
        preds = torch.argmax(outputs, dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())
    
    avg_loss = running_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    
    return avg_loss, accuracy


def validate(model, dataloader, criterion, device):
    """Validate model"""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating", leave=False):
            if isinstance(batch[0], (tuple, list)) and len(batch[0]) == 2:
                inputs = tuple(x.to(device) for x in batch[0])
            else:
                inputs = batch[0].to(device)
            
            labels = batch[1].to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs)
    
    avg_loss = running_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    
    try:
        roc_auc = roc_auc_score(all_labels, np.array(all_probs)[:, 1])
    except:
        roc_auc = 0.0
    
    metrics = {
        'loss': avg_loss,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'roc_auc': roc_auc
    }
    
    return metrics, all_preds, all_labels, all_probs


def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler,
                device, num_epochs, save_dir, model_name, early_stopping_patience=20, max_grad_norm=None):
    """Train model with early stopping and optional gradient clipping"""
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    best_val_f1 = 0.0
    best_epoch = 0
    patience_counter = 0
    training_history = []
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 50)
        
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, max_grad_norm)
        val_metrics, _, _, _ = validate(model, val_loader, criterion, device)
        
        if scheduler is not None:
            scheduler.step()
        
        current_lr = optimizer.param_groups[0]['lr']
        epoch_metrics = {
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_accuracy': train_acc,
            'val_loss': val_metrics['loss'],
            'val_accuracy': val_metrics['accuracy'],
            'val_precision': val_metrics['precision'],
            'val_recall': val_metrics['recall'],
            'val_f1': val_metrics['f1'],
            'val_roc_auc': val_metrics['roc_auc'],
            'learning_rate': current_lr
        }
        training_history.append(epoch_metrics)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}")
        print(f"Val F1: {val_metrics['f1']:.4f}, Val ROC-AUC: {val_metrics['roc_auc']:.4f}")
        print(f"Learning Rate: {current_lr:.6f}")
        
        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            best_epoch = epoch + 1
            patience_counter = 0
            
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': best_val_f1,
                'val_metrics': val_metrics
            }, save_dir / 'best_model.pt')
            
            print(f"✓ New best model saved! (F1: {best_val_f1:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print(f"\nEarly stopping at epoch {epoch+1}")
                print(f"Best F1: {best_val_f1:.4f} at epoch {best_epoch}")
                break
    
    with open(save_dir / 'training_history.json', 'w') as f:
        json.dump(training_history, f, indent=2)
    
    config = model.get_config() if hasattr(model, 'get_config') else {}
    config.update({
        'best_epoch': best_epoch,
        'best_val_f1': best_val_f1,
        'num_epochs': num_epochs
    })
    with open(save_dir / 'config.json', 'w') as f:
        json.dump(config, f, indent=2)
    
    return training_history, best_epoch


def evaluate_model(model, test_loader, criterion, device):
    """Evaluate model on test set"""
    metrics, preds, labels, probs = validate(model, test_loader, criterion, device)
    
    precision_per_class = precision_score(labels, preds, average=None, zero_division=0)
    recall_per_class = recall_score(labels, preds, average=None, zero_division=0)
    f1_per_class = f1_score(labels, preds, average=None, zero_division=0)
    
    metrics['precision_d'] = float(precision_per_class[0])
    metrics['precision_t'] = float(precision_per_class[1])
    metrics['recall_d'] = float(recall_per_class[0])
    metrics['recall_t'] = float(recall_per_class[1])
    metrics['f1_d'] = float(f1_per_class[0])
    metrics['f1_t'] = float(f1_per_class[1])
    metrics['confusion_matrix'] = confusion_matrix(labels, preds).tolist()
    
    return metrics, preds, labels, probs


class WarmupCosineScheduler:
    """Learning rate scheduler with warmup and cosine annealing"""
    def __init__(self, optimizer, warmup_epochs, total_epochs, min_lr=1e-6):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.min_lr = min_lr
        self.base_lr = optimizer.param_groups[0]['lr']
        self.current_epoch = 0
    
    def step(self):
        self.current_epoch += 1
        
        if self.current_epoch <= self.warmup_epochs:
            lr = self.base_lr * (self.current_epoch / self.warmup_epochs)
        else:
            progress = (self.current_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
            lr = self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
    
    def get_last_lr(self):
        return [self.optimizer.param_groups[0]['lr']]

print("Training utilities defined successfully!")


Training utilities defined successfully!


## Create Model and Training Configuration


In [10]:
# Create model V4.3 with automatic feature count detection
model = HybridCNNMLP_V4_3(n_features=len(feature_cols), num_classes=2, dropout=0.4).to(device)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {model.get_config()['model_type']}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Number of features: {len(feature_cols)}")

# Prepare class weights for loss function
class_weights = torch.tensor([
    class_weights_dict.get('0', class_weights_dict.get(0, 1.0)), 
    class_weights_dict.get('1', class_weights_dict.get(1, 1.0))
], dtype=torch.float32).to(device)

# Loss function: Focal Loss with class weights
criterion = FocalLoss(alpha=0.25, gamma=2.0, weight=class_weights, reduction='mean')

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)

# Learning rate scheduler with warmup and cosine annealing
num_epochs = 200
warmup_epochs = 5
scheduler = WarmupCosineScheduler(optimizer, warmup_epochs=warmup_epochs, total_epochs=num_epochs, min_lr=1e-6)

# Output directory
OUTPUT_DIR = DATA_DIR / 'improved_models'
save_dir = OUTPUT_DIR / 'hybrid_cnn_mlp_v4_3_enhanced'
save_dir.mkdir(parents=True, exist_ok=True)

print(f"\nTraining configuration:")
print(f"- Epochs: {num_epochs}")
print(f"- Warmup epochs: {warmup_epochs}")
print(f"- Initial LR: {optimizer.param_groups[0]['lr']}")
print(f"- Loss function: Focal Loss (alpha=0.25, gamma=2.0)")
print(f"- Gradient clipping: 1.0")
print(f"- Early stopping patience: 20")
print(f"- Dropout: 0.4")
print(f"- SpecAugment: Enabled for training")
print(f"- Context windows: ±100ms (V2 with VOT and burst features)")
print(f"- Save directory: {save_dir}")


Model: HybridCNNMLP_V4_3
Total parameters: 6,579,554
Trainable parameters: 6,579,554
Number of features: 129

Training configuration:
- Epochs: 200
- Warmup epochs: 5
- Initial LR: 0.0005
- Loss function: Focal Loss (alpha=0.25, gamma=2.0)
- Gradient clipping: 1.0
- Early stopping patience: 20
- Dropout: 0.4
- SpecAugment: Enabled for training
- Context windows: ±100ms (V2 with VOT and burst features)
- Save directory: /Volumes/SSanDisk/SpeechRec-German/artifacts/d-t_dl_models_with_context_v2/improved_models/hybrid_cnn_mlp_v4_3_enhanced


## Train Model


In [11]:
# Train model
history, best_epoch = train_model(
    model, train_hybrid_loader, val_hybrid_loader, criterion, optimizer, scheduler,
    device, num_epochs=num_epochs, save_dir=save_dir, model_name='hybrid_cnn_mlp_v4_3_enhanced', 
    early_stopping_patience=20, max_grad_norm=1.0
)

# Load best model and evaluate on test set
checkpoint = torch.load(save_dir / 'best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
test_metrics, test_preds, test_labels, test_probs = evaluate_model(model, test_hybrid_loader, criterion, device)

# Save test metrics
with open(save_dir / 'test_metrics.json', 'w') as f:
    json.dump(test_metrics, f, indent=2)

print(f"\n{'='*60}")
print(f"Final Test Results:")
print(f"{'='*60}")
print(f"Accuracy: {test_metrics['accuracy']:.4f}")
print(f"F1-score: {test_metrics['f1']:.4f}")
print(f"ROC-AUC: {test_metrics['roc_auc']:.4f}")
print(f"Precision: {test_metrics['precision']:.4f}")
print(f"Recall: {test_metrics['recall']:.4f}")
print(f"Best epoch: {best_epoch}")



Epoch 1/200
--------------------------------------------------


                                                             

Train Loss: 0.0190, Train Acc: 0.8804
Val Loss: 0.0148, Val Acc: 0.9100
Val F1: 0.9104, Val ROC-AUC: 0.9746
Learning Rate: 0.000100
✓ New best model saved! (F1: 0.9104)

Epoch 2/200
--------------------------------------------------


                                                             

Train Loss: 0.0143, Train Acc: 0.9126
Val Loss: 0.0121, Val Acc: 0.9238
Val F1: 0.9240, Val ROC-AUC: 0.9807
Learning Rate: 0.000200
✓ New best model saved! (F1: 0.9240)

Epoch 3/200
--------------------------------------------------


                                                             

Train Loss: 0.0137, Train Acc: 0.9172
Val Loss: 0.0121, Val Acc: 0.9261
Val F1: 0.9263, Val ROC-AUC: 0.9811
Learning Rate: 0.000300
✓ New best model saved! (F1: 0.9263)

Epoch 4/200
--------------------------------------------------


                                                             

Train Loss: 0.0134, Train Acc: 0.9189
Val Loss: 0.0123, Val Acc: 0.9206
Val F1: 0.9209, Val ROC-AUC: 0.9819
Learning Rate: 0.000400

Epoch 5/200
--------------------------------------------------


                                                             

Train Loss: 0.0128, Train Acc: 0.9243
Val Loss: 0.0113, Val Acc: 0.9334
Val F1: 0.9334, Val ROC-AUC: 0.9832
Learning Rate: 0.000500
✓ New best model saved! (F1: 0.9334)

Epoch 6/200
--------------------------------------------------


                                                             

Train Loss: 0.0127, Train Acc: 0.9249
Val Loss: 0.0121, Val Acc: 0.9322
Val F1: 0.9323, Val ROC-AUC: 0.9835
Learning Rate: 0.000500

Epoch 7/200
--------------------------------------------------


                                                             

Train Loss: 0.0120, Train Acc: 0.9288
Val Loss: 0.0114, Val Acc: 0.9299
Val F1: 0.9301, Val ROC-AUC: 0.9839
Learning Rate: 0.000500

Epoch 8/200
--------------------------------------------------


                                                             

Train Loss: 0.0118, Train Acc: 0.9320
Val Loss: 0.0108, Val Acc: 0.9325
Val F1: 0.9326, Val ROC-AUC: 0.9844
Learning Rate: 0.000500

Epoch 9/200
--------------------------------------------------


                                                             

Train Loss: 0.0115, Train Acc: 0.9324
Val Loss: 0.0108, Val Acc: 0.9334
Val F1: 0.9335, Val ROC-AUC: 0.9845
Learning Rate: 0.000499
✓ New best model saved! (F1: 0.9335)

Epoch 10/200
--------------------------------------------------


                                                             

Train Loss: 0.0114, Train Acc: 0.9333
Val Loss: 0.0106, Val Acc: 0.9358
Val F1: 0.9360, Val ROC-AUC: 0.9846
Learning Rate: 0.000499
✓ New best model saved! (F1: 0.9360)

Epoch 11/200
--------------------------------------------------


                                                             

Train Loss: 0.0112, Train Acc: 0.9342
Val Loss: 0.0105, Val Acc: 0.9375
Val F1: 0.9376, Val ROC-AUC: 0.9853
Learning Rate: 0.000499
✓ New best model saved! (F1: 0.9376)

Epoch 12/200
--------------------------------------------------


                                                             

Train Loss: 0.0111, Train Acc: 0.9355
Val Loss: 0.0112, Val Acc: 0.9345
Val F1: 0.9347, Val ROC-AUC: 0.9846
Learning Rate: 0.000498

Epoch 13/200
--------------------------------------------------


                                                             

Train Loss: 0.0110, Train Acc: 0.9362
Val Loss: 0.0106, Val Acc: 0.9382
Val F1: 0.9383, Val ROC-AUC: 0.9851
Learning Rate: 0.000498
✓ New best model saved! (F1: 0.9383)

Epoch 14/200
--------------------------------------------------


                                                             

Train Loss: 0.0110, Train Acc: 0.9350
Val Loss: 0.0104, Val Acc: 0.9402
Val F1: 0.9403, Val ROC-AUC: 0.9856
Learning Rate: 0.000497
✓ New best model saved! (F1: 0.9403)

Epoch 15/200
--------------------------------------------------


                                                             

Train Loss: 0.0108, Train Acc: 0.9378
Val Loss: 0.0105, Val Acc: 0.9361
Val F1: 0.9362, Val ROC-AUC: 0.9852
Learning Rate: 0.000497

Epoch 16/200
--------------------------------------------------


                                                             

Train Loss: 0.0109, Train Acc: 0.9367
Val Loss: 0.0106, Val Acc: 0.9365
Val F1: 0.9366, Val ROC-AUC: 0.9848
Learning Rate: 0.000496

Epoch 17/200
--------------------------------------------------


                                                             

Train Loss: 0.0109, Train Acc: 0.9368
Val Loss: 0.0105, Val Acc: 0.9342
Val F1: 0.9344, Val ROC-AUC: 0.9857
Learning Rate: 0.000495

Epoch 18/200
--------------------------------------------------


                                                             

Train Loss: 0.0106, Train Acc: 0.9373
Val Loss: 0.0103, Val Acc: 0.9369
Val F1: 0.9371, Val ROC-AUC: 0.9861
Learning Rate: 0.000495

Epoch 19/200
--------------------------------------------------


                                                             

Train Loss: 0.0106, Train Acc: 0.9392
Val Loss: 0.0102, Val Acc: 0.9394
Val F1: 0.9395, Val ROC-AUC: 0.9858
Learning Rate: 0.000494

Epoch 20/200
--------------------------------------------------


                                                             

Train Loss: 0.0104, Train Acc: 0.9405
Val Loss: 0.0102, Val Acc: 0.9399
Val F1: 0.9399, Val ROC-AUC: 0.9857
Learning Rate: 0.000493

Epoch 21/200
--------------------------------------------------


                                                             

Train Loss: 0.0104, Train Acc: 0.9391
Val Loss: 0.0104, Val Acc: 0.9358
Val F1: 0.9360, Val ROC-AUC: 0.9863
Learning Rate: 0.000492

Epoch 22/200
--------------------------------------------------


                                                             

Train Loss: 0.0104, Train Acc: 0.9400
Val Loss: 0.0103, Val Acc: 0.9371
Val F1: 0.9373, Val ROC-AUC: 0.9860
Learning Rate: 0.000491

Epoch 23/200
--------------------------------------------------


                                                             

Train Loss: 0.0103, Train Acc: 0.9402
Val Loss: 0.0101, Val Acc: 0.9371
Val F1: 0.9373, Val ROC-AUC: 0.9864
Learning Rate: 0.000490

Epoch 24/200
--------------------------------------------------


                                                             

Train Loss: 0.0104, Train Acc: 0.9402
Val Loss: 0.0104, Val Acc: 0.9350
Val F1: 0.9352, Val ROC-AUC: 0.9858
Learning Rate: 0.000488

Epoch 25/200
--------------------------------------------------


                                                             

Train Loss: 0.0103, Train Acc: 0.9402
Val Loss: 0.0102, Val Acc: 0.9408
Val F1: 0.9408, Val ROC-AUC: 0.9863
Learning Rate: 0.000487
✓ New best model saved! (F1: 0.9408)

Epoch 26/200
--------------------------------------------------


                                                             

Train Loss: 0.0102, Train Acc: 0.9410
Val Loss: 0.0102, Val Acc: 0.9396
Val F1: 0.9397, Val ROC-AUC: 0.9860
Learning Rate: 0.000486

Epoch 27/200
--------------------------------------------------


                                                             

Train Loss: 0.0102, Train Acc: 0.9411
Val Loss: 0.0106, Val Acc: 0.9372
Val F1: 0.9373, Val ROC-AUC: 0.9862
Learning Rate: 0.000484

Epoch 28/200
--------------------------------------------------


                                                             

Train Loss: 0.0103, Train Acc: 0.9398
Val Loss: 0.0104, Val Acc: 0.9372
Val F1: 0.9373, Val ROC-AUC: 0.9856
Learning Rate: 0.000483

Epoch 29/200
--------------------------------------------------


                                                             

Train Loss: 0.0100, Train Acc: 0.9414
Val Loss: 0.0102, Val Acc: 0.9390
Val F1: 0.9391, Val ROC-AUC: 0.9861
Learning Rate: 0.000482

Epoch 30/200
--------------------------------------------------


                                                             

Train Loss: 0.0101, Train Acc: 0.9416
Val Loss: 0.0101, Val Acc: 0.9393
Val F1: 0.9394, Val ROC-AUC: 0.9864
Learning Rate: 0.000480

Epoch 31/200
--------------------------------------------------


                                                             

Train Loss: 0.0102, Train Acc: 0.9408
Val Loss: 0.0104, Val Acc: 0.9347
Val F1: 0.9349, Val ROC-AUC: 0.9858
Learning Rate: 0.000478

Epoch 32/200
--------------------------------------------------


                                                             

Train Loss: 0.0101, Train Acc: 0.9417
Val Loss: 0.0101, Val Acc: 0.9416
Val F1: 0.9417, Val ROC-AUC: 0.9863
Learning Rate: 0.000477
✓ New best model saved! (F1: 0.9417)

Epoch 33/200
--------------------------------------------------


                                                             

Train Loss: 0.0101, Train Acc: 0.9413
Val Loss: 0.0098, Val Acc: 0.9395
Val F1: 0.9397, Val ROC-AUC: 0.9871
Learning Rate: 0.000475

Epoch 34/200
--------------------------------------------------


                                                             

Train Loss: 0.0101, Train Acc: 0.9414
Val Loss: 0.0101, Val Acc: 0.9383
Val F1: 0.9385, Val ROC-AUC: 0.9865
Learning Rate: 0.000473

Epoch 35/200
--------------------------------------------------


                                                             

Train Loss: 0.0101, Train Acc: 0.9413
Val Loss: 0.0099, Val Acc: 0.9388
Val F1: 0.9390, Val ROC-AUC: 0.9866
Learning Rate: 0.000471

Epoch 36/200
--------------------------------------------------


                                                             

Train Loss: 0.0101, Train Acc: 0.9413
Val Loss: 0.0100, Val Acc: 0.9399
Val F1: 0.9400, Val ROC-AUC: 0.9869
Learning Rate: 0.000470

Epoch 37/200
--------------------------------------------------


                                                             

Train Loss: 0.0099, Train Acc: 0.9422
Val Loss: 0.0100, Val Acc: 0.9388
Val F1: 0.9389, Val ROC-AUC: 0.9866
Learning Rate: 0.000468

Epoch 38/200
--------------------------------------------------


                                                             

Train Loss: 0.0098, Train Acc: 0.9431
Val Loss: 0.0102, Val Acc: 0.9389
Val F1: 0.9390, Val ROC-AUC: 0.9863
Learning Rate: 0.000466

Epoch 39/200
--------------------------------------------------


                                                             

Train Loss: 0.0100, Train Acc: 0.9435
Val Loss: 0.0105, Val Acc: 0.9380
Val F1: 0.9382, Val ROC-AUC: 0.9859
Learning Rate: 0.000463

Epoch 40/200
--------------------------------------------------


                                                             

Train Loss: 0.0099, Train Acc: 0.9426
Val Loss: 0.0109, Val Acc: 0.9230
Val F1: 0.9233, Val ROC-AUC: 0.9861
Learning Rate: 0.000461

Epoch 41/200
--------------------------------------------------


                                                             

Train Loss: 0.0100, Train Acc: 0.9423
Val Loss: 0.0102, Val Acc: 0.9381
Val F1: 0.9383, Val ROC-AUC: 0.9868
Learning Rate: 0.000459

Epoch 42/200
--------------------------------------------------


                                                             

Train Loss: 0.0101, Train Acc: 0.9425
Val Loss: 0.0099, Val Acc: 0.9416
Val F1: 0.9416, Val ROC-AUC: 0.9867
Learning Rate: 0.000457

Epoch 43/200
--------------------------------------------------


                                                             

Train Loss: 0.0097, Train Acc: 0.9449
Val Loss: 0.0100, Val Acc: 0.9383
Val F1: 0.9385, Val ROC-AUC: 0.9868
Learning Rate: 0.000455

Epoch 44/200
--------------------------------------------------


                                                             

Train Loss: 0.0101, Train Acc: 0.9423
Val Loss: 0.0100, Val Acc: 0.9422
Val F1: 0.9423, Val ROC-AUC: 0.9866
Learning Rate: 0.000452
✓ New best model saved! (F1: 0.9423)

Epoch 45/200
--------------------------------------------------


                                                             

Train Loss: 0.0099, Train Acc: 0.9440
Val Loss: 0.0098, Val Acc: 0.9430
Val F1: 0.9430, Val ROC-AUC: 0.9870
Learning Rate: 0.000450
✓ New best model saved! (F1: 0.9430)

Epoch 46/200
--------------------------------------------------


                                                             

Train Loss: 0.0097, Train Acc: 0.9441
Val Loss: 0.0102, Val Acc: 0.9398
Val F1: 0.9399, Val ROC-AUC: 0.9860
Learning Rate: 0.000448

Epoch 47/200
--------------------------------------------------


                                                             

Train Loss: 0.0096, Train Acc: 0.9443
Val Loss: 0.0101, Val Acc: 0.9397
Val F1: 0.9398, Val ROC-AUC: 0.9863
Learning Rate: 0.000445

Epoch 48/200
--------------------------------------------------


                                                             

Train Loss: 0.0097, Train Acc: 0.9440
Val Loss: 0.0099, Val Acc: 0.9409
Val F1: 0.9410, Val ROC-AUC: 0.9868
Learning Rate: 0.000442

Epoch 49/200
--------------------------------------------------


                                                             

Train Loss: 0.0097, Train Acc: 0.9442
Val Loss: 0.0099, Val Acc: 0.9404
Val F1: 0.9405, Val ROC-AUC: 0.9868
Learning Rate: 0.000440

Epoch 50/200
--------------------------------------------------


                                                             

Train Loss: 0.0097, Train Acc: 0.9437
Val Loss: 0.0102, Val Acc: 0.9397
Val F1: 0.9399, Val ROC-AUC: 0.9865
Learning Rate: 0.000437

Epoch 51/200
--------------------------------------------------


                                                             

Train Loss: 0.0096, Train Acc: 0.9442
Val Loss: 0.0101, Val Acc: 0.9395
Val F1: 0.9396, Val ROC-AUC: 0.9867
Learning Rate: 0.000435

Epoch 52/200
--------------------------------------------------


                                                             

Train Loss: 0.0096, Train Acc: 0.9442
Val Loss: 0.0100, Val Acc: 0.9413
Val F1: 0.9414, Val ROC-AUC: 0.9868
Learning Rate: 0.000432

Epoch 53/200
--------------------------------------------------


                                                             

Train Loss: 0.0095, Train Acc: 0.9459
Val Loss: 0.0099, Val Acc: 0.9396
Val F1: 0.9398, Val ROC-AUC: 0.9870
Learning Rate: 0.000429

Epoch 54/200
--------------------------------------------------


                                                             

Train Loss: 0.0096, Train Acc: 0.9451
Val Loss: 0.0102, Val Acc: 0.9344
Val F1: 0.9346, Val ROC-AUC: 0.9869
Learning Rate: 0.000426

Epoch 55/200
--------------------------------------------------


                                                             

Train Loss: 0.0095, Train Acc: 0.9457
Val Loss: 0.0099, Val Acc: 0.9386
Val F1: 0.9387, Val ROC-AUC: 0.9871
Learning Rate: 0.000423

Epoch 56/200
--------------------------------------------------


                                                             

Train Loss: 0.0095, Train Acc: 0.9446
Val Loss: 0.0098, Val Acc: 0.9420
Val F1: 0.9421, Val ROC-AUC: 0.9869
Learning Rate: 0.000420

Epoch 57/200
--------------------------------------------------


                                                             

Train Loss: 0.0097, Train Acc: 0.9434
Val Loss: 0.0103, Val Acc: 0.9366
Val F1: 0.9368, Val ROC-AUC: 0.9868
Learning Rate: 0.000417

Epoch 58/200
--------------------------------------------------


                                                             

Train Loss: 0.0095, Train Acc: 0.9459
Val Loss: 0.0100, Val Acc: 0.9400
Val F1: 0.9401, Val ROC-AUC: 0.9869
Learning Rate: 0.000414

Epoch 59/200
--------------------------------------------------


                                                             

Train Loss: 0.0092, Train Acc: 0.9470
Val Loss: 0.0100, Val Acc: 0.9379
Val F1: 0.9381, Val ROC-AUC: 0.9870
Learning Rate: 0.000411

Epoch 60/200
--------------------------------------------------


                                                             

Train Loss: 0.0095, Train Acc: 0.9452
Val Loss: 0.0107, Val Acc: 0.9324
Val F1: 0.9327, Val ROC-AUC: 0.9866
Learning Rate: 0.000408

Epoch 61/200
--------------------------------------------------


                                                             

Train Loss: 0.0097, Train Acc: 0.9442
Val Loss: 0.0099, Val Acc: 0.9430
Val F1: 0.9431, Val ROC-AUC: 0.9871
Learning Rate: 0.000405
✓ New best model saved! (F1: 0.9431)

Epoch 62/200
--------------------------------------------------


                                                             

Train Loss: 0.0095, Train Acc: 0.9445
Val Loss: 0.0100, Val Acc: 0.9367
Val F1: 0.9369, Val ROC-AUC: 0.9872
Learning Rate: 0.000402

Epoch 63/200
--------------------------------------------------


                                                             

Train Loss: 0.0092, Train Acc: 0.9468
Val Loss: 0.0102, Val Acc: 0.9391
Val F1: 0.9392, Val ROC-AUC: 0.9868
Learning Rate: 0.000399

Epoch 64/200
--------------------------------------------------


                                                             

Train Loss: 0.0095, Train Acc: 0.9452
Val Loss: 0.0105, Val Acc: 0.9364
Val F1: 0.9366, Val ROC-AUC: 0.9870
Learning Rate: 0.000396

Epoch 65/200
--------------------------------------------------


                                                             

Train Loss: 0.0094, Train Acc: 0.9457
Val Loss: 0.0104, Val Acc: 0.9394
Val F1: 0.9395, Val ROC-AUC: 0.9865
Learning Rate: 0.000392

Epoch 66/200
--------------------------------------------------


                                                             

Train Loss: 0.0094, Train Acc: 0.9457
Val Loss: 0.0101, Val Acc: 0.9411
Val F1: 0.9412, Val ROC-AUC: 0.9871
Learning Rate: 0.000389

Epoch 67/200
--------------------------------------------------


                                                             

Train Loss: 0.0093, Train Acc: 0.9460
Val Loss: 0.0099, Val Acc: 0.9424
Val F1: 0.9425, Val ROC-AUC: 0.9874
Learning Rate: 0.000386

Epoch 68/200
--------------------------------------------------


                                                             

Train Loss: 0.0094, Train Acc: 0.9462
Val Loss: 0.0102, Val Acc: 0.9390
Val F1: 0.9391, Val ROC-AUC: 0.9871
Learning Rate: 0.000382

Epoch 69/200
--------------------------------------------------


                                                             

Train Loss: 0.0094, Train Acc: 0.9469
Val Loss: 0.0099, Val Acc: 0.9415
Val F1: 0.9416, Val ROC-AUC: 0.9873
Learning Rate: 0.000379

Epoch 70/200
--------------------------------------------------


                                                             

Train Loss: 0.0092, Train Acc: 0.9471
Val Loss: 0.0098, Val Acc: 0.9411
Val F1: 0.9413, Val ROC-AUC: 0.9875
Learning Rate: 0.000375

Epoch 71/200
--------------------------------------------------


                                                             

Train Loss: 0.0092, Train Acc: 0.9466
Val Loss: 0.0099, Val Acc: 0.9432
Val F1: 0.9433, Val ROC-AUC: 0.9872
Learning Rate: 0.000372
✓ New best model saved! (F1: 0.9433)

Epoch 72/200
--------------------------------------------------


                                                             

Train Loss: 0.0091, Train Acc: 0.9485
Val Loss: 0.0103, Val Acc: 0.9347
Val F1: 0.9349, Val ROC-AUC: 0.9866
Learning Rate: 0.000368

Epoch 73/200
--------------------------------------------------


                                                             

Train Loss: 0.0090, Train Acc: 0.9476
Val Loss: 0.0097, Val Acc: 0.9454
Val F1: 0.9455, Val ROC-AUC: 0.9876
Learning Rate: 0.000365
✓ New best model saved! (F1: 0.9455)

Epoch 74/200
--------------------------------------------------


                                                             

Train Loss: 0.0089, Train Acc: 0.9490
Val Loss: 0.0100, Val Acc: 0.9436
Val F1: 0.9437, Val ROC-AUC: 0.9872
Learning Rate: 0.000361

Epoch 75/200
--------------------------------------------------


                                                             

Train Loss: 0.0093, Train Acc: 0.9469
Val Loss: 0.0098, Val Acc: 0.9409
Val F1: 0.9410, Val ROC-AUC: 0.9873
Learning Rate: 0.000357

Epoch 76/200
--------------------------------------------------


                                                             

Train Loss: 0.0092, Train Acc: 0.9466
Val Loss: 0.0100, Val Acc: 0.9382
Val F1: 0.9383, Val ROC-AUC: 0.9871
Learning Rate: 0.000354

Epoch 77/200
--------------------------------------------------


                                                             

Train Loss: 0.0089, Train Acc: 0.9485
Val Loss: 0.0096, Val Acc: 0.9434
Val F1: 0.9435, Val ROC-AUC: 0.9877
Learning Rate: 0.000350

Epoch 78/200
--------------------------------------------------


                                                             

Train Loss: 0.0088, Train Acc: 0.9494
Val Loss: 0.0101, Val Acc: 0.9445
Val F1: 0.9445, Val ROC-AUC: 0.9874
Learning Rate: 0.000346

Epoch 79/200
--------------------------------------------------


                                                             

Train Loss: 0.0090, Train Acc: 0.9490
Val Loss: 0.0097, Val Acc: 0.9429
Val F1: 0.9430, Val ROC-AUC: 0.9875
Learning Rate: 0.000343

Epoch 80/200
--------------------------------------------------


                                                             

Train Loss: 0.0090, Train Acc: 0.9478
Val Loss: 0.0098, Val Acc: 0.9432
Val F1: 0.9432, Val ROC-AUC: 0.9874
Learning Rate: 0.000339

Epoch 81/200
--------------------------------------------------


                                                             

Train Loss: 0.0090, Train Acc: 0.9480
Val Loss: 0.0097, Val Acc: 0.9442
Val F1: 0.9442, Val ROC-AUC: 0.9875
Learning Rate: 0.000335

Epoch 82/200
--------------------------------------------------


                                                             

Train Loss: 0.0087, Train Acc: 0.9503
Val Loss: 0.0100, Val Acc: 0.9400
Val F1: 0.9401, Val ROC-AUC: 0.9875
Learning Rate: 0.000331

Epoch 83/200
--------------------------------------------------


                                                             

Train Loss: 0.0089, Train Acc: 0.9483
Val Loss: 0.0098, Val Acc: 0.9391
Val F1: 0.9393, Val ROC-AUC: 0.9873
Learning Rate: 0.000328

Epoch 84/200
--------------------------------------------------


                                                             

Train Loss: 0.0088, Train Acc: 0.9497
Val Loss: 0.0097, Val Acc: 0.9439
Val F1: 0.9439, Val ROC-AUC: 0.9873
Learning Rate: 0.000324

Epoch 85/200
--------------------------------------------------


                                                             

Train Loss: 0.0088, Train Acc: 0.9496
Val Loss: 0.0097, Val Acc: 0.9422
Val F1: 0.9423, Val ROC-AUC: 0.9874
Learning Rate: 0.000320

Epoch 86/200
--------------------------------------------------


                                                             

Train Loss: 0.0089, Train Acc: 0.9493
Val Loss: 0.0099, Val Acc: 0.9398
Val F1: 0.9400, Val ROC-AUC: 0.9873
Learning Rate: 0.000316

Epoch 87/200
--------------------------------------------------


                                                             

Train Loss: 0.0087, Train Acc: 0.9504
Val Loss: 0.0097, Val Acc: 0.9446
Val F1: 0.9447, Val ROC-AUC: 0.9876
Learning Rate: 0.000312

Epoch 88/200
--------------------------------------------------


                                                             

Train Loss: 0.0087, Train Acc: 0.9498
Val Loss: 0.0096, Val Acc: 0.9448
Val F1: 0.9448, Val ROC-AUC: 0.9875
Learning Rate: 0.000308

Epoch 89/200
--------------------------------------------------


                                                             

Train Loss: 0.0087, Train Acc: 0.9506
Val Loss: 0.0099, Val Acc: 0.9428
Val F1: 0.9429, Val ROC-AUC: 0.9872
Learning Rate: 0.000304

Epoch 90/200
--------------------------------------------------


                                                             

Train Loss: 0.0085, Train Acc: 0.9509
Val Loss: 0.0098, Val Acc: 0.9421
Val F1: 0.9422, Val ROC-AUC: 0.9874
Learning Rate: 0.000300

Epoch 91/200
--------------------------------------------------


                                                             

Train Loss: 0.0087, Train Acc: 0.9504
Val Loss: 0.0096, Val Acc: 0.9433
Val F1: 0.9434, Val ROC-AUC: 0.9878
Learning Rate: 0.000296

Epoch 92/200
--------------------------------------------------


                                                             

Train Loss: 0.0085, Train Acc: 0.9508
Val Loss: 0.0096, Val Acc: 0.9452
Val F1: 0.9453, Val ROC-AUC: 0.9880
Learning Rate: 0.000293

Epoch 93/200
--------------------------------------------------


                                                             

Train Loss: 0.0086, Train Acc: 0.9508
Val Loss: 0.0095, Val Acc: 0.9448
Val F1: 0.9448, Val ROC-AUC: 0.9877
Learning Rate: 0.000289

Early stopping at epoch 93
Best F1: 0.9455 at epoch 73


                                                             


Final Test Results:
Accuracy: 0.9442
F1-score: 0.9443
ROC-AUC: 0.9872
Precision: 0.9445
Recall: 0.9442
Best epoch: 73




## Save Predictions with Probabilities


In [12]:
# Get test dataset to extract phoneme metadata
test_df = df[df['split'] == 'test'].reset_index(drop=True)

# Create predictions dataframe with probabilities
predictions_data = []
for idx, row in test_df.iterrows():
    predictions_data.append({
        'phoneme_id': row['phoneme_id'],
        'utterance_id': row.get('utterance_id', None),
        'phoneme': row.get('phoneme', row.get('class', None)),
        'true_class': row['class'],
        'true_class_encoded': int(test_labels[idx]),
        'predicted_class_encoded': int(test_preds[idx]),
        'predicted_class': 'd' if test_preds[idx] == 0 else 't',
        'prob_class_0': float(test_probs[idx][0]),
        'prob_class_1': float(test_probs[idx][1]),
        'max_prob': float(np.max(test_probs[idx])),
        'is_correct': int(test_labels[idx] == test_preds[idx]),
        'confidence': float(np.max(test_probs[idx])) if test_labels[idx] == test_preds[idx] else float(test_probs[idx][test_preds[idx]]),
        'duration_ms': row.get('duration_ms', None)
    })

predictions_df = pd.DataFrame(predictions_data)

# Save to CSV
predictions_df.to_csv(save_dir / 'test_predictions_with_probs.csv', index=False)
print(f"Saved predictions with probabilities to: {save_dir / 'test_predictions_with_probs.csv'}")
print(f"Total predictions: {len(predictions_df)}")
print(f"Correct predictions: {predictions_df['is_correct'].sum()}")
print(f"Incorrect predictions: {(~predictions_df['is_correct'].astype(bool)).sum()}")

# Save summary statistics
summary_stats = {
    'total_samples': len(predictions_df),
    'correct_predictions': int(predictions_df['is_correct'].sum()),
    'incorrect_predictions': int((~predictions_df['is_correct'].astype(bool)).sum()),
    'accuracy': float(predictions_df['is_correct'].mean()),
    'avg_confidence_correct': float(predictions_df[predictions_df['is_correct'] == 1]['confidence'].mean()),
    'avg_confidence_incorrect': float(predictions_df[predictions_df['is_correct'] == 0]['confidence'].mean()),
    'min_confidence_incorrect': float(predictions_df[predictions_df['is_correct'] == 0]['confidence'].min()),
    'max_confidence_incorrect': float(predictions_df[predictions_df['is_correct'] == 0]['confidence'].max()),
    'high_confidence_errors': int(((predictions_df['is_correct'] == 0) & (predictions_df['confidence'] > 0.8)).sum()),
    'low_confidence_errors': int(((predictions_df['is_correct'] == 0) & (predictions_df['confidence'] < 0.6)).sum()),
}

with open(save_dir / 'predictions_summary.json', 'w') as f:
    json.dump(summary_stats, f, indent=2)

print(f"\nSummary Statistics:")
print(f"- Average confidence (correct): {summary_stats['avg_confidence_correct']:.4f}")
print(f"- Average confidence (incorrect): {summary_stats['avg_confidence_incorrect']:.4f}")
print(f"- High confidence errors (>0.8): {summary_stats['high_confidence_errors']}")
print(f"- Low confidence errors (<0.6): {summary_stats['low_confidence_errors']}")


Saved predictions with probabilities to: /Volumes/SSanDisk/SpeechRec-German/artifacts/d-t_dl_models_with_context_v2/improved_models/hybrid_cnn_mlp_v4_3_enhanced/test_predictions_with_probs.csv
Total predictions: 19949
Correct predictions: 18836
Incorrect predictions: 1113

Summary Statistics:
- Average confidence (correct): 0.8499
- Average confidence (incorrect): 0.6387
- High confidence errors (>0.8): 115
- Low confidence errors (<0.6): 498


## Save Validation Predictions


In [13]:
# Get validation predictions
val_metrics, val_preds, val_labels, val_probs = evaluate_model(model, val_hybrid_loader, criterion, device)
val_df = df[df['split'] == 'val'].reset_index(drop=True)

val_predictions_data = []
for idx, row in val_df.iterrows():
    val_predictions_data.append({
        'phoneme_id': row['phoneme_id'],
        'utterance_id': row.get('utterance_id', None),
        'phoneme': row.get('phoneme', row.get('class', None)),
        'true_class': row['class'],
        'true_class_encoded': int(val_labels[idx]),
        'predicted_class_encoded': int(val_preds[idx]),
        'predicted_class': 'd' if val_preds[idx] == 0 else 't',
        'prob_class_0': float(val_probs[idx][0]),
        'prob_class_1': float(val_probs[idx][1]),
        'max_prob': float(np.max(val_probs[idx])),
        'is_correct': int(val_labels[idx] == val_preds[idx]),
        'confidence': float(np.max(val_probs[idx])) if val_labels[idx] == val_preds[idx] else float(val_probs[idx][val_preds[idx]]),
        'duration_ms': row.get('duration_ms', None)
    })

val_predictions_df = pd.DataFrame(val_predictions_data)
val_predictions_df.to_csv(save_dir / 'val_predictions_with_probs.csv', index=False)
print(f"Saved validation predictions to: {save_dir / 'val_predictions_with_probs.csv'}")


                                                             

Saved validation predictions to: /Volumes/SSanDisk/SpeechRec-German/artifacts/d-t_dl_models_with_context_v2/improved_models/hybrid_cnn_mlp_v4_3_enhanced/val_predictions_with_probs.csv


## Confusion Matrix Analysis

Visualize confusion matrix to understand model errors per class

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# Create confusion matrix for test set
cm = confusion_matrix(test_labels, test_preds)

# Get class names from label encoder
class_names = le.classes_

# Create figure with two subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Confusion matrix with counts
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names,
    ax=axes[0]
)
axes[0].set_xlabel('Predicted Class', fontsize=12)
axes[0].set_ylabel('True Class', fontsize=12)
axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')

# Plot 2: Confusion matrix with percentages
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
sns.heatmap(
    cm_percent, 
    annot=True, 
    fmt='.1f', 
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names,
    ax=axes[1]
)
axes[1].set_xlabel('Predicted Class', fontsize=12)
axes[1].set_ylabel('True Class', fontsize=12)
axes[1].set_title('Confusion Matrix (Percentages)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(save_dir / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
logger.info(f"Confusion matrix saved to: {save_dir / 'confusion_matrix.png'}")
plt.show()

# Print detailed statistics
logger.info(f"\n{'='*60}")
logger.info(f"Confusion Matrix Analysis:")
logger.info(f"{'='*60}")
for i, true_class in enumerate(class_names):
    total_true = cm[i].sum()
    correct = cm[i, i]
    errors = total_true - correct
    accuracy_per_class = (correct / total_true * 100) if total_true > 0 else 0
    
    logger.info(f"\nTrue Class: {true_class}")
    logger.info(f"  Total samples: {total_true}")
    logger.info(f"  Correctly predicted: {correct} ({accuracy_per_class:.2f}%)")
    logger.info(f"  Incorrectly predicted: {errors}")
    
    # Show error breakdown
    for j, pred_class in enumerate(class_names):
        if i != j and cm[i, j] > 0:
            error_pct = (cm[i, j] / total_true * 100) if total_true > 0 else 0
            logger.info(f"    → Misclassified as '{pred_class}': {cm[i, j]} ({error_pct:.2f}%)")

# Calculate overall metrics
total_samples = cm.sum()
correct_predictions = cm.trace()
total_errors = total_samples - correct_predictions

logger.info(f"\n{'='*60}")
logger.info(f"Overall Statistics:")
logger.info(f"  Total test samples: {total_samples}")
logger.info(f"  Correct predictions: {correct_predictions} ({correct_predictions/total_samples*100:.2f}%)")
logger.info(f"  Total errors: {total_errors} ({total_errors/total_samples*100:.2f}%)")
