<a href="https://colab.research.google.com/github/Eric-rWang/VivoX/blob/main/PPG_VivoX_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

PPG Array Transformer for Arterial/Venous SpO₂ Prediction

In [1]:
!pip install h5py torchsummary
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import h5py
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
from torchsummary import summary
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import math

os.chdir("/content")
print(os.getcwd())

/content


Data Loading + Preprocessing

In [2]:
class PPGDataset(Dataset):
    """Loads PPG array data from HDF5 files"""
    def __init__(self, file_path, augment=True, test_mode=False):
        """
        Args:
            file_path (string): Path to the HDF5 file
            augment (bool): Whether to apply data augmentation
            test_mode (bool): If True, returns additional metadata for testing
        """
        with h5py.File(file_path, 'r') as f:
            self.waveforms = f['waveforms'][:]
            self.labels = f['labels'][:]

        self.augment = augment
        self.test_mode = test_mode

        # Add placeholder for metadata if in test mode
        if test_mode:
            self.positions = np.zeros((len(self.waveforms), 2))  # [arterial_pos, venous_pos]
            self.spacings = np.zeros(len(self.waveforms))        # vessel spacing

    def __len__(self):
        return len(self.waveforms)

    def __getitem__(self, idx):
        x = self.waveforms[idx].astype(np.float32)
        y = self.labels[idx].astype(np.float32)

        if self.augment:
            x = self.augment_sample(x)

        if self.test_mode:
            return torch.tensor(x), torch.tensor(y), torch.tensor(self.positions[idx]), torch.tensor(self.spacings[idx])
        return torch.tensor(x), torch.tensor(y)

    def augment_sample(self, x):
        """Applies physics-based data augmentation"""
        # 1. Random channel shift (simulate sensor placement variation)
        if np.random.rand() > 0.5:
            shift = np.random.randint(-3, 3)
            x = np.roll(x, shift, axis=1)

        # 2. Venous signal inversion (randomly apply to 940nm channels)
        if np.random.rand() > 0.7:
            x[:, 24:] *= -1

        # 3. Distance-based signal attenuation
        vessel_position = np.random.randint(0, 12)  # Random vessel center
        distance = np.abs(np.arange(12) - vessel_position)
        decay = np.exp(-distance/2.0)  # Light decay model

        # Apply decay to all wavelengths
        for i in range(3):
            start_idx = i * 12
            end_idx = (i + 1) * 12
            x[:, start_idx:end_idx] *= decay

        # 4. Add wavelength-specific noise
        noise_levels = [0.02, 0.015, 0.025]  # 660nm, 850nm, 940nm
        for i, noise_level in enumerate(noise_levels):
            start_idx = i * 12
            end_idx = (i + 1) * 12
            noise = np.random.normal(0, noise_level, x[:, start_idx:end_idx].shape)
            x[:, start_idx:end_idx] += noise

        return x

Model Architecture

In [3]:
# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.max_len = max_len
        self.d_model = d_model

        # Create positional encoding buffer
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # FIXED: Corrected div_term calculation with proper parentheses
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        if seq_len > self.max_len:
            # Dynamically extend positional encoding
            self.extend_pe(seq_len)
        x = x + self.pe[:seq_len, :].unsqueeze(0)  # Add batch dimension
        return self.dropout(x)

    def extend_pe(self, seq_len):
        """Extend positional encoding for longer sequences"""
        position = torch.arange(self.max_len, seq_len).float().unsqueeze(1)
        # FIXED: Corrected div_term calculation with proper parentheses
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model))
        new_pe = torch.zeros(len(position), self.d_model)
        new_pe[:, 0::2] = torch.sin(position * div_term)
        new_pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = torch.cat([self.pe, new_pe], dim=0)
        self.max_len = seq_len


# Transformer Model
class PPGArrayTransformer(nn.Module):
    def __init__(self, d_model=126, nhead=6, num_layers=4, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # Wavelength-specific embedding
        self.embed_660 = nn.Linear(1, d_model//3)
        self.embed_850 = nn.Linear(1, d_model//3)
        self.embed_940 = nn.Linear(1, d_model//3)

        # Positional encodings
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=5000)

        # Learnable PD positions
        self.pd_position = nn.Embedding(12, d_model//3)

        # Transformers
        self.spatial_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=256, dropout=dropout),
            num_layers=1
        )
        self.temporal_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=512, dropout=dropout),
            num_layers=num_layers
        )

        # Venous presence detector
        self.venous_detector = nn.Sequential(
            nn.Linear(d_model, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

        # Output heads
        self.art_head = nn.Sequential(
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self.ven_head = nn.Sequential(
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        # Split into wavelengths
        x_660 = x[:, :, :12].unsqueeze(-1)  # (batch, 350, 12, 1)
        x_850 = x[:, :, 12:24].unsqueeze(-1)
        x_940 = x[:, :, 24:].unsqueeze(-1)

        # PD position embeddings
        pd_indices = torch.arange(12).to(x.device)
        pd_emb = self.pd_position(pd_indices)  # (12, d_model//3)
        pd_emb = pd_emb.unsqueeze(0).unsqueeze(0)  # (1, 1, 12, d_model//3)

        # Embed each wavelength
        emb_660 = self.embed_660(x_660) + pd_emb
        emb_850 = self.embed_850(x_850) + pd_emb
        emb_940 = self.embed_940(x_940) + pd_emb

        # Combine embeddings
        x = torch.cat([emb_660, emb_850, emb_940], dim=-1)  # (batch, 350, 12, d_model)

        # Combine time and PD dimensions
        x = x.reshape(batch_size, seq_len * 12, self.d_model)  # (batch, 4200, d_model)
        x = self.pos_encoder(x)

        # Transformer processing
        x = x.permute(1, 0, 2)  # (4200, batch, d_model)
        x = self.spatial_transformer(x)
        temporal_out = self.temporal_transformer(x)
        temporal_out = temporal_out.permute(1, 0, 2)  # (batch, 4200, d_model)

        # Venous presence weighting
        venous_mask = self.venous_detector(temporal_out)  # (batch, 4200, 1)
        temporal_out = temporal_out * venous_mask  # Automatic broadcasting

        # Global pooling
        pooled = temporal_out.mean(dim=1)  # (batch, d_model)

        # Predictions
        art_pred = self.art_head(pooled)
        ven_pred = self.ven_head(pooled)

        return torch.cat([art_pred, ven_pred], dim=1), venous_mask

Training Setup

In [4]:
def physiological_loss(art_pred, ven_pred, art_true, ven_true):
    """Custom loss with physiological constraints"""
    # Base MSE losses
    art_loss = F.mse_loss(art_pred, art_true)
    ven_loss = F.mse_loss(ven_pred, ven_true)

    # Physiological constraint: venous SpO₂ < arterial SpO₂
    violation = torch.relu(ven_pred - art_pred)
    constraint_loss = torch.mean(violation) * 0.5

    return art_loss + ven_loss + constraint_loss

def train_model(model, train_loader, val_loader, epochs=50, lr=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    train_losses, val_losses = [], []
    best_val_loss = float('inf')

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_epoch_loss = 0
        for batch_x, batch_y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()
            # FIX: Properly handle model output
            output, _ = model(batch_x)
            art_pred = output[:, 0]
            ven_pred = output[:, 1]

            art_true, ven_true = batch_y[:, 0], batch_y[:, 1]

            loss = physiological_loss(art_pred, ven_pred, art_true, ven_true)
            loss.backward()
            optimizer.step()

            train_epoch_loss += loss.item()

        # Validation phase
        model.eval()
        val_epoch_loss = 0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                # FIX: Properly handle model output
                output, _ = model(batch_x)
                art_pred = output[:, 0]
                ven_pred = output[:, 1]

                art_true, ven_true = batch_y[:, 0], batch_y[:, 1]

                loss = physiological_loss(art_pred, ven_pred, art_true, ven_true)
                val_epoch_loss += loss.item()

        # Calculate epoch metrics (unchanged)
        train_epoch_loss /= len(train_loader)
        val_epoch_loss /= len(val_loader)
        train_losses.append(train_epoch_loss)
        val_losses.append(val_epoch_loss)

        # Update scheduler
        scheduler.step(val_epoch_loss)

        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_epoch_loss:.4f}, Val Loss: {val_epoch_loss:.4f}")

        # Save best model
        if val_epoch_loss < best_val_loss:
            best_val_loss = val_epoch_loss
            torch.save(model.state_dict(), "best_model.pth")
            print("Saved new best model")

    # Plot training history (unchanged)
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training History')
    plt.savefig('training_history.png')
    plt.show()

    return model

Evaluation Metrics

In [5]:
def evaluate_model(model, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()

    art_preds, art_trues = [], []
    ven_preds, ven_trues = [], []
    venous_scores = []

    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            # FIX: Properly handle model output
            output, venous_mask = model(batch_x)
            art_pred = output[:, 0]
            ven_pred = output[:, 1]

            art_preds.append(art_pred.cpu().numpy())
            ven_preds.append(ven_pred.cpu().numpy())
            art_trues.append(batch_y[:, 0].cpu().numpy())
            ven_trues.append(batch_y[:, 1].cpu().numpy())
            venous_scores.append(venous_mask.cpu().numpy())

    # Concatenate results (unchanged)
    art_preds = np.concatenate(art_preds)
    ven_preds = np.concatenate(ven_preds)
    art_trues = np.concatenate(art_trues)
    ven_trues = np.concatenate(ven_trues)
    venous_scores = np.concatenate(venous_scores)

    # Calculate metrics (unchanged)
    art_mae = np.mean(np.abs(art_preds - art_trues))
    ven_mae = np.mean(np.abs(ven_preds - ven_trues))
    art_rmse = np.sqrt(np.mean((art_preds - art_trues)**2))
    ven_rmse = np.sqrt(np.mean((ven_preds - ven_trues)**2))

    print(f"Arterial SpO₂ - MAE: {art_mae:.2f}%, RMSE: {art_rmse:.2f}%")
    print(f"Venous SpO₂   - MAE: {ven_mae:.2f}%, RMSE: {ven_rmse:.2f}%")

    # Plot results (unchanged)
    plt.figure(figsize=(15, 6))

    plt.subplot(1, 2, 1)
    plt.scatter(art_trues, art_preds, alpha=0.5)
    plt.plot([50, 100], [50, 100], 'r--')
    plt.xlabel('True Arterial SpO₂ (%)')
    plt.ylabel('Predicted Arterial SpO₂ (%)')
    plt.title(f'Arterial SpO₂ Prediction (MAE: {art_mae:.2f}%)')
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.scatter(ven_trues, ven_preds, alpha=0.5)
    plt.plot([30, 80], [30, 80], 'r--')
    plt.xlabel('True Venous SpO₂ (%)')
    plt.ylabel('Predicted Venous SpO₂ (%)')
    plt.title(f'Venous SpO₂ Prediction (MAE: {ven_mae:.2f}%)')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('spO2_predictions.png')
    plt.show()

    # Plot venous detection scores (unchanged)
    plt.figure(figsize=(10, 5))
    plt.hist(venous_scores, bins=50, alpha=0.7)
    plt.axvline(0.5, color='r', linestyle='--')
    plt.xlabel('Venous Presence Score')
    plt.ylabel('Frequency')
    plt.title('Venous Signal Detection Distribution')
    plt.savefig('venous_detection.png')
    plt.show()

    return {
        'art_mae': art_mae,
        'ven_mae': ven_mae,
        'art_rmse': art_rmse,
        'ven_rmse': ven_rmse,
        'venous_scores': venous_scores
    }

Workflow

In [None]:
def main():
    # Configuration
    DATA_PATH = "Shifting_Position_v3.3.h5"  # Update this path
    BATCH_SIZE = 32
    EPOCHS = 100
    LR = 1e-4
    TEST_SIZE = 0.2
    VAL_SIZE = 0.1

    # Load dataset
    full_dataset = PPGDataset(DATA_PATH, augment=True)

    # Split into train/val/test
    train_idx, test_idx = train_test_split(
        range(len(full_dataset)),
        test_size=TEST_SIZE,
        random_state=42
    )
    train_idx, val_idx = train_test_split(
        train_idx,
        test_size=VAL_SIZE/(1-TEST_SIZE),
        random_state=42
    )

    # Create subsets
    train_dataset = torch.utils.data.Subset(full_dataset, train_idx)
    val_dataset = torch.utils.data.Subset(full_dataset, val_idx)
    test_dataset = torch.utils.data.Subset(full_dataset, test_idx)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

    # Initialize model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = PPGArrayTransformer(d_model=126, nhead=6, num_layers=4).to(device)

    # Print model summary
    # summary(model, input_size=(350, 36))

    # Train model
    trained_model = train_model(model, train_loader, val_loader, epochs=EPOCHS, lr=LR)

    # Evaluate on test set
    results = evaluate_model(trained_model, test_loader)

    # Save final model
    torch.save(trained_model.state_dict(), "final_model.pth")
    print("Training complete. Model saved.")

if __name__ == "__main__":
    main()

Epoch 1/100: 100%|██████████| 113/113 [10:53<00:00,  5.79s/it]


Epoch 1/100 - Train Loss: 11904.1403, Val Loss: 11028.0198
Saved new best model


Epoch 2/100: 100%|██████████| 113/113 [10:54<00:00,  5.79s/it]


Epoch 2/100 - Train Loss: 10221.4757, Val Loss: 9152.7883
Saved new best model


Epoch 3/100: 100%|██████████| 113/113 [10:54<00:00,  5.79s/it]


Epoch 3/100 - Train Loss: 8161.2056, Val Loss: 6980.5519
Saved new best model


Epoch 4/100: 100%|██████████| 113/113 [10:53<00:00,  5.79s/it]


Epoch 4/100 - Train Loss: 5973.0540, Val Loss: 4868.3383
Saved new best model


Epoch 5/100: 100%|██████████| 113/113 [10:53<00:00,  5.79s/it]


Epoch 5/100 - Train Loss: 3998.5876, Val Loss: 3105.6494
Saved new best model


Epoch 6/100: 100%|██████████| 113/113 [10:53<00:00,  5.78s/it]


Epoch 6/100 - Train Loss: 2445.2392, Val Loss: 1807.5715
Saved new best model


Epoch 7/100: 100%|██████████| 113/113 [10:53<00:00,  5.78s/it]


Epoch 7/100 - Train Loss: 1363.8363, Val Loss: 962.8081
Saved new best model


Epoch 8/100: 100%|██████████| 113/113 [10:53<00:00,  5.78s/it]


Epoch 8/100 - Train Loss: 711.6155, Val Loss: 503.1525
Saved new best model


Epoch 9/100: 100%|██████████| 113/113 [10:53<00:00,  5.78s/it]


Epoch 9/100 - Train Loss: 388.6825, Val Loss: 303.7620
Saved new best model


Epoch 10/100: 100%|██████████| 113/113 [10:53<00:00,  5.78s/it]


Epoch 10/100 - Train Loss: 261.7249, Val Loss: 237.0788
Saved new best model


Epoch 11/100: 100%|██████████| 113/113 [10:53<00:00,  5.78s/it]


Epoch 11/100 - Train Loss: 221.6928, Val Loss: 218.7762
Saved new best model


Epoch 12/100: 100%|██████████| 113/113 [10:53<00:00,  5.78s/it]


Epoch 12/100 - Train Loss: 211.6535, Val Loss: 215.8425
Saved new best model


Epoch 13/100: 100%|██████████| 113/113 [10:53<00:00,  5.78s/it]


Epoch 13/100 - Train Loss: 209.9357, Val Loss: 214.9667
Saved new best model


Epoch 14/100: 100%|██████████| 113/113 [10:53<00:00,  5.78s/it]


Epoch 14/100 - Train Loss: 209.9053, Val Loss: 215.1142


Epoch 15/100: 100%|██████████| 113/113 [10:53<00:00,  5.78s/it]


Epoch 15/100 - Train Loss: 209.8776, Val Loss: 215.2351


Epoch 16/100:  12%|█▏        | 14/113 [01:21<09:34,  5.81s/it]