# Structural Heart Disease Multimodal Model - Training on Google Colab

This notebook trains a multimodal deep learning model for structural heart disease prediction using:
- **A1**: ECG Transformer Encoder (HuBERT-ECG pretrained)
- **A2**: Tabular Encoder (FTTransformer)
- **A3**: Gated Multimodal Fusion + Multi-label Prediction

---

## üìã Table of Contents
1. [Setup & Installation](#setup)
2. [Data Upload/Mounting](#data)
3. [Model Configuration](#config)
4. [Training](#training)
5. [Evaluation & Visualization](#evaluation)
6. [Download Results](#download)

## 1. Setup & Installation

First, let's check if we're running on Colab and set up GPU acceleration.

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è WARNING: No GPU detected. Training will be slow on CPU.")
    print("Go to Runtime > Change runtime type > Hardware accelerator > GPU")

In [None]:
# Install required packages
print("Installing dependencies...")
!pip install -q transformers>=4.30.0
!pip install -q tab-transformer-pytorch>=0.2.0
!pip install -q tensorboard

print("‚úì Installation complete!")

### Clone Repository (Option 1: From GitHub)

In [None]:
# If your code is on GitHub, clone it here
# Uncomment and modify the following lines:

# !git clone https://github.com/YOUR_USERNAME/structural_heart_disease.git
# %cd structural_heart_disease

# For now, we'll create the necessary files directly
print("Skipping git clone - will upload files manually or mount from Drive")

### Upload Code Files (Option 2: Manual Upload)

In [None]:
# Create project structure
import os

os.makedirs('src', exist_ok=True)
os.makedirs('echonext_dataset', exist_ok=True)
os.makedirs('outputs', exist_ok=True)

print("‚úì Project structure created")
print("\nüìÅ Please upload the following files:")
print("  - src/models.py")
print("  - src/dataset.py")
print("  - src/utils.py")
print("  - src/__init__.py")
print("\nUse the file upload button on the left sidebar or run the cell below.")

In [None]:
# Upload source files
from google.colab import files

print("Upload src/models.py:")
uploaded = files.upload()
for filename in uploaded.keys():
    !mv {filename} src/models.py

print("\nUpload src/dataset.py:")
uploaded = files.upload()
for filename in uploaded.keys():
    !mv {filename} src/dataset.py

print("\nUpload src/utils.py:")
uploaded = files.upload()
for filename in uploaded.keys():
    !mv {filename} src/utils.py

# Create __init__.py
!touch src/__init__.py

print("\n‚úì Source files uploaded!")

## 2. Data Upload/Mounting

Choose one of the following options to access your data.

### Option A: Mount Google Drive (Recommended for large datasets)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Update this path to where your data is stored in Google Drive
DRIVE_DATA_PATH = '/content/drive/MyDrive/echonext_dataset'

# Create symbolic link
!ln -s {DRIVE_DATA_PATH} echonext_dataset

print(f"‚úì Data mounted from: {DRIVE_DATA_PATH}")

### Option B: Upload Data Files Manually (For smaller datasets)

In [None]:
# Upload data files one by one
# WARNING: This can be slow for large .npy files!

from google.colab import files
import shutil

print("Upload EchoNext_metadata_100k.csv:")
uploaded = files.upload()
for filename in uploaded.keys():
    shutil.move(filename, 'echonext_dataset/EchoNext_metadata_100k.csv')

print("\nUpload training waveforms (.npy):")
uploaded = files.upload()
for filename in uploaded.keys():
    shutil.move(filename, 'echonext_dataset/EchoNext_train_waveforms.npy')

# Continue for other files...
print("\n‚ö†Ô∏è Note: Upload remaining .npy files using the same pattern")

### Verify Data Files

In [None]:
# Check if all required files exist
import os

required_files = [
    'echonext_dataset/EchoNext_metadata_100k.csv',
    'echonext_dataset/EchoNext_train_waveforms.npy',
    'echonext_dataset/EchoNext_train_tabular_features.npy',
    'echonext_dataset/EchoNext_val_waveforms.npy',
    'echonext_dataset/EchoNext_val_tabular_features.npy',
    'echonext_dataset/EchoNext_test_waveforms.npy',
    'echonext_dataset/EchoNext_test_tabular_features.npy',
]

print("Checking data files:")
all_present = True
for filepath in required_files:
    exists = os.path.exists(filepath)
    status = "‚úì" if exists else "‚úó"
    size = f"{os.path.getsize(filepath) / 1e6:.1f} MB" if exists else "Missing"
    print(f"{status} {filepath}: {size}")
    if not exists:
        all_present = False

if all_present:
    print("\n‚úì All data files present!")
else:
    print("\n‚ö†Ô∏è Some files are missing. Please upload them before continuing.")

## 3. Model Configuration

Configure training hyperparameters and model architecture.

In [None]:
# Training Configuration
config = {
    # Data
    'data_dir': './echonext_dataset',
    'batch_size': 16,  # Reduced for Colab GPU memory
    'num_workers': 2,  # Colab has limited CPU cores
    
    # Model - ECG Encoder
    'ecg_model_size': 'large',  # Options: 'small', 'base', 'large'
    'ecg_embed_dim': 256,
    'ecg_freeze': False,  # Set True to freeze pretrained weights
    'ecg_use_pretrained': True,
    
    # Model - Tabular Encoder
    'tabular_dim': 32,
    'tabular_depth': 2,
    'tabular_heads': 4,
    'tabular_output_dim': 128,
    
    # Model - Fusion
    'fusion_dim': 256,
    
    # Training
    'num_epochs': 50,  # Reduced for Colab time limits
    'lr': 1e-4,
    'weight_decay': 1e-5,
    'warmup_epochs': 3,
    'dropout': 0.1,
    
    # Loss
    'loss_type': 'asymmetric',  # Options: 'bce', 'focal', 'asymmetric'
    'use_pos_weights': True,
    
    # Regularization
    'early_stopping_patience': 10,
    
    # Computational
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'mixed_precision': True,  # Use AMP for faster training
    
    # Output
    'output_dir': './outputs',
    'save_freq': 5,
    
    # Evaluation
    'eval_uncertainty': True,
    'uncertainty_samples': 20,
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## 4. Training

Now let's train the model!

In [None]:
# Import necessary modules
import sys
sys.path.append('.')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from tqdm.notebook import tqdm
import numpy as np
import json

from src.models import SHDMultimodalModel
from src.dataset import get_dataloaders, EchoNextDataset
from src.utils import (
    compute_metrics, compute_calibration_metrics,
    FocalLoss, AsymmetricLoss, EarlyStopping,
    save_checkpoint, load_checkpoint, AverageMeter,
    get_pos_weights
)

print("‚úì Imports successful!")

In [None]:
# Load data
print("Loading datasets...")
train_loader, val_loader, test_loader = get_dataloaders(
    data_dir=config['data_dir'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers']
)

print(f"\n‚úì Data loaded:")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

In [None]:
# Create model
print("Creating model...")

device = torch.device(config['device'])

model_config = {
    'ecg_config': {
        'model_size': config['ecg_model_size'],
        'embed_dim': config['ecg_embed_dim'],
        'freeze_encoder': config['ecg_freeze'],
        'use_pretrained': config['ecg_use_pretrained'],
        'pooling': 'mean',
    },
    'tabular_config': {
        'dim': config['tabular_dim'],
        'depth': config['tabular_depth'],
        'heads': config['tabular_heads'],
        'output_dim': config['tabular_output_dim'],
        'attn_dropout': config['dropout'],
        'ff_dropout': config['dropout'],
    },
    'fusion_config': {
        'ecg_dim': config['ecg_embed_dim'],
        'tabular_dim': config['tabular_output_dim'],
        'output_dim': config['fusion_dim'],
    },
    'prediction_config': {
        'input_dim': config['fusion_dim'],
        'num_labels': len(EchoNextDataset.LABEL_COLUMNS),
        'dropout': config['dropout'],
    }
}

model = SHDMultimodalModel(**model_config)
model = model.to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n‚úì Model created with {num_params:,} trainable parameters")

In [None]:
# Setup training components
print("Setting up training...")

# Loss function
if config['use_pos_weights']:
    train_dataset = train_loader.dataset
    train_labels = train_dataset.labels
    pos_weights = get_pos_weights(train_labels).to(device)
    print(f"Using positive weights for class imbalance")

if config['loss_type'] == 'bce':
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights if config['use_pos_weights'] else None)
elif config['loss_type'] == 'focal':
    criterion = FocalLoss(alpha=0.25, gamma=2.0)
elif config['loss_type'] == 'asymmetric':
    criterion = AsymmetricLoss(gamma_neg=4.0, gamma_pos=1.0)

# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=config['lr'],
    weight_decay=config['weight_decay']
)

# Learning rate scheduler with warmup
num_training_steps = len(train_loader) * config['num_epochs']
num_warmup_steps = len(train_loader) * config['warmup_epochs']

def lr_lambda(current_step):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Mixed precision scaler
scaler = torch.cuda.amp.GradScaler() if config['mixed_precision'] else None

# Early stopping
early_stopping = EarlyStopping(
    patience=config['early_stopping_patience'],
    mode='max'  # Maximize validation AUROC
)

# Create output directory
output_dir = Path(config['output_dir'])
output_dir.mkdir(parents=True, exist_ok=True)

# Save configuration
with open(output_dir / 'config.json', 'w') as f:
    json.dump(config, f, indent=2)

print("‚úì Training setup complete!")

In [None]:
# Training function
def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, scaler=None):
    """Train for one epoch."""
    model.train()
    losses = AverageMeter()
    
    pbar = tqdm(dataloader, desc='Training', leave=False)
    for batch in pbar:
        waveform = batch['waveform'].to(device)
        tabular = batch['tabular'].to(device)
        tabular_mask = batch['tabular_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        if scaler is not None:
            with torch.cuda.amp.autocast():
                output = model(waveform, tabular, tabular_mask)
                loss = criterion(output['logits'], labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(waveform, tabular, tabular_mask)
            loss = criterion(output['logits'], labels)
            loss.backward()
            optimizer.step()
        
        if scheduler is not None:
            scheduler.step()
        
        losses.update(loss.item(), waveform.size(0))
        pbar.set_postfix({'loss': losses.avg})
    
    return losses.avg


@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    """Evaluate model."""
    model.eval()
    losses = AverageMeter()
    all_labels = []
    all_probs = []
    all_calibrated_probs = []
    all_fusion_gates = []
    
    pbar = tqdm(dataloader, desc='Evaluating', leave=False)
    for batch in pbar:
        waveform = batch['waveform'].to(device)
        tabular = batch['tabular'].to(device)
        tabular_mask = batch['tabular_mask'].to(device)
        labels = batch['labels'].to(device)
        
        output = model(waveform, tabular, tabular_mask)
        loss = criterion(output['logits'], labels)
        
        losses.update(loss.item(), waveform.size(0))
        all_labels.append(labels.cpu().numpy())
        all_probs.append(output['probs'].cpu().numpy())
        all_calibrated_probs.append(output['calibrated_probs'].cpu().numpy())
        all_fusion_gates.append(output['fusion_gates'].cpu().numpy())
    
    all_labels = np.concatenate(all_labels, axis=0)
    all_probs = np.concatenate(all_probs, axis=0)
    all_calibrated_probs = np.concatenate(all_calibrated_probs, axis=0)
    all_fusion_gates = np.concatenate(all_fusion_gates, axis=0)
    
    metrics = compute_metrics(all_labels, all_calibrated_probs, EchoNextDataset.LABEL_COLUMNS)
    calibration_metrics = compute_calibration_metrics(all_labels, all_calibrated_probs)
    
    metrics.update(calibration_metrics)
    metrics['loss'] = losses.avg
    metrics['avg_ecg_gate'] = all_fusion_gates[:, 0].mean()
    metrics['avg_tabular_gate'] = all_fusion_gates[:, 1].mean()
    
    return metrics

print("‚úì Training functions defined")

In [None]:
# Main training loop
print("Starting training...\n")

# Load TensorBoard extension
%load_ext tensorboard

# Start TensorBoard
writer = SummaryWriter(log_dir=output_dir / 'logs')

best_val_auroc = 0.0
training_history = []

for epoch in range(config['num_epochs']):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{config['num_epochs']}")
    print(f"{'='*60}")
    
    # Train
    train_loss = train_epoch(
        model, train_loader, criterion, optimizer, scheduler, device, scaler
    )
    
    # Evaluate
    val_metrics = evaluate(model, val_loader, criterion, device)
    
    # Log metrics
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/val', val_metrics['loss'], epoch)
    writer.add_scalar('AUROC/val_macro', val_metrics.get('macro_auroc', 0), epoch)
    writer.add_scalar('AUPRC/val_macro', val_metrics.get('macro_auprc', 0), epoch)
    writer.add_scalar('Calibration/val_ece', val_metrics.get('mean_ece', 0), epoch)
    writer.add_scalar('FusionGates/ecg', val_metrics['avg_ecg_gate'], epoch)
    writer.add_scalar('FusionGates/tabular', val_metrics['avg_tabular_gate'], epoch)
    
    # Print metrics
    print(f"\nTrain Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f}")
    print(f"Val Macro AUROC: {val_metrics.get('macro_auroc', 0):.4f}")
    print(f"Val Macro AUPRC: {val_metrics.get('macro_auprc', 0):.4f}")
    print(f"Val ECE: {val_metrics.get('mean_ece', 0):.4f}")
    print(f"Fusion Gates - ECG: {val_metrics['avg_ecg_gate']:.3f}, Tabular: {val_metrics['avg_tabular_gate']:.3f}")
    
    # Save history
    training_history.append({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'val_loss': val_metrics['loss'],
        'val_auroc': val_metrics.get('macro_auroc', 0),
        'val_auprc': val_metrics.get('macro_auprc', 0),
    })
    
    # Save best model
    val_auroc = val_metrics.get('macro_auroc', 0)
    if val_auroc > best_val_auroc:
        best_val_auroc = val_auroc
        save_checkpoint(
            model, optimizer, epoch, val_metrics,
            output_dir / 'best_model.pt'
        )
        print(f"‚úì New best model saved (AUROC: {val_auroc:.4f})")
    
    # Save checkpoint periodically
    if (epoch + 1) % config['save_freq'] == 0:
        save_checkpoint(
            model, optimizer, epoch, val_metrics,
            output_dir / f'checkpoint_epoch_{epoch+1}.pt'
        )
    
    # Early stopping
    if early_stopping(val_auroc):
        print(f"\n‚ö†Ô∏è Early stopping triggered at epoch {epoch+1}")
        break

writer.close()
print("\n" + "="*60)
print("‚úì Training complete!")
print(f"Best validation AUROC: {best_val_auroc:.4f}")
print("="*60)

### View TensorBoard

In [None]:
# Launch TensorBoard
%tensorboard --logdir outputs/logs

## 5. Evaluation & Visualization

Evaluate the best model on the test set.

In [None]:
# Load best model and evaluate on test set
print("Evaluating on test set...\n")

load_checkpoint(model, None, output_dir / 'best_model.pt', device)
test_metrics = evaluate(model, test_loader, criterion, device)

# Print test results
print("\n" + "="*60)
print("TEST SET RESULTS")
print("="*60)
print(f"Test Loss: {test_metrics['loss']:.4f}")
print(f"Test Macro AUROC: {test_metrics.get('macro_auroc', 0):.4f}")
print(f"Test Macro AUPRC: {test_metrics.get('macro_auprc', 0):.4f}")
print(f"Test ECE: {test_metrics.get('mean_ece', 0):.4f}")
print(f"Test MCE: {test_metrics.get('mean_mce', 0):.4f}")

print("\nPer-label AUROC:")
for label in EchoNextDataset.LABEL_COLUMNS:
    auroc = test_metrics.get(f'{label}_auroc', None)
    if auroc is not None:
        print(f"  {label}: {auroc:.4f}")

# Save test results
with open(output_dir / 'test_results.json', 'w') as f:
    json.dump(test_metrics, f, indent=2)

print(f"\n‚úì Results saved to {output_dir}")

In [None]:
# Plot training history
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

epochs = [h['epoch'] for h in training_history]
train_losses = [h['train_loss'] for h in training_history]
val_losses = [h['val_loss'] for h in training_history]
val_aurocs = [h['val_auroc'] for h in training_history]

# Loss plot
axes[0].plot(epochs, train_losses, label='Train Loss', marker='o')
axes[0].plot(epochs, val_losses, label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# AUROC plot
axes[1].plot(epochs, val_aurocs, label='Val AUROC', marker='o', color='green')
axes[1].axhline(y=best_val_auroc, color='r', linestyle='--', label=f'Best: {best_val_auroc:.4f}')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('AUROC')
axes[1].set_title('Validation AUROC')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / 'training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Training history plotted")

In [None]:
# Plot per-label performance
import matplotlib.pyplot as plt

labels = EchoNextDataset.LABEL_COLUMNS
aurocs = [test_metrics.get(f'{label}_auroc', 0) for label in labels]
auprcs = [test_metrics.get(f'{label}_auprc', 0) for label in labels]

fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(labels))
width = 0.35

ax.bar(x - width/2, aurocs, width, label='AUROC', alpha=0.8)
ax.bar(x + width/2, auprcs, width, label='AUPRC', alpha=0.8)

ax.set_xlabel('Label')
ax.set_ylabel('Score')
ax.set_title('Per-Label Performance on Test Set')
ax.set_xticks(x)
ax.set_xticklabels([l.replace('_flag', '').replace('_', ' ')[:20] for l in labels], rotation=45, ha='right')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
ax.set_ylim([0, 1])

plt.tight_layout()
plt.savefig(output_dir / 'per_label_performance.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Per-label performance plotted")

## 6. Download Results

Download trained model and results to your local machine.

In [None]:
# Create a zip file with all results
import shutil

print("Creating results archive...")

# Create archive
shutil.make_archive('shd_training_results', 'zip', output_dir)

print("‚úì Archive created: shd_training_results.zip")
print("\nContents:")
!unzip -l shd_training_results.zip | head -20

In [None]:
# Download the results
from google.colab import files

print("Downloading results...")
files.download('shd_training_results.zip')

print("\n‚úì Download started!")
print("\nThe archive contains:")
print("  - best_model.pt (trained model weights)")
print("  - config.json (training configuration)")
print("  - test_results.json (test metrics)")
print("  - training_history.png (loss/AUROC plots)")
print("  - per_label_performance.png (per-label metrics)")
print("  - logs/ (TensorBoard logs)")

## 7. Inference Example (Optional)

Run inference on a single sample.

In [None]:
# Get a single sample from test set
model.eval()

sample = test_loader.dataset[0]
waveform = sample['waveform'].unsqueeze(0).to(device)
tabular = sample['tabular'].unsqueeze(0).to(device)
tabular_mask = sample['tabular_mask'].unsqueeze(0).to(device)
true_labels = sample['labels'].numpy()

# Run inference
with torch.no_grad():
    output = model(waveform, tabular, tabular_mask, return_embeddings=True)

# Get predictions
probs = output['calibrated_probs'].cpu().numpy()[0]
fusion_gates = output['fusion_gates'].cpu().numpy()[0]

print("Inference Results:")
print("="*60)
print(f"Fusion Gates - ECG: {fusion_gates[0]:.3f}, Tabular: {fusion_gates[1]:.3f}")
print("\nPredictions:")
print(f"{'Label':<50} {'True':<6} {'Pred':<6}")
print("-"*60)
for i, label in enumerate(EchoNextDataset.LABEL_COLUMNS):
    print(f"{label:<50} {int(true_labels[i]):<6} {probs[i]:.3f}")

print("\n‚úì Inference complete!")

## üéâ Training Complete!

### Next Steps:
1. **Review TensorBoard** for detailed training metrics
2. **Download results** using the cells above
3. **Experiment** with different hyperparameters
4. **Fine-tune** on your specific use case

### Tips for Better Performance:
- Increase `num_epochs` for longer training (watch for overfitting)
- Try different `ecg_model_size` ('small', 'base', 'large')
- Adjust `batch_size` based on GPU memory
- Experiment with different loss functions
- Use `ecg_freeze=True` for faster training with frozen pretrained weights

---

**Questions or issues?** Check the project README or documentation.