# Hybrid Transformer Fusion Training (V2)

Training hybrid model combining CNN+MLP and Transformer branches:
- CNN branch: processes spectrograms (from HybridCNNMLP_V2)
- MLP branch: processes extracted features (from HybridCNNMLP_V2)
- Transformer branch: processes sequences (from TransformerSequence_V2)
- Learnable weighted fusion to balance all three branches
- Improved training: dropout 0.15-0.2, weight decay, warmup, cosine annealing

## Architecture
- **CNN Branch**: 64→128→256 channels with residual connections
- **MLP Branch**: 512→256→128 neurons
- **Transformer Branch**: d_model=256, 6 layers, dim_feedforward=1024
- **Fusion**: Learnable weights + concatenation → 256→128→64→2

## Training Strategy
- Dropout: 0.15-0.2 to prevent overfitting
- Weight decay: 0.01 for regularization
- Label smoothing: 0.1
- Gradient clipping: 1.0
- Early stopping: patience 15


In [None]:
import sys
from pathlib import Path
import torch
import torch.nn as nn
import json

# Determine project root (parent of notebooks directory)
PROJECT_ROOT = Path.cwd().parent if Path.cwd().name in ['notebooks', 'b-p_first_experiments'] else Path.cwd()
sys.path.insert(0, str(PROJECT_ROOT))

from models.hybrid.hybrid_transformer_fusion import HybridTransformerFusion_V2
from utils.training_utils import train_model, evaluate_model, WarmupCosineScheduler, LabelSmoothingCrossEntropy
from utils.data_loader import load_data, create_dataloaders

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print(f"Using MPS device")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

df, spectrograms_dict, feature_cols, feature_scaler, class_weights_dict = load_data(PROJECT_ROOT)
dataloaders = create_dataloaders(df, spectrograms_dict, feature_cols, feature_scaler, class_weights_dict, batch_size=64)

# Get hybrid and sequence dataloaders
train_hybrid_loader = dataloaders['hybrid']['train']
val_hybrid_loader = dataloaders['hybrid']['val']
test_hybrid_loader = dataloaders['hybrid']['test']

train_sequence_loader = dataloaders['sequence']['train']
val_sequence_loader = dataloaders['sequence']['val']
test_sequence_loader = dataloaders['sequence']['test']

OUTPUT_DIR = PROJECT_ROOT / 'artifacts' / 'b-p_dl_models' / 'improved_models'
class_weights = torch.tensor([class_weights_dict.get('0', class_weights_dict.get(0, 1.0)), 
                              class_weights_dict.get('1', class_weights_dict.get(1, 1.0))], dtype=torch.float32).to(device)

# Create custom dataloader that combines hybrid and sequence data
from torch.utils.data import Dataset, DataLoader

class FusionDataset(Dataset):
    """Combines hybrid and sequence datasets"""
    def __init__(self, hybrid_dataset, sequence_dataset):
        assert len(hybrid_dataset) == len(sequence_dataset), "Datasets must have same length"
        self.hybrid_dataset = hybrid_dataset
        self.sequence_dataset = sequence_dataset
    
    def __len__(self):
        return len(self.hybrid_dataset)
    
    def __getitem__(self, idx):
        (spectrogram, features), label_hybrid = self.hybrid_dataset[idx]
        sequence, label_seq = self.sequence_dataset[idx]
        assert label_hybrid == label_seq, "Labels must match"
        return (spectrogram, features, sequence), label_hybrid

# Create fusion datasets
train_fusion_ds = FusionDataset(train_hybrid_loader.dataset, train_sequence_loader.dataset)
val_fusion_ds = FusionDataset(val_hybrid_loader.dataset, val_sequence_loader.dataset)
test_fusion_ds = FusionDataset(test_hybrid_loader.dataset, test_sequence_loader.dataset)

# Create fusion dataloaders
train_fusion_loader = DataLoader(
    train_fusion_ds, 
    batch_size=64, 
    sampler=train_hybrid_loader.sampler,
    num_workers=0
)
val_fusion_loader = DataLoader(val_fusion_ds, batch_size=64, shuffle=False, num_workers=0)
test_fusion_loader = DataLoader(test_fusion_ds, batch_size=64, shuffle=False, num_workers=0)

print(f"Created fusion dataloaders:")
print(f"  Train: {len(train_fusion_loader)} batches")
print(f"  Val: {len(val_fusion_loader)} batches")
print(f"  Test: {len(test_fusion_loader)} batches")


Using MPS device
Columns in df_phonemes: ['phoneme_id', 'utterance_id', 'phoneme', 'class', 'start_ms', 'end_ms', 'duration_ms', 'audio_path']
Columns in df_features: ['energy_rms', 'energy_rms_std', 'energy_zcr', 'energy_zcr_std', 'spectral_centroid', 'spectral_centroid_std', 'spectral_rolloff', 'spectral_rolloff_std', 'spectral_bandwidth', 'spectral_bandwidth_std', 'formant_f1', 'formant_f2', 'formant_f3', 'formant_f4', 'formant_f1_std', 'formant_f2_std', 'formant_f3_std', 'formant_f4_std', 'spectral_flatness', 'harmonic_noise_ratio', 'zcr_mean', 'energy_cv', 'phoneme_id', 'class', 'duration_ms', 'mfcc_mean_0', 'mfcc_mean_1', 'mfcc_mean_2', 'mfcc_mean_3', 'mfcc_mean_4', 'mfcc_mean_5', 'mfcc_mean_6', 'mfcc_mean_7', 'mfcc_mean_8', 'mfcc_mean_9', 'mfcc_mean_10', 'mfcc_mean_11', 'mfcc_mean_12', 'mfcc_std_0', 'mfcc_std_1', 'mfcc_std_2', 'mfcc_std_3', 'mfcc_std_4', 'mfcc_std_5', 'mfcc_std_6', 'mfcc_std_7', 'mfcc_std_8', 'mfcc_std_9', 'mfcc_std_10', 'mfcc_std_11', 'mfcc_std_12', 'delta_mfcc

## Model: Hybrid Transformer Fusion V2


In [None]:
model = HybridTransformerFusion_V2(
    n_features=len(feature_cols), 
    num_classes=2, 
    dropout=0.15,
    d_model=256,
    nhead=8,
    num_layers=6,
    dim_feedforward=1024,
    batch_first=True
).to(device)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {model.get_config()['model_type']}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Loss function with label smoothing
criterion = LabelSmoothingCrossEntropy(smoothing=0.1, weight=class_weights)

# Optimizer with weight decay for regularization
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

# Learning rate scheduler with warmup and cosine annealing
num_epochs = 110
warmup_epochs = 10
scheduler = WarmupCosineScheduler(optimizer, warmup_epochs=warmup_epochs, total_epochs=num_epochs, min_lr=1e-6)

save_dir = OUTPUT_DIR / 'hybrid_transformer_fusion_v2'
save_dir.mkdir(parents=True, exist_ok=True)

print(f"\nTraining configuration:")
print(f"- Epochs: {num_epochs}")
print(f"- Warmup epochs: {warmup_epochs}")
print(f"- Initial LR: {optimizer.param_groups[0]['lr']}")
print(f"- Label smoothing: 0.1")
print(f"- Dropout: 0.15")
print(f"- Weight decay: 0.01")
print(f"- Gradient clipping: 1.0")
print(f"- Early stopping patience: 15")


Model: HybridTransformerFusion_V2
Total parameters: 6,417,477
Trainable parameters: 6,417,477

Training configuration:
- Epochs: 110
- Warmup epochs: 10
- Initial LR: 0.0001
- Label smoothing: 0.1
- Dropout: 0.15
- Weight decay: 0.01
- Gradient clipping: 1.0
- Early stopping patience: 15


In [None]:
history, best_epoch = train_model(
    model, train_fusion_loader, val_fusion_loader, criterion, optimizer, scheduler,
    device, num_epochs=num_epochs, save_dir=save_dir, model_name='hybrid_transformer_fusion_v2', 
    early_stopping_patience=15, max_grad_norm=1.0
)

checkpoint = torch.load(save_dir / 'best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
test_metrics, _, _, _ = evaluate_model(model, test_fusion_loader, criterion, device)

with open(save_dir / 'test_metrics.json', 'w') as f:
    json.dump(test_metrics, f, indent=2)

print(f"\n{'='*60}")
print(f"Final Test Results:")
print(f"{'='*60}")
print(f"Accuracy: {test_metrics['accuracy']:.4f}")
print(f"F1-score: {test_metrics['f1']:.4f}")
print(f"ROC-AUC: {test_metrics['roc_auc']:.4f}")
print(f"Precision: {test_metrics['precision']:.4f}")
print(f"Recall: {test_metrics['recall']:.4f}")
print(f"Best epoch: {best_epoch}")

# Print fusion weights
print(f"\n{'='*60}")
print(f"Fusion Weights (CNN, MLP, Transformer):")
print(f"{'='*60}")
fusion_weights = torch.softmax(model.fusion_weights, dim=0)
print(f"CNN branch: {fusion_weights[0].item():.4f}")
print(f"MLP branch: {fusion_weights[1].item():.4f}")
print(f"Transformer branch: {fusion_weights[2].item():.4f}")



Epoch 1/110
--------------------------------------------------


                                                           

Train Loss: 0.4259, Train Acc: 0.8750
Val Loss: 0.3505, Val Acc: 0.9093
Val F1: 0.9115, Val ROC-AUC: 0.9767
Learning Rate: 0.000010
✓ New best model saved! (F1: 0.9115)

Epoch 2/110
--------------------------------------------------


                                                           

Train Loss: 0.3628, Train Acc: 0.9182
Val Loss: 0.3350, Val Acc: 0.9201
Val F1: 0.9218, Val ROC-AUC: 0.9810
Learning Rate: 0.000020
✓ New best model saved! (F1: 0.9218)

Epoch 3/110
--------------------------------------------------


                                                           

Train Loss: 0.3581, Train Acc: 0.9238
Val Loss: 0.3327, Val Acc: 0.9201
Val F1: 0.9219, Val ROC-AUC: 0.9818
Learning Rate: 0.000030
✓ New best model saved! (F1: 0.9219)

Epoch 4/110
--------------------------------------------------


                                                           

Train Loss: 0.3525, Train Acc: 0.9260
Val Loss: 0.3267, Val Acc: 0.9288
Val F1: 0.9301, Val ROC-AUC: 0.9832
Learning Rate: 0.000040
✓ New best model saved! (F1: 0.9301)

Epoch 5/110
--------------------------------------------------


                                                           

Train Loss: 0.3425, Train Acc: 0.9339
Val Loss: 0.3369, Val Acc: 0.9164
Val F1: 0.9184, Val ROC-AUC: 0.9824
Learning Rate: 0.000050

Epoch 6/110
--------------------------------------------------


                                                           

Train Loss: 0.3394, Train Acc: 0.9361
Val Loss: 0.3270, Val Acc: 0.9271
Val F1: 0.9286, Val ROC-AUC: 0.9847
Learning Rate: 0.000060

Epoch 7/110
--------------------------------------------------


                                                           

Train Loss: 0.3374, Train Acc: 0.9357
Val Loss: 0.3159, Val Acc: 0.9381
Val F1: 0.9390, Val ROC-AUC: 0.9851
Learning Rate: 0.000070
✓ New best model saved! (F1: 0.9390)

Epoch 8/110
--------------------------------------------------


                                                           

Train Loss: 0.3305, Train Acc: 0.9414
Val Loss: 0.3184, Val Acc: 0.9359
Val F1: 0.9369, Val ROC-AUC: 0.9853
Learning Rate: 0.000080

Epoch 9/110
--------------------------------------------------


                                                           

Train Loss: 0.3240, Train Acc: 0.9459
Val Loss: 0.3183, Val Acc: 0.9400
Val F1: 0.9407, Val ROC-AUC: 0.9846
Learning Rate: 0.000090
✓ New best model saved! (F1: 0.9407)

Epoch 10/110
--------------------------------------------------


                                                           

Train Loss: 0.3229, Train Acc: 0.9461
Val Loss: 0.3122, Val Acc: 0.9466
Val F1: 0.9470, Val ROC-AUC: 0.9862
Learning Rate: 0.000100
✓ New best model saved! (F1: 0.9470)

Epoch 11/110
--------------------------------------------------


                                                           

Train Loss: 0.3216, Train Acc: 0.9473
Val Loss: 0.3208, Val Acc: 0.9456
Val F1: 0.9458, Val ROC-AUC: 0.9855
Learning Rate: 0.000100

Epoch 12/110
--------------------------------------------------


                                                           

Train Loss: 0.3148, Train Acc: 0.9522
Val Loss: 0.3214, Val Acc: 0.9372
Val F1: 0.9381, Val ROC-AUC: 0.9849
Learning Rate: 0.000100

Epoch 13/110
--------------------------------------------------


                                                           

Train Loss: 0.3102, Train Acc: 0.9543
Val Loss: 0.3108, Val Acc: 0.9496
Val F1: 0.9499, Val ROC-AUC: 0.9874
Learning Rate: 0.000100
✓ New best model saved! (F1: 0.9499)

Epoch 14/110
--------------------------------------------------


                                                           

Train Loss: 0.3028, Train Acc: 0.9589
Val Loss: 0.3221, Val Acc: 0.9494
Val F1: 0.9495, Val ROC-AUC: 0.9844
Learning Rate: 0.000100

Epoch 15/110
--------------------------------------------------


                                                           

Train Loss: 0.3039, Train Acc: 0.9581
Val Loss: 0.3139, Val Acc: 0.9445
Val F1: 0.9451, Val ROC-AUC: 0.9857
Learning Rate: 0.000099

Epoch 16/110
--------------------------------------------------


                                                           

Train Loss: 0.2982, Train Acc: 0.9622
Val Loss: 0.3191, Val Acc: 0.9438
Val F1: 0.9443, Val ROC-AUC: 0.9854
Learning Rate: 0.000099

Epoch 17/110
--------------------------------------------------


                                                           

Train Loss: 0.2912, Train Acc: 0.9671
Val Loss: 0.3144, Val Acc: 0.9413
Val F1: 0.9421, Val ROC-AUC: 0.9866
Learning Rate: 0.000099

Epoch 18/110
--------------------------------------------------


                                                           

Train Loss: 0.2919, Train Acc: 0.9665
Val Loss: 0.3160, Val Acc: 0.9494
Val F1: 0.9496, Val ROC-AUC: 0.9857
Learning Rate: 0.000098

Epoch 19/110
--------------------------------------------------


                                                           

Train Loss: 0.2880, Train Acc: 0.9698
Val Loss: 0.3210, Val Acc: 0.9481
Val F1: 0.9483, Val ROC-AUC: 0.9839
Learning Rate: 0.000098

Epoch 20/110
--------------------------------------------------


                                                           

Train Loss: 0.2856, Train Acc: 0.9710
Val Loss: 0.3202, Val Acc: 0.9473
Val F1: 0.9476, Val ROC-AUC: 0.9863
Learning Rate: 0.000098

Epoch 21/110
--------------------------------------------------


                                                           

Train Loss: 0.2839, Train Acc: 0.9713
Val Loss: 0.3160, Val Acc: 0.9479
Val F1: 0.9483, Val ROC-AUC: 0.9866
Learning Rate: 0.000097

Epoch 22/110
--------------------------------------------------


                                                           

Train Loss: 0.2832, Train Acc: 0.9723
Val Loss: 0.3168, Val Acc: 0.9449
Val F1: 0.9454, Val ROC-AUC: 0.9853
Learning Rate: 0.000097

Epoch 23/110
--------------------------------------------------


                                                           

Train Loss: 0.2767, Train Acc: 0.9757
Val Loss: 0.3217, Val Acc: 0.9475
Val F1: 0.9478, Val ROC-AUC: 0.9847
Learning Rate: 0.000096

Epoch 24/110
--------------------------------------------------


                                                           

Train Loss: 0.2773, Train Acc: 0.9753
Val Loss: 0.3197, Val Acc: 0.9471
Val F1: 0.9475, Val ROC-AUC: 0.9839
Learning Rate: 0.000095

Epoch 25/110
--------------------------------------------------


                                                           

Train Loss: 0.2752, Train Acc: 0.9770
Val Loss: 0.3277, Val Acc: 0.9520
Val F1: 0.9519, Val ROC-AUC: 0.9807
Learning Rate: 0.000095
✓ New best model saved! (F1: 0.9519)

Epoch 26/110
--------------------------------------------------


                                                           

Train Loss: 0.2726, Train Acc: 0.9792
Val Loss: 0.3229, Val Acc: 0.9528
Val F1: 0.9527, Val ROC-AUC: 0.9843
Learning Rate: 0.000094
✓ New best model saved! (F1: 0.9527)

Epoch 27/110
--------------------------------------------------


                                                           

Train Loss: 0.2740, Train Acc: 0.9775
Val Loss: 0.3169, Val Acc: 0.9481
Val F1: 0.9486, Val ROC-AUC: 0.9851
Learning Rate: 0.000093

Epoch 28/110
--------------------------------------------------


                                                           

Train Loss: 0.2679, Train Acc: 0.9810
Val Loss: 0.3211, Val Acc: 0.9492
Val F1: 0.9494, Val ROC-AUC: 0.9834
Learning Rate: 0.000092

Epoch 29/110
--------------------------------------------------


                                                           

Train Loss: 0.2680, Train Acc: 0.9805
Val Loss: 0.3249, Val Acc: 0.9545
Val F1: 0.9544, Val ROC-AUC: 0.9826
Learning Rate: 0.000091
✓ New best model saved! (F1: 0.9544)

Epoch 30/110
--------------------------------------------------


                                                           

Train Loss: 0.2671, Train Acc: 0.9815
Val Loss: 0.3201, Val Acc: 0.9451
Val F1: 0.9457, Val ROC-AUC: 0.9848
Learning Rate: 0.000091

Epoch 31/110
--------------------------------------------------


                                                           

Train Loss: 0.2646, Train Acc: 0.9833
Val Loss: 0.3206, Val Acc: 0.9518
Val F1: 0.9520, Val ROC-AUC: 0.9830
Learning Rate: 0.000090

Epoch 32/110
--------------------------------------------------


                                                           

Train Loss: 0.2642, Train Acc: 0.9826
Val Loss: 0.3216, Val Acc: 0.9505
Val F1: 0.9507, Val ROC-AUC: 0.9826
Learning Rate: 0.000089

Epoch 33/110
--------------------------------------------------


                                                           

Train Loss: 0.2630, Train Acc: 0.9833
Val Loss: 0.3185, Val Acc: 0.9522
Val F1: 0.9524, Val ROC-AUC: 0.9835
Learning Rate: 0.000088

Epoch 34/110
--------------------------------------------------


                                                           

Train Loss: 0.2621, Train Acc: 0.9839
Val Loss: 0.3331, Val Acc: 0.9359
Val F1: 0.9369, Val ROC-AUC: 0.9814
Learning Rate: 0.000087

Epoch 35/110
--------------------------------------------------


                                                           

Train Loss: 0.2608, Train Acc: 0.9855
Val Loss: 0.3186, Val Acc: 0.9515
Val F1: 0.9517, Val ROC-AUC: 0.9838
Learning Rate: 0.000086

Epoch 36/110
--------------------------------------------------


                                                           

Train Loss: 0.2595, Train Acc: 0.9863
Val Loss: 0.3257, Val Acc: 0.9513
Val F1: 0.9513, Val ROC-AUC: 0.9807
Learning Rate: 0.000084

Epoch 37/110
--------------------------------------------------


                                                           

Train Loss: 0.2599, Train Acc: 0.9856
Val Loss: 0.3267, Val Acc: 0.9516
Val F1: 0.9517, Val ROC-AUC: 0.9818
Learning Rate: 0.000083

Epoch 38/110
--------------------------------------------------


                                                           

Train Loss: 0.2574, Train Acc: 0.9867
Val Loss: 0.3237, Val Acc: 0.9507
Val F1: 0.9510, Val ROC-AUC: 0.9804
Learning Rate: 0.000082

Epoch 39/110
--------------------------------------------------


                                                           

Train Loss: 0.2571, Train Acc: 0.9871
Val Loss: 0.3272, Val Acc: 0.9509
Val F1: 0.9510, Val ROC-AUC: 0.9781
Learning Rate: 0.000081

Epoch 40/110
--------------------------------------------------


                                                           

Train Loss: 0.2567, Train Acc: 0.9871
Val Loss: 0.3236, Val Acc: 0.9471
Val F1: 0.9475, Val ROC-AUC: 0.9796
Learning Rate: 0.000080

Epoch 41/110
--------------------------------------------------


                                                           

Train Loss: 0.2544, Train Acc: 0.9889
Val Loss: 0.3210, Val Acc: 0.9537
Val F1: 0.9538, Val ROC-AUC: 0.9786
Learning Rate: 0.000078

Epoch 42/110
--------------------------------------------------


                                                           

Train Loss: 0.2535, Train Acc: 0.9895
Val Loss: 0.3228, Val Acc: 0.9526
Val F1: 0.9526, Val ROC-AUC: 0.9824
Learning Rate: 0.000077

Epoch 43/110
--------------------------------------------------


                                                           

Train Loss: 0.2535, Train Acc: 0.9894
Val Loss: 0.3227, Val Acc: 0.9513
Val F1: 0.9515, Val ROC-AUC: 0.9813
Learning Rate: 0.000076

Epoch 44/110
--------------------------------------------------


                                                           

Train Loss: 0.2517, Train Acc: 0.9902
Val Loss: 0.3241, Val Acc: 0.9528
Val F1: 0.9530, Val ROC-AUC: 0.9779
Learning Rate: 0.000074

Early stopping at epoch 44
Best F1: 0.9544 at epoch 29


                                                           


Final Test Results:
Accuracy: 0.9471
F1-score: 0.9470
ROC-AUC: 0.9775
Precision: 0.9470
Recall: 0.9471
Best epoch: 29

Fusion Weights (CNN, MLP, Transformer):
CNN branch: 0.3466
MLP branch: 0.3644
Transformer branch: 0.2890


