## 1. Setup and Installation

In [None]:
# Install required packages
!pip install -q torch torchvision timm transformers
!pip install -q librosa soundfile albumentations
!pip install -q scikit-learn matplotlib seaborn tqdm
!pip install -q efficientnet-pytorch

In [None]:
# Import libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

# Image processing
from PIL import Image
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Audio processing
import librosa
import soundfile as sf
import parselmouth
from parselmouth.praat import call

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import timm
from transformers import Wav2Vec2Model, Wav2Vec2Processor

# Metrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report, roc_auc_score, roc_curve
)

# Set seed
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Clone Dataset and Load Pre-trained Models

In [None]:
# Clone repository
!git clone https://github.com/Tvenkatathanuj/SDP.git

# Paths
HANDWRITING_PATH = '/content/SDP/handwritten dataset/Dataset/Dataset'
SPEECH_PATH = '/content/SDP/speech dataset'

print("Dataset cloned successfully!")

In [None]:
# Download pre-trained models from GitHub repository
print("Downloading pre-trained models from GitHub...")

# These models should be trained first using the individual notebooks
# The model files will be in the root of the repository after training

# For now, we'll use placeholder paths - you need to train individual models first
print("\n⚠️ IMPORTANT: Before running fusion model, you must:")
print("1. Train handwriting model (handwriting_parkinsons_detection.ipynb)")
print("2. Train speech model (speech_parkinsons_detection.ipynb)")
print("3. Download the .pth files from Colab")
print("4. Upload them to this Colab session or push to GitHub")
print("\nOnce models are trained, the .pth files should be in the same directory.")
print("✓ Ready to load models!")

## 3. Define Individual Model Architectures

In [None]:
# Copy handwriting model architecture
class CBAM(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(CBAM, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        self.spatial_conv = nn.Conv2d(2, 1, 7, padding=3, bias=False)
    
    def forward(self, x):
        # Channel attention
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        channel_att = self.sigmoid(avg_out + max_out)
        x = x * channel_att
        
        # Spatial attention
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_att = self.sigmoid(self.spatial_conv(torch.cat([avg_out, max_out], dim=1)))
        x = x * spatial_att
        return x

class SpatialPyramidPooling(nn.Module):
    def __init__(self, pool_sizes=[1, 2, 4]):
        super(SpatialPyramidPooling, self).__init__()
        self.pool_sizes = pool_sizes
    
    def forward(self, x):
        batch_size, channels, _, _ = x.size()
        pools = []
        for pool_size in self.pool_sizes:
            pool = F.adaptive_avg_pool2d(x, (pool_size, pool_size))
            pool = pool.view(batch_size, channels, -1)
            pools.append(pool)
        return torch.cat(pools, dim=2)

class HandwritingModel(nn.Module):
    def __init__(self, num_classes=2):
        super(HandwritingModel, self).__init__()
        self.backbone = timm.create_model('efficientnet_b4', pretrained=False, features_only=True)
        
        # Get actual output channels dynamically
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            features = self.backbone(dummy_input)
            feature_dim = features[-1].shape[1]
        
        self.cbam = CBAM(feature_dim)
        self.spp = SpatialPyramidPooling()
        spp_dim = feature_dim * 21  # Dynamic calculation
        
        self.feature_extractor = nn.Sequential(
            nn.Linear(spp_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
        )
        
        # Updated to match improved model
        self.classifier = nn.Sequential(
            nn.Linear(spp_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.6),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x, return_features=False):
        features = self.backbone(x)
        x = features[-1]
        x = self.cbam(x)
        x = self.spp(x)
        x = x.view(x.size(0), -1)
        
        if return_features:
            return self.feature_extractor(x)
        return self.classifier(x)

print("Handwriting model architecture loaded!")

In [None]:
# Copy speech model architecture
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads=8, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        attended, _ = self.attention(x, x, x)
        return self.norm(x + self.dropout(attended))

class SpeechModel(nn.Module):
    def __init__(self, num_classes=2, acoustic_feature_dim=110):
        super(SpeechModel, self).__init__()
        self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        
        self.lstm = nn.LSTM(768, 256, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
        self.attention = MultiHeadAttention(512, num_heads=8)
        
        # Updated deeper acoustic branch
        self.acoustic_branch = nn.Sequential(
            nn.Linear(acoustic_feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
        )
        
        fusion_dim = 512 + 128
        self.feature_extractor = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, input_values, acoustic_features, return_features=False):
        wav2vec_out = self.wav2vec(input_values).last_hidden_state
        lstm_out, _ = self.lstm(wav2vec_out)
        attended = self.attention(lstm_out)
        pooled = torch.mean(attended, dim=1)
        acoustic_out = self.acoustic_branch(acoustic_features)
        fused = torch.cat([pooled, acoustic_out], dim=1)
        
        if return_features:
            return self.feature_extractor(fused)
        return self.classifier(fused)

print("Speech model architecture loaded!")

## 4. Novel Fusion Architecture: Cross-Modal Attention Fusion

In [None]:
class CrossModalAttention(nn.Module):
    """Cross-modal attention for handwriting-speech feature fusion"""
    def __init__(self, dim1, dim2, hidden_dim=256):
        super(CrossModalAttention, self).__init__()
        
        # Project features to common space
        self.proj1 = nn.Linear(dim1, hidden_dim)
        self.proj2 = nn.Linear(dim2, hidden_dim)
        
        # Cross-attention
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        
        self.scale = hidden_dim ** -0.5
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, feat1, feat2):
        # Project to common space
        f1 = self.proj1(feat1)  # (B, hidden_dim)
        f2 = self.proj2(feat2)  # (B, hidden_dim)
        
        # Cross-attention: feat1 attends to feat2
        q = self.query(f1).unsqueeze(1)  # (B, 1, hidden_dim)
        k = self.key(f2).unsqueeze(1)    # (B, 1, hidden_dim)
        v = self.value(f2).unsqueeze(1)  # (B, 1, hidden_dim)
        
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        attended_f2 = torch.matmul(attn, v).squeeze(1)
        
        # Reverse: feat2 attends to feat1
        q2 = self.query(f2).unsqueeze(1)
        k2 = self.key(f1).unsqueeze(1)
        v2 = self.value(f1).unsqueeze(1)
        
        attn2 = torch.matmul(q2, k2.transpose(-2, -1)) * self.scale
        attn2 = F.softmax(attn2, dim=-1)
        attn2 = self.dropout(attn2)
        
        attended_f1 = torch.matmul(attn2, v2).squeeze(1)
        
        return attended_f1, attended_f2

class UncertaintyModule(nn.Module):
    """Uncertainty quantification using Monte Carlo Dropout"""
    def __init__(self, input_dim, dropout_rate=0.3):
        super(UncertaintyModule, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(input_dim, 2)  # Confidence scores
    
    def forward(self, x, n_samples=10):
        # Monte Carlo Dropout
        predictions = []
        for _ in range(n_samples):
            dropped = self.dropout(x)
            pred = self.fc(dropped)
            predictions.append(F.softmax(pred, dim=1))
        
        # Mean and variance
        predictions = torch.stack(predictions)
        mean_pred = predictions.mean(dim=0)
        uncertainty = predictions.var(dim=0).mean(dim=1)  # Average variance across classes
        
        return mean_pred, uncertainty

class MultimodalFusionModel(nn.Module):
    """Novel Cross-Modal Attention Fusion with Uncertainty Quantification"""
    def __init__(self, handwriting_model, speech_model, num_classes=2):
        super(MultimodalFusionModel, self).__init__()
        
        self.handwriting_model = handwriting_model
        self.speech_model = speech_model
        
        # Freeze individual models initially
        for param in self.handwriting_model.parameters():
            param.requires_grad = False
        for param in self.speech_model.parameters():
            param.requires_grad = False
        
        # Cross-modal attention
        self.cross_attention = CrossModalAttention(dim1=512, dim2=512, hidden_dim=256)
        
        # Uncertainty modules
        self.uncertainty_hand = UncertaintyModule(256)
        self.uncertainty_speech = UncertaintyModule(256)
        
        # Adaptive fusion weights
        self.fusion_weights = nn.Parameter(torch.tensor([0.5, 0.5]))
        
        # Final fusion classifier
        fusion_dim = 256 * 2  # Two modalities
        self.fusion_classifier = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(256, num_classes)
        )
        
        # Ensemble predictor
        self.ensemble = nn.Linear(num_classes * 3, num_classes)  # 3 predictions
    
    def forward(self, image=None, audio=None, acoustic_feat=None, mode='fusion'):
        if mode == 'handwriting_only':
            return self.handwriting_model(image)
        
        elif mode == 'speech_only':
            return self.speech_model(audio, acoustic_feat)
        
        elif mode == 'fusion':
            # Extract features from both modalities
            hand_features = self.handwriting_model(image, return_features=True)
            speech_features = self.speech_model(audio, acoustic_feat, return_features=True)
            
            # Cross-modal attention
            attended_hand, attended_speech = self.cross_attention(hand_features, speech_features)
            
            # Get predictions with uncertainty
            hand_pred, hand_uncertainty = self.uncertainty_hand(attended_hand)
            speech_pred, speech_uncertainty = self.uncertainty_speech(attended_speech)
            
            # Adaptive weighting based on uncertainty (lower uncertainty = higher weight)
            hand_confidence = 1.0 / (1.0 + hand_uncertainty)
            speech_confidence = 1.0 / (1.0 + speech_uncertainty)
            
            total_confidence = hand_confidence + speech_confidence
            hand_weight = hand_confidence / total_confidence
            speech_weight = speech_confidence / total_confidence
            
            # Weighted predictions
            weighted_pred = hand_weight.unsqueeze(1) * hand_pred + speech_weight.unsqueeze(1) * speech_pred
            
            # Fusion features
            fused_features = torch.cat([attended_hand, attended_speech], dim=1)
            fusion_pred = self.fusion_classifier(fused_features)
            
            # Ensemble
            ensemble_input = torch.cat([hand_pred, speech_pred, F.softmax(fusion_pred, dim=1)], dim=1)
            final_pred = self.ensemble(ensemble_input)
            
            return final_pred, {
                'hand_pred': hand_pred,
                'speech_pred': speech_pred,
                'fusion_pred': fusion_pred,
                'weighted_pred': weighted_pred,
                'hand_uncertainty': hand_uncertainty,
                'speech_uncertainty': speech_uncertainty,
                'hand_weight': hand_weight,
                'speech_weight': speech_weight
            }

print("✓ Multimodal Fusion Model architecture created!")

## 5. Multimodal Dataset

In [None]:
def extract_acoustic_features(audio_path, sr=16000):
    """Extract acoustic features (110 dimensions)"""
    try:
        y, sr = librosa.load(audio_path, sr=sr)
        
        # MFCCs (40 mean + 40 std = 80)
        mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=40)
        mfcc_mean = np.mean(mfccs, axis=1)
        mfcc_std = np.std(mfccs, axis=1)
        
        # Chroma (12 mean + 12 std = 24)
        chroma = librosa.feature.chroma_stft(y=y, sr=sr, n_chroma=12)
        chroma_mean = np.mean(chroma, axis=1)
        chroma_std = np.std(chroma, axis=1)
        
        # Spectral features (4)
        spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))
        spectral_rolloff = np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr))
        spectral_bandwidth = np.mean(librosa.feature.spectral_bandwidth(y=y, sr=sr))
        zcr = np.mean(librosa.feature.zero_crossing_rate(y))
        
        # Jitter and Shimmer (2)
        try:
            sound = parselmouth.Sound(audio_path)
            pitch = call(sound, "To Pitch", 0.0, 75, 600)
            point_process = call(sound, "To PointProcess (periodic, cc)", 75, 600)
            jitter = call(point_process, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3)
            shimmer = call([sound, point_process], "Get shimmer (local)", 0, 0, 0.0001, 0.02, 1.3, 1.6)
        except:
            jitter, shimmer = 0, 0
        
        # Total: 80 + 24 + 4 + 2 = 110
        feature_vector = np.concatenate([
            mfcc_mean, mfcc_std,
            chroma_mean, chroma_std,
            [spectral_centroid, spectral_rolloff, spectral_bandwidth, zcr, jitter, shimmer]
        ])
        return feature_vector
    except:
        return np.zeros(110)

class MultimodalDataset(Dataset):
    def __init__(self, handwriting_df, speech_df, processor, image_transform, max_audio_length=80000):
        # Match samples by label (since we have different counts)
        self.samples = []
        
        # For each speech sample, pair with random handwriting sample of same class
        for _, speech_row in speech_df.iterrows():
            label = speech_row['label']
            hand_sample = handwriting_df[handwriting_df['label'] == label].sample(1).iloc[0]
            
            self.samples.append({
                'image_path': hand_sample['image_path'],
                'audio_path': speech_row['audio_path'],
                'label': label
            })
        
        self.processor = processor
        self.image_transform = image_transform
        self.max_audio_length = max_audio_length
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load and process image
        image = cv2.imread(sample['image_path'])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.image_transform(image=image)['image']
        
        # Load and process audio
        y, sr = librosa.load(sample['audio_path'], sr=16000)
        if len(y) < self.max_audio_length:
            y = np.pad(y, (0, self.max_audio_length - len(y)))
        else:
            y = y[:self.max_audio_length]
        
        audio_input = self.processor(y, sampling_rate=16000, return_tensors="pt", padding=True)
        
        # Extract acoustic features
        acoustic_features = extract_acoustic_features(sample['audio_path'])
        
        return {
            'image': image,
            'audio': audio_input.input_values.squeeze(0),
            'acoustic_features': torch.FloatTensor(acoustic_features),
            'label': torch.LongTensor([sample['label']])[0]
        }

print("Multimodal dataset class created!")

## 6. Prepare Data

In [None]:
# Prepare handwriting data
hand_data = []
for img in os.listdir(os.path.join(HANDWRITING_PATH, 'Healthy')):
    hand_data.append({'image_path': os.path.join(HANDWRITING_PATH, 'Healthy', img), 'label': 0})
for img in os.listdir(os.path.join(HANDWRITING_PATH, 'Parkinson')):
    hand_data.append({'image_path': os.path.join(HANDWRITING_PATH, 'Parkinson', img), 'label': 1})
hand_df = pd.DataFrame(hand_data)

# Prepare speech data
speech_data = []
for audio in os.listdir(os.path.join(SPEECH_PATH, 'HC_AH/HC_AH')):
    if audio.endswith('.wav'):
        speech_data.append({'audio_path': os.path.join(SPEECH_PATH, 'HC_AH/HC_AH', audio), 'label': 0})
for audio in os.listdir(os.path.join(SPEECH_PATH, 'PD_AH/PD_AH')):
    if audio.endswith('.wav'):
        speech_data.append({'audio_path': os.path.join(SPEECH_PATH, 'PD_AH/PD_AH', audio), 'label': 1})
speech_df = pd.DataFrame(speech_data)

print(f"Handwriting samples: {len(hand_df)}")
print(f"Speech samples: {len(speech_df)}")

# Split speech data (limiting factor)
train_speech, temp_speech = train_test_split(speech_df, test_size=0.3, stratify=speech_df['label'], random_state=42)
val_speech, test_speech = train_test_split(temp_speech, test_size=0.5, stratify=temp_speech['label'], random_state=42)

print(f"\nTrain: {len(train_speech)}, Val: {len(val_speech)}, Test: {len(test_speech)}")

In [None]:
# Image transform (matching handwriting model)
image_transform = A.Compose([
    A.Resize(224, 224),  # Match handwriting training
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

# Wav2Vec processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

# Create datasets
train_dataset = MultimodalDataset(hand_df, train_speech, processor, image_transform)
val_dataset = MultimodalDataset(hand_df, val_speech, processor, image_transform)
test_dataset = MultimodalDataset(hand_df, test_speech, processor, image_transform)

# DataLoaders
BATCH_SIZE = 4
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print("✓ Multimodal DataLoaders ready!")

## 7. Load Pre-trained Models and Create Fusion Model

In [None]:
# Load handwriting model
handwriting_model = HandwritingModel(num_classes=2).to(device)
hand_checkpoint = torch.load('handwriting_parkinsons_model_final.pth', map_location=device)
handwriting_model.load_state_dict(hand_checkpoint['model_state_dict'])
print(f"✓ Handwriting model loaded (Acc: {hand_checkpoint['test_acc']:.4f})")

# Load speech model
speech_model = SpeechModel(num_classes=2).to(device)
speech_checkpoint = torch.load('speech_parkinsons_model_final.pth', map_location=device)
speech_model.load_state_dict(speech_checkpoint['model_state_dict'])
print(f"✓ Speech model loaded (Acc: {speech_checkpoint['test_acc']:.4f})")

# Create fusion model
fusion_model = MultimodalFusionModel(handwriting_model, speech_model, num_classes=2).to(device)
print(f"\n✓ Fusion model created!")
print(f"Trainable parameters: {sum(p.numel() for p in fusion_model.parameters() if p.requires_grad):,}")

## 8. Training Functions for Fusion Model

In [None]:
def train_fusion_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    pbar = tqdm(dataloader, desc='Training Fusion')
    for batch in pbar:
        images = batch['image'].to(device)
        audios = batch['audio'].to(device)
        acoustic_feats = batch['acoustic_features'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs, aux_outputs = model(images, audios, acoustic_feats, mode='fusion')
        
        # Main loss
        loss = criterion(outputs, labels)
        
        # Auxiliary losses for individual modality predictions
        aux_loss = 0.3 * (criterion(aux_outputs['hand_pred'], labels) + 
                          criterion(aux_outputs['speech_pred'], labels) +
                          criterion(aux_outputs['fusion_pred'], labels))
        
        total_loss = loss + aux_loss
        
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        running_loss += total_loss.item()
        pbar.set_postfix({'loss': running_loss / (pbar.n + 1)})
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    return epoch_loss, epoch_acc

def validate_fusion_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validating Fusion')
        for batch in pbar:
            images = batch['image'].to(device)
            audios = batch['audio'].to(device)
            acoustic_feats = batch['acoustic_features'].to(device)
            labels = batch['label'].to(device)
            
            outputs, _ = model(images, audios, acoustic_feats, mode='fusion')
            loss = criterion(outputs, labels)
            
            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            running_loss += loss.item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            
            pbar.set_postfix({'loss': running_loss / (pbar.n + 1)})
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    return epoch_loss, epoch_acc, all_labels, all_preds, all_probs

print("Training functions ready!")

## 9. Train Fusion Model

In [None]:
# Training setup with improved parameters
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, fusion_model.parameters()), 
                              lr=5e-5, weight_decay=5e-4)  # Lower LR, higher weight decay
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

# Early stopping
early_stopping_patience = 15
best_val_loss = float('inf')
patience_counter = 0

NUM_EPOCHS = 50
best_val_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

print("\n" + "="*60)
print("TRAINING MULTIMODAL FUSION MODEL")
print("="*60)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 50)
    
    train_loss, train_acc = train_fusion_epoch(fusion_model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, _, _, _ = validate_fusion_epoch(fusion_model, val_loader, criterion, device)
    
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    # Save based on validation loss for better generalization
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_val_acc = val_acc
        patience_counter = 0
        torch.save({
            'model_state_dict': fusion_model.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'epoch': epoch
        }, 'best_fusion_model.pth')
        print(f"✓ Fusion model saved with val_acc: {val_acc:.4f}, val_loss: {val_loss:.4f}")
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{early_stopping_patience}")
    
    # Early stopping
    if patience_counter >= early_stopping_patience:
        print(f"\nEarly stopping triggered after {epoch+1} epochs")
        break

print(f"\nBest Validation Loss: {best_val_loss:.4f}")
print(f"Best Validation Accuracy: {best_val_acc:.4f}")

## 10. Final Evaluation

In [None]:
# Load best model
checkpoint = torch.load('best_fusion_model.pth')
fusion_model.load_state_dict(checkpoint['model_state_dict'])

# Test evaluation
test_loss, test_acc, y_true, y_pred, y_probs = validate_fusion_epoch(fusion_model, test_loader, criterion, device)

print(f"\n{'='*70}")
print(f"MULTIMODAL FUSION MODEL - FINAL TEST RESULTS")
print(f"{'='*70}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Loss: {test_loss:.4f}")
print(f"\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=['Healthy', 'Parkinson'], digits=4))

# Metrics
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
auc = roc_auc_score(y_true, y_probs)

print(f"\n{'='*70}")
print(f"PERFORMANCE METRICS")
print(f"{'='*70}")
print(f"Precision: {precision:.4f}")
print(f"Recall (Sensitivity): {recall:.4f}")
print(f"Specificity: {confusion_matrix(y_true, y_pred)[0,0]/(confusion_matrix(y_true, y_pred)[0,0]+confusion_matrix(y_true, y_pred)[0,1]):.4f}")
print(f"F1-Score: {f1:.4f}")
print(f"AUC-ROC: {auc:.4f}")

# Compare with individual models
print(f"\n{'='*70}")
print(f"MODEL COMPARISON")
print(f"{'='*70}")
print(f"Handwriting Model Accuracy: {hand_checkpoint['test_acc']:.4f}")
print(f"Speech Model Accuracy: {speech_checkpoint['test_acc']:.4f}")
print(f"Fusion Model Accuracy: {test_acc:.4f}")
print(f"\nImprovement over best individual: {(test_acc - max(hand_checkpoint['test_acc'], speech_checkpoint['test_acc']))*100:.2f}%")

## 11. Comprehensive Visualization

In [None]:
# Create comprehensive visualization
fig = plt.figure(figsize=(20, 12))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# Training history
ax1 = fig.add_subplot(gs[0, :])
ax1.plot(history['train_acc'], label='Train Acc', marker='o', linewidth=2)
ax1.plot(history['val_acc'], label='Val Acc', marker='s', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Accuracy', fontsize=12)
ax1.set_title('Fusion Model Training History', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Confusion Matrix
ax2 = fig.add_subplot(gs[1, 0])
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='RdYlGn', xticklabels=['Healthy', 'Parkinson'],
            yticklabels=['Healthy', 'Parkinson'], ax=ax2, cbar_kws={'label': 'Count'})
ax2.set_xlabel('Predicted', fontsize=11)
ax2.set_ylabel('Actual', fontsize=11)
ax2.set_title(f'Confusion Matrix\nAcc: {test_acc:.4f}', fontsize=12, fontweight='bold')

# ROC Curve
ax3 = fig.add_subplot(gs[1, 1])
fpr, tpr, _ = roc_curve(y_true, y_probs)
ax3.plot(fpr, tpr, label=f'Fusion (AUC={auc:.4f})', linewidth=3, color='darkgreen')
ax3.plot([0, 1], [0, 1], 'k--', label='Random', linewidth=2)
ax3.set_xlabel('False Positive Rate', fontsize=11)
ax3.set_ylabel('True Positive Rate', fontsize=11)
ax3.set_title('ROC Curve', fontsize=12, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)

# Model Comparison
ax4 = fig.add_subplot(gs[1, 2])
models = ['Handwriting', 'Speech', 'Fusion']
accuracies = [hand_checkpoint['test_acc'], speech_checkpoint['test_acc'], test_acc]
colors = ['#3498db', '#e74c3c', '#2ecc71']
bars = ax4.bar(models, accuracies, color=colors, edgecolor='black', linewidth=2)
ax4.set_ylabel('Accuracy', fontsize=11)
ax4.set_title('Model Comparison', fontsize=12, fontweight='bold')
ax4.set_ylim([0, 1])
for i, bar in enumerate(bars):
    height = bar.get_height()
    ax4.text(bar.get_x() + bar.get_width()/2., height + 0.02,
             f'{accuracies[i]:.4f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
ax4.grid(True, axis='y', alpha=0.3)

# Metrics comparison
ax5 = fig.add_subplot(gs[2, :])
metrics_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'AUC-ROC']
metrics_values = [test_acc, precision, recall, f1, auc]
x_pos = np.arange(len(metrics_names))
bars = ax5.barh(x_pos, metrics_values, color='#16a085', edgecolor='black', linewidth=2)
ax5.set_yticks(x_pos)
ax5.set_yticklabels(metrics_names, fontsize=11)
ax5.set_xlabel('Score', fontsize=11)
ax5.set_title('Comprehensive Performance Metrics', fontsize=12, fontweight='bold')
ax5.set_xlim([0, 1])
for i, bar in enumerate(bars):
    width = bar.get_width()
    ax5.text(width + 0.02, bar.get_y() + bar.get_height()/2.,
             f'{metrics_values[i]:.4f}', ha='left', va='center', fontsize=10, fontweight='bold')
ax5.grid(True, axis='x', alpha=0.3)

plt.suptitle('Multimodal Fusion Model - Complete Analysis', fontsize=16, fontweight='bold', y=0.995)
plt.savefig('fusion_model_complete_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n✓ Comprehensive visualization saved!")

## 12. Save Final Model

In [None]:
# Save complete fusion model
torch.save({
    'fusion_model_state_dict': fusion_model.state_dict(),
    'handwriting_model_state_dict': handwriting_model.state_dict(),
    'speech_model_state_dict': speech_model.state_dict(),
    'test_metrics': {
        'accuracy': test_acc,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'auc_roc': auc
    },
    'individual_models': {
        'handwriting_acc': hand_checkpoint['test_acc'],
        'speech_acc': speech_checkpoint['test_acc']
    }
}, 'multimodal_fusion_parkinsons_final.pth')

print("\n" + "="*70)
print("✓ COMPLETE MULTIMODAL SYSTEM SAVED")
print("="*70)
print("\nFiles saved:")
print("1. multimodal_fusion_parkinsons_final.pth (Complete fusion model)")
print("2. fusion_model_complete_analysis.png (Visualization)")
print("\nFinal Performance Summary:")
print(f"  • Fusion Accuracy: {test_acc:.4f}")
print(f"  • AUC-ROC: {auc:.4f}")
print(f"  • F1-Score: {f1:.4f}")
print(f"  • Precision: {precision:.4f}")
print(f"  • Recall: {recall:.4f}")
print("\n" + "="*70)