<a href="https://colab.research.google.com/github/Eric-rWang/VivoX/blob/main/PPG_Lightweight_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Lightweight Transformer Implementation

In [None]:
# imports
import os
import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import time

**Model Architecture:**

**Input:**
(batch_size, 1, 36, window_size) (36 channels: 12 locations × 3 wavelengths)

**Channel-wise Embedding:**
Stack of 1D convolutions and pooling compresses each channel’s temporal waveform into a single 256-dimensional feature vector.

**Transformer Encoder:**
A 2-layer Transformer encoder (with 4 attention heads) models relationships between these compressed features.

**Regression Head:**
Fully connected layers map the output to 2 values (SvO2 and SaO2).

**Flow:**
Input → Conv1d/Pooling → Feature Vector → Transformer → Output Head → Prediction

In [None]:
class LightweightTransformer(nn.Module):
    def __init__(self, num_channels=36, window_size=350):
        super().__init__()

        # 1. Channel-wise embedding with proper output dimension
        self.embed = nn.Sequential(
            nn.Conv1d(num_channels, 64, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(64, 128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )

        # 2. Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=256,
            nhead=4,
            dim_feedforward=512,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)

        # 3. Output head
        self.head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

    def forward(self, x):
        # Input: (B, 1, C, T)
        x = x.squeeze(1)  # (B, C, T)
        x = self.embed(x)  # (B, 256, 1)
        x = x.squeeze(-1)  # (B, 256)
        x = x.unsqueeze(1)  # (B, 1, 256)
        x = self.transformer(x)  # (B, 1, 256)
        x = x.squeeze(1)  # (B, 256)
        return self.head(x)

Supporting functions for script


In [None]:
def verify_model(model):
    model.eval()
    try:
        # Test with random data matching our expected input shape
        dummy_input = torch.randn(2, 1, 36, 350)
        output = model(dummy_input)
        assert output.shape == (2, 2), f"Bad output shape: {output.shape}"
        print(f"Model verification passed")
        return True
    except Exception as e:
        print(f"Model verification failed: {str(e)}")
        return False

In [None]:
def import_h5py(file_path):
    with h5py.File(file_path, 'r') as f:
        X = f['waveforms'][:]
        y = f['labels'][:]
        return X, y

Configurations

In [None]:
USE_SEPARATE_TEST_FILE = True  # Set to False to use single file approach

if USE_SEPARATE_TEST_FILE:
    train_val_file_path = "DATA/JUL22_recordings.h5"
    test_file_path = "DATA/JUL22_test.h5"
    print(f"Loading training/validation data from {train_val_file_path} ...")
    print(f"Loading test data from {test_file_path} ...")
else:
    file_path = "DATA/jul14th_shift_sensor_data_eric.h5"
    print(f"Loading data from {file_path} ...")

Model hyperparameters

In [None]:
# Hyperparameters
num_channels = 36
window_size = 350
batch_size = 32
lr = 1e-4
weight_decay = 1e-4
num_epochs = 5 # was 300
n_patience = 50

# Models to train
MODELS = {
    "LightTransformer": LightweightTransformer
}

Data loading and preparation

In [None]:
# --- Data Loading and Preparation ---
if USE_SEPARATE_TEST_FILE:
    # Load training/validation data
    combined_data, labels_array = import_h5py(train_val_file_path)
    print(f"Training/Val data loaded ✅\nData shape: {combined_data.shape}, Labels shape: {labels_array.shape}")

    # Load test data separately
    X_test, y_test = import_h5py(test_file_path)
    print(f"Test data loaded ✅\nData shape: {X_test.shape}, Labels shape: {y_test.shape}")

    # Split the training/validation data into train and val (75/25 split)
    X_train, X_val, y_train, y_val = train_test_split(
        combined_data, labels_array,
        test_size=0.25,
        random_state=42
    )
else:
    # Original single file approach
    combined_data, labels_array = import_h5py(file_path)
    print(f"Loaded ✅\nData shape: {combined_data.shape}, Labels shape: {labels_array.shape}")

    # Use 100% of combined_data: 60% train, 20% val, 20% test
    X_train, X_temp, y_train, y_temp = train_test_split(
        combined_data, labels_array,
        train_size=0.6,
        random_state=42
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp,
        test_size=0.5,
        random_state=42
    )

print("Training data range:")
print(f"X_train: {X_train.min()} to {X_train.max()}")
print(f"y_train: {y_train.min()} to {y_train.max()}")
print("\nTest data range:")
print(f"X_test: {X_test.min()} to {X_test.max()}")
print(f"y_test: {y_test.min()} to {y_test.max()}")

Diagnostic

In [None]:
# Diagnostic: original train‐set class counts
print(f"\nOriginal training set class counts:")
labels, counts = np.unique(y_train, axis=0, return_counts=True)
for lbl, cnt in zip(labels, counts):
    print(f"  {lbl.tolist()}: {cnt} samples")

Balancing

In [None]:
# --- Balance TRAINING SET only via median‐based under/oversampling ---
target = int(np.median(counts))
print(f"\nBalancing training set to {target} samples per class (median count)")
balanced_X, balanced_y = [], []

def jitter(x, σ=0.02):
    return x + np.random.normal(0, σ*np.std(x), size=x.shape)
def time_shift(x, max_shift=15):
    s = np.random.randint(-max_shift, max_shift+1)
    return np.roll(x, s, axis=0)

for lbl, cnt in zip(labels, counts):
    idxs = np.where((y_train == lbl).all(axis=1))[0]
    X_lbl = X_train[idxs]
    y_lbl = y_train[idxs]
    # undersample if too big
    if cnt > target:
        chosen = np.random.choice(idxs, target, replace=False)
        balanced_X.append(X_train[chosen])
        balanced_y.append(y_train[chosen])
    # oversample + augment if too small
    elif cnt < target:
        balanced_X.append(X_lbl)
        balanced_y.append(y_lbl)
        n_to_gen = target - cnt
        for _ in range(n_to_gen):
            i = np.random.choice(idxs)
            x_aug = time_shift(jitter(X_train[i]))
            balanced_X.append(x_aug[None])
            balanced_y.append(lbl[None])
    else:
        balanced_X.append(X_lbl)
        balanced_y.append(y_lbl)

# concatenate back
X_train = np.vstack(balanced_X)
y_train = np.vstack(balanced_y)

Post balancing diagnostic

In [None]:
# Diagnostic: neue training und test set Klassenquantitäten
print("Post‐balance training set class counts:")
new_labels, new_counts = np.unique(y_train, axis=0, return_counts=True)
for lbl, old_cnt, new_cnt in zip(labels, counts, new_counts):
    delta = new_cnt - old_cnt
    pct = delta/old_cnt*100
    print(f"  {lbl.tolist()}: {new_cnt} samples ({'+' if delta>=0 else ''}{delta}, {pct:.1f}%)")

Normalization


In [None]:
# Better input normalization (per‐channel)
def normalize_ppg(X):
    median = np.median(X, axis=2, keepdims=True)
    mad    = 1.4826 * np.median(np.abs(X - median), axis=2, keepdims=True)
    return (X - median) / (mad + 1e-6)

X_train = normalize_ppg(X_train)
X_val   = normalize_ppg(X_val)
X_test  = normalize_ppg(X_test)

# Label normalization
y_max   = 100.0
y_train = y_train / y_max
y_val   = y_val   / y_max
y_test  = y_test  / y_max

print("Post-scaling X range:", X_train.min(), X_train.max())
print("Post-scaling y range:", y_train.min(), y_train.max())

In [None]:
# Save the splits for later
np.savez(os.path.join(save_dir, 'data_splits.npz'),
         X_train=X_train, y_train=y_train,
         X_val=X_val,     y_val=y_val,
         X_test=X_test,   y_test=y_test)

In [None]:
# Convert to tensors and reshape
X_train_t = torch.from_numpy(X_train).float()
y_train_t = torch.from_numpy(y_train).float()
X_val_t = torch.from_numpy(X_val).float()
y_val_t = torch.from_numpy(y_val).float()
X_test_t = torch.from_numpy(X_test).float()
y_test_t = torch.from_numpy(y_test).float()

In [None]:
# Reshape to (N, 1, 36, 350)
X_train_rs = X_train_t.permute(0, 2, 1).unsqueeze(1)
X_val_rs = X_val_t.permute(0, 2, 1).unsqueeze(1)
X_test_rs = X_test_t.permute(0, 2, 1).unsqueeze(1)

In [None]:
# Normalize using only training statistics und so
mean = X_train_rs.mean(dim=(0, 1, 3), keepdim=True)
std = X_train_rs.std(dim=(0, 1, 3), keepdim=True)
X_train_rs = (X_train_rs - mean) / (std + 1e-6)
X_val_rs = (X_val_rs - mean) / (std + 1e-6)
X_test_rs = (X_test_rs - mean) / (std + 1e-6)

In [None]:
# DataLoaders
train_loader = DataLoader(TensorDataset(X_train_rs, y_train_t), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val_rs, y_val_t), batch_size=batch_size)
test_loader = DataLoader(TensorDataset(X_test_rs, y_test_t), batch_size=batch_size)

In [None]:
# --- Training Loop für alle Modelle ---
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = torch.device('mps')
    print("🍏 Using Apple Silicon GPU (MPS)")
else:
    device = torch.device('cpu')
    print("💀 Falling back to CPU? Yikes!")

for model_name, model_class in MODELS.items():

    print(f"\n=== Training {model_name} ===")
    model = model_class(
        num_channels=num_channels,
        window_size=window_size
    ).to(device)

    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)

    best_val_loss = float('inf')
    patience_counter = 0
    train_losses = []
    val_losses = []
    test_losses = []  # für jede epoch den test loss speichern aber halt auch den g
    epoch_times = []  # Track epoch durations

    for epoch in range(1, num_epochs + 1):
        epoch_start_time = time.time()

        # Training phase
        model.train()
        epoch_train_loss = 0.0
        for Xb, yb in train_loader:
            Xb, yb = Xb.to(device), yb.to(device)
            optimizer.zero_grad()
            preds = model(Xb)
            loss = criterion(preds, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_train_loss += loss.item() * Xb.size(0)
        train_losses.append(epoch_train_loss / len(train_loader.dataset))

        # Validation phase
        model.eval()
        epoch_val_loss = 0.0
        with torch.no_grad():
            for Xb, yb in val_loader:
                Xb, yb = Xb.to(device), yb.to(device)
                epoch_val_loss += criterion(model(Xb), yb).item() * Xb.size(0)
        val_losses.append(epoch_val_loss / len(val_loader.dataset))

        # TEST PHASE - NOW EVALUATED EVERY EPOCH
        epoch_test_loss = 0.0
        with torch.no_grad():
            for Xb, yb in test_loader:
                Xb, yb = Xb.to(device), yb.to(device)
                epoch_test_loss += criterion(model(Xb), yb).item() * Xb.size(0)
        test_losses.append(epoch_test_loss / len(test_loader.dataset))

        # Calculate epoch duration and ETA; einfach nur für den aktuellen Epoch
        epoch_duration = time.time() - epoch_start_time
        epoch_times.append(epoch_duration)

        # Calculate ETA based on average epoch time
        avg_epoch_time = np.mean(epoch_times)
        remaining_epochs = num_epochs - epoch
        eta_seconds = remaining_epochs * avg_epoch_time
        eta_mins = eta_seconds / 60

        # Update scheduler and early stopping
        scheduler.step(val_losses[-1])
        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            patience_counter = 0
            status = ''
        else:
            patience_counter += 1
            status = '⚪️'
            if patience_counter >= n_patience:
                print(f"\nEarly stopping triggered at epoch {epoch}!")
                break

        print(f"{status} {model_name} Epoch {epoch:3d}: Train {train_losses[-1]:.4f} | Val {val_losses[-1]:.4f} | Test {test_losses[-1]:.4f} | {epoch_duration:.1f}s/epoch | ETA: {eta_mins:.1f}mins ({eta_seconds:.0f}s)")

    print(f"\n{model_name} Training complete!")
    print(f"Best val loss: {best_val_loss:.4f}")

    # Save model and training history
    torch.save({
        'model_state': model.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'test_losses': test_losses,  # Full test loss trajectory - weil wir es jetzt jedes Epoch speichern
        'config': {
            'num_channels': num_channels,
            'window_size': window_size,
            'model_name': model_name
        }
    }, os.path.join(save_dir, f'{model_name}.pt'))

print("\nAll models trained and saved!")