Import Statements

In [26]:
from sklearn.metrics import roc_auc_score, average_precision_score
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from torch import nn
import numpy as np
import torch
import os

from torch.nn.utils.rnn import pad_sequence

from models.EHRMamba import MambaEHR

Define paths and load data in

In [27]:
# Define paths
split = 1
train_path = f"./P12data/split_{split}/train_physionet2012_{split}.npy"
validation_path = f"./P12data/split_{split}/validation_physionet2012_{split}.npy"
test_path = f"./P12data/split_{split}/test_physionet2012_{split}.npy"

# Load data from .npy files
train_data = np.load(train_path, allow_pickle=True)
validation_data = np.load(validation_path, allow_pickle=True)
test_data = np.load(test_path, allow_pickle=True)

Preprocess data

In [28]:
# Preprocess Data
def preprocess(data):
    # Convert lists of time series to tensors, padding to the maximum length
    ts_values_all = [torch.tensor(patient['ts_values'], dtype=torch.float32) for patient in data]
    ts_indicators_all = [torch.tensor(patient['ts_indicators'], dtype=torch.float32) for patient in data]
    times_all = [torch.tensor(patient['ts_times'], dtype=torch.float32) for patient in data]

    # Pad sequences along the time dimension
    ts_values_padded = pad_sequence(ts_values_all, batch_first=True)  # Shape: (N, max_time, 37)
    ts_indicators_padded = pad_sequence(ts_indicators_all, batch_first=True)  # Shape: (N, max_time, 37)

    # Get the lengths of each sequence
    sequence_lengths = torch.tensor([len(patient['ts_times']) for patient in data], dtype=torch.long)

    # Static features and labels
    static_all = torch.tensor([patient['static'] for patient in data], dtype=torch.float32)  # Shape: (N, static_dim)
    labels_all = torch.tensor([patient['labels'] for patient in data], dtype=torch.float32).unsqueeze(1)  # Shape: (N, 1)
    
    return ts_values_padded, ts_indicators_padded, static_all, labels_all


train_ts_values_padded, train_ts_indicators_padded, train_static_all, train_labels_all = preprocess(train_data)
val_ts_values_padded, val_ts_indicators_padded, val_static_all, val_labels_all = preprocess(validation_data)
test_ts_values_padded, test_ts_indicators_padded, test_static_all, test_labels_all = preprocess(test_data)

# Check the sizes
print(f"Shape of ts_values after padding: {train_ts_values_padded.shape}")
print(f"Shape of ts_indicators after padding: {train_ts_indicators_padded.shape}")
print(f"Shape of static_all: {train_static_all.shape}")
print(f"Shape of labels_all: {train_labels_all.shape}")

Shape of ts_values after padding: torch.Size([9590, 215, 37])
Shape of ts_indicators after padding: torch.Size([9590, 215, 37])
Shape of static_all: torch.Size([9590, 8])
Shape of labels_all: torch.Size([9590, 1])


In [29]:
# Define model
model = MambaEHR(ts_dim=37, static_dim=7)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

batch_size = 32
loss_fn = nn.BCELoss()  # Binary cross-entropy for mortality prediction

train_labels_all_int = train_labels_all.squeeze().long()  # Remove extra dimension and convert to integers
val_labels_all_int = val_labels_all.squeeze().long()  # Remove extra dimension and convert to integers
test_labels_all_int = test_labels_all.squeeze().long()  # Remove extra dimension and convert to integers

In [30]:
from sklearn.metrics import roc_auc_score, average_precision_score
from torch.utils.data import DataLoader, TensorDataset

# Prepare data loaders
def create_dataloader(ts_values, ts_indicators, static, labels, batch_size):
    dataset = TensorDataset(ts_values, ts_indicators, static, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

train_loader = create_dataloader(train_ts_values_padded, train_ts_indicators_padded, train_static_all, train_labels_all, batch_size)
val_loader = create_dataloader(val_ts_values_padded, val_ts_indicators_padded, val_static_all, val_labels_all, batch_size)

# Define training and validation loops
def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    for ts_values, ts_indicators, static, labels in loader:
        ts_values, ts_indicators, static, labels = ts_values.to(device), ts_indicators.to(device), static.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(ts_values, ts_indicators, static)  # Forward pass
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        all_preds.append(outputs.detach().cpu())
        all_labels.append(labels.cpu())

    avg_loss = total_loss / len(loader)
    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    # Calculate AUROC and AUPRC
    auroc = roc_auc_score(all_labels, all_preds)
    auprc = average_precision_score(all_labels, all_preds)
    
    return avg_loss, auroc, auprc

def validate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for ts_values, ts_indicators, static, labels in loader:
            ts_values, ts_indicators, static, labels = ts_values.to(device), ts_indicators.to(device), static.to(device), labels.to(device)

            outputs = model(ts_values, ts_indicators, static)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()

            all_preds.append(outputs.cpu())
            all_labels.append(labels.cpu())

    avg_loss = total_loss / len(loader)
    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    # Calculate AUROC and AUPRC
    auroc = roc_auc_score(all_labels, all_preds)
    auprc = average_precision_score(all_labels, all_preds)
    
    return avg_loss, auroc, auprc

# Main training loop
num_epochs = 10
model.to(device)

for epoch in range(num_epochs):
    train_loss, train_auroc, train_auprc = train(model, train_loader, optimizer, loss_fn, device)
    val_loss, val_auroc, val_auprc = validate(model, val_loader, loss_fn, device)

    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"Train Loss: {train_loss:.4f}, Train AUROC: {train_auroc:.4f}, Train AUPRC: {train_auprc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val AUROC: {val_auroc:.4f}, Val AUPRC: {val_auprc:.4f}\n")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x8 and 7x32)