In [19]:
import re
import os
import platform

if platform.system() == "Darwin":
    os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from torchmetrics import F1Score
from sklearn.metrics import multilabel_confusion_matrix, precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns

In [20]:
# Define device for torch
device = torch.device("cpu")
# MPS for Apple Silicon GPUs
if torch.mps.is_available():
   print("MPS is available")
   device = torch.device("mps")

# CUDA for Nvidia GPUs
if torch.cuda.is_available():
   print("CUDA is available")
   device = torch.device("cuda")
print(device)

MPS is available
mps


In [21]:
def count_parameters(model):
    for name, module in model.named_modules():
        params = sum(p.numel() for p in module.parameters())
        print(f"{name}: {params} parameters")

# Dataset and Dataloader

To investigate: Normalization or other transforms

In [22]:
class NormalizeECG:
    def __call__(self, tensor):
        # Z-score normalization per lead
        means = tensor.mean(dim=1, keepdim=True)
        stds = tensor.std(dim=1, keepdim=True)
        return (tensor - means) / (stds + 1e-8)

In [23]:
class ECGDataset(Dataset):
    def __init__(self, path="data/ecg_clipped", diagnoses='data/diagnoses_balanced.csv', transform=None):
        # Load and prepare labels
        self.labels_df = pd.read_csv(diagnoses)
        self.path = path

        self.labels_df['ID'] = self.labels_df['ID'].astype(str).str.replace(r'\D', '', regex=True) # Remove the JS
        self.labels_df.set_index('ID', inplace=True)
        self.num_classes = self.labels_df.shape[1]
        print(f'Number of classes: {self.num_classes}')

        self.transform = transform
        self.cache = {}

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        if idx in self.cache:
            return self.cache[idx]
        
        try:
            # Access the row through iloc of the index,
            # Use the ID to make filepath
            ID = self.labels_df.iloc[idx].name

            file_path = self.path + f'/{ID}.csv'
            
            # Load ECG data
            df = pd.read_csv(file_path)
            ecg_data = df.drop(columns=['time']).values
            tensor = torch.tensor(ecg_data, dtype=torch.float32).T  # (leads, timesteps)
            
            if self.transform:
                tensor = self.transform(tensor)
                
            # Get corresponding label

            label_values = self.labels_df.loc[ID].values  # Get all label columns
            label = torch.tensor(label_values, dtype=torch.float32)  # Use float for multi-label

            return tensor, label
            
        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
            # Return zero tensor and -1 label placeholder
            return torch.zeros((12, 5000), dtype=torch.float32), torch.full((self.num_classes,), -1, dtype=torch.float32)

In [24]:
dataset = ECGDataset()
data, label = dataset.__getitem__(23423)
print(data.shape)
print(label.shape)

Number of classes: 5
torch.Size([12, 5000])
torch.Size([5])


# Basic Transformer

In [25]:
class ECGTransformer(nn.Module):
    def __init__(self, d_model, num_classes=63, nhead=8, num_encoder_layers=2, dim_feedforward=2048):
        super().__init__()
        
        # Define encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, batch_first=True)

        # Encoder stack
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        # Classification head
        self.classifier = nn.Linear(d_model, num_classes)
        
    def forward(self, x):
        encoded = self.transformer(x)
        # encoded shape: (batch_size, seq_len, d_model)
        # Pick out only the last in the sequence for classification
        encoded = encoded[:, -1, :]
        result = self.classifier(encoded)
        return result

In [26]:
model = ECGTransformer(d_model=12, nhead=4, num_classes=63, num_encoder_layers=6, dim_feedforward=512)
print(model)
count_parameters(model)

ECGTransformer(
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (linear1): Linear(in_features=12, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=12, bias=True)
        (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (classifier): Linear(in_features=12, out_features=63, bias=True)
)
: 81723 parameters
transformer: 80904 parameters
transformer.layers: 80904 parameters
transformer.layers.0: 13484 parameters
transformer.layers.0.self_attn: 624 parameters
transformer.layers.0.self_attn.out_proj:

In [27]:
inputs = torch.rand((2, 5000, 12))
out = model(inputs)
print(out.shape)
print(out[0])

torch.Size([2, 63])
tensor([-0.1000,  0.4248,  0.1513,  0.8337,  0.6989, -0.0704, -0.6072,  0.2872,
        -0.9834,  0.1062,  0.2191,  0.2958,  0.5035, -0.3359,  1.0074,  0.1418,
         0.2994,  0.4639,  0.7729,  1.0185, -0.5924, -0.0194, -0.2567,  0.8975,
        -0.5290, -0.0741,  0.3213, -0.1744,  0.4785,  0.3235,  0.0231, -0.3299,
         0.0893,  0.8521, -0.1676,  0.9647,  0.6593, -0.5105, -0.8359, -0.2195,
        -0.2407, -0.3787,  0.1329,  0.9259,  0.2729,  0.1595, -0.5420, -0.2772,
         1.0798,  0.0233, -0.2949,  1.5003, -0.3199,  0.2852,  0.3586, -0.9559,
        -0.2353,  0.4107,  0.1862,  1.1228, -0.3644, -0.5036, -1.1120],
       grad_fn=<SelectBackward0>)


# An embedding model 
that uses convolution

Convolution turning 12 channels to 128, repeated to transfer forward 200ms.


In [28]:
class ECGEmbeddings(nn.Module):
    def __init__(self, d_input, d_model, n_conv_layers=3):
        super().__init__()
        # Keep original layers
        self.conv_layers = nn.ModuleList([
            nn.Conv1d(d_model if i>0 else d_input, d_model, 51, stride=1, padding='same')
            for i in range(n_conv_layers)
        ])
        self.activation = nn.ReLU(inplace=False)  # Important for checkpointing

    def forward(self, x):
        for i in range(len(self.conv_layers)):
            x = self.conv_layers[i](x)
            if i < len(self.conv_layers) - 1:
                x = self.activation(x)

        if not x.requires_grad:
            x = x.detach().requires_grad_(True)

        return x

In [29]:
embedding_model = ECGEmbeddings(d_input = 12, d_model=512)
print(embedding_model)
count_parameters(embedding_model)

ECGEmbeddings(
  (conv_layers): ModuleList(
    (0): Conv1d(12, 512, kernel_size=(51,), stride=(1,), padding=same)
    (1-2): 2 x Conv1d(512, 512, kernel_size=(51,), stride=(1,), padding=same)
  )
  (activation): ReLU()
)
: 27053568 parameters
conv_layers: 27053568 parameters
conv_layers.0: 313856 parameters
conv_layers.1: 13369856 parameters
conv_layers.2: 13369856 parameters
activation: 0 parameters


# Combining together embedding with transformer

In [30]:
class ECGCombined(nn.Module):
    def __init__(self, d_input, d_model, num_classes=63, nhead=8, num_encoder_layers=2, dim_feedforward=2048):
        super().__init__()
        self.num_classes = num_classes
        
        self.embedding_model = ECGEmbeddings(d_input, d_model)
        self.transformer = ECGTransformer(d_model, num_classes, nhead, num_encoder_layers, dim_feedforward)
        self._init_weights()

    def forward(self, x):
        x = self.embedding_model(x)

        # Shape before transformer: (batch, channels, seq_len)
        x = x.permute(0, 2, 1)  # → (batch, seq_len, d_model)

        if torch.isnan(x).any():
            print("❗ NaNs BEFORE transformer")
            print("Input stats → min:", x.min(), "max:", x.max(), "mean:", x.mean())
            raise ValueError("NaNs before transformer")

        x = self.transformer(x)

        if torch.isnan(x).any():
            print("❗ NaNs AFTER transformer")
            print("Output stats → min:", x.min(), "max:", x.max(), "mean:", x.mean())
            raise ValueError("NaNs after transformer")

        return x
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)



# Trainer class

In [41]:
class Trainer:
    def __init__(self, model, device, accum_steps=4, checkpoint_interval=256, lr=1e-5,
                 resume_checkpoint=None):
        self.model = model
        self.device = device
        self.accum_steps = accum_steps
        self.checkpoint_interval = checkpoint_interval

        if self.device.type == 'cuda':
            self.autocast = torch.cuda.amp.autocast
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            from contextlib import nullcontext
            self.autocast = lambda: nullcontext()  # no-op context
            self.scaler = None

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.loss_history = []
        self.acc_history = []
        self.batch_count = 0
        self.start_epoch = 0
        self.start_batch = 0

        if resume_checkpoint:
            self._load_checkpoint(resume_checkpoint)

    def _load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        self.model.load_state_dict(checkpoint['model_state'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state'])

        if self.scaler and checkpoint.get('scaler_state'):
            self.scaler.load_state_dict(checkpoint['scaler_state'])

        self.loss_history = checkpoint['loss_history']
        self.acc_history = checkpoint['acc_history']
        self.batch_count = checkpoint.get('batch_count', 0)
        self.start_epoch = checkpoint['epoch']
        self.start_batch = checkpoint.get('batch', 0) + 1
        self.checkpoint_interval = checkpoint.get('checkpoint_interval', self.checkpoint_interval)
        print(f"Resuming from epoch {self.start_epoch} batch {self.start_batch}")

    def loss(self, output, target):
        return F.binary_cross_entropy_with_logits(output, target.float())

    def train(self, train_dataloader, test_dataloader, num_epochs, save_path="training_progress"):
        os.makedirs(save_path, exist_ok=True)
        self.model.train()

        for epoch in range(self.start_epoch, num_epochs):
            for batch_idx, (inputs, labels) in enumerate(train_dataloader):
                if batch_idx < self.start_batch:
                    continue

                inputs, labels = inputs.to(self.device), labels.to(self.device)

                if torch.isnan(inputs).any() or torch.isinf(inputs).any():
                    print("⚠️ Bad input detected")
                    raise ValueError("Inputs contain NaNs or Infs")

                with self.autocast():
                    outputs = self.model(inputs)
                    loss = self.loss(outputs, labels) / self.accum_steps
                    if torch.isnan(loss):
                        print("⚠️ Loss is NaN!")
                        raise ValueError("Loss turned NaN")

                if self.scaler:
                    self.scaler.scale(loss).backward()
                else:
                    loss.backward()

                if (batch_idx + 1) % self.accum_steps == 0:
                    self._update_parameters()

                current_loss = loss.item() * self.accum_steps
                self.loss_history.append(current_loss)
                self.batch_count += 1

                if self.batch_count % self.accum_steps == 0:
                    print(f"Epoch {epoch+1}/{num_epochs} | Batch {batch_idx+1}/{len(train_dataloader)} | "
                          f"Loss: {current_loss:.4f}")

                if self.batch_count % self.checkpoint_interval == 0:
                    acc = self.evaluate(test_dataloader)
                    self.acc_history.append([self.batch_count, acc])
                    self._save_checkpoint(save_path, epoch, batch_idx, test_dataloader)

                del inputs, labels, outputs, loss
                if torch.backends.mps.is_available():
                    torch.mps.empty_cache()

    def evaluate(self, dataloader):
        self.model.eval()
        total_samples = 0
        num_classes = self.model.num_classes
        mismatches_per_class = torch.zeros(num_classes, device=self.device)

        all_preds = []
        all_labels = []

        with torch.no_grad():
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                predicted = (torch.sigmoid(outputs) >= 0.5).float()

                all_preds.append(predicted.cpu())
                all_labels.append(labels.cpu())

                mismatches_per_class += (predicted != labels).sum(dim=0).float()
                total_samples += inputs.size(0)

        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)

        y_true = all_labels.numpy().astype(int)
        y_pred = all_preds.numpy().astype(int)

        f1 = F1Score(task='multilabel', num_labels=num_classes, average='macro')
        f1_score = f1(all_preds, all_labels)

        hamming_loss_per_class = mismatches_per_class.cpu().numpy() / total_samples
        overall_hamming_loss = mismatches_per_class.sum().item() / (total_samples * num_classes)

        precision_per_class = precision_score(y_true, y_pred, average=None, zero_division=0)
        recall_per_class = recall_score(y_true, y_pred, average=None, zero_division=0)
        conf_matrices = multilabel_confusion_matrix(y_true, y_pred)

        return {
            "f1": f1_score.item(),
            "overall": overall_hamming_loss,
            "per_class": hamming_loss_per_class,
            "precision": precision_per_class,
            "recall": recall_per_class,
            "conf_matrices": conf_matrices
        }

    def _update_parameters(self):
        if self.scaler:
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
            self.optimizer.step()

        self.optimizer.zero_grad()

    def _save_checkpoint(self, path, epoch, batch_idx, dataloader):
        checkpoint = {
            'epoch': epoch,
            'batch': batch_idx,
            'batch_count': self.batch_count,
            'checkpoint_interval': self.checkpoint_interval,
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'loss_history': self.loss_history,
            'acc_history': self.acc_history,
            'scaler_state': self.scaler.state_dict() if self.scaler else None
        }

        checkpoint_path = f"{path}/checkpoint_ep{epoch}_b{batch_idx}.pt"
        torch.save(checkpoint, checkpoint_path)
        print(f"\nCheckpoint saved at epoch {epoch+1} batch {batch_idx+1}")

        np.save(f"{path}/loss_history.npy", np.array(self.loss_history))
        np.save(f"{path}/acc_history.npy", np.array(self.acc_history))

        eval_metrics = self.evaluate(dataloader=dataloader)

        np.save(f"{path}/precision_ep{epoch}_b{batch_idx}.npy", eval_metrics["precision"])
        np.save(f"{path}/recall_ep{epoch}_b{batch_idx}.npy", eval_metrics["recall"])

        conf_matrices = eval_metrics["conf_matrices"]
        label_names = [
            "Dx_426177001",
            "Dx_426783006",
            "Dx_164890007",
            "Dx_427084000",
            "Dx_164934002"
        ]

        for i, cm in enumerate(conf_matrices):
            plt.figure(figsize=(4, 3))
            sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                        xticklabels=["Pred 0", "Pred 1"], yticklabels=["True 0", "True 1"])
            plt.title(f'Confusion Matrix for {label_names[i]}')
            plt.xlabel('Prediction')
            plt.ylabel('Actual')
            plt.tight_layout()
            plt.savefig(f"{path}/conf_matrix_class{i}_ep{epoch}_b{batch_idx}.png")
            plt.close()


# Let's Go Training

In [42]:
ecg_dataset = ECGDataset(diagnoses='data/diagnoses_balanced.csv', transform=NormalizeECG())
train_dataset, test_dataset, val_dataset = random_split(
                                            ecg_dataset, [len(ecg_dataset) - 1000, 500, 500], 
                                            torch.Generator().manual_seed(42))

batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Number of classes: 5


## Start from 0:

In [44]:
model = ECGCombined(d_input=12, d_model=64, num_classes=5, nhead=4, num_encoder_layers=2, dim_feedforward=128).to(device)
trainer = Trainer(model, device, accum_steps=16,checkpoint_interval=256, lr=5e-6)
trainer.train(train_dataloader, test_dataloader, num_epochs=1, save_path="training_progress/balanced")

KeyboardInterrupt: 

In [179]:
model.eval()

with torch.no_grad():
    for inputs, labels in test_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        outputs = model(inputs)  # raw logits
        print("Logits:", outputs.min().item(), outputs.max().item())

        # Optional: look at predicted classes
        probs = torch.sigmoid(outputs)
        preds = (probs >= 0.5).float()

        print("Predicted class counts per label:", preds.sum(dim=0).cpu().numpy())
        print("True class counts per label:", labels.sum(dim=0).cpu().numpy())

        break  # just do one batch


Logits: nan nan
Predicted class counts per label: [0. 0. 0. 0. 0.]
True class counts per label: [3. 1. 0. 0. 3.]


## Resume from a checkpoint:

In [None]:
resume_from = "training_progress/cut/checkpoint_ep0_b2559.pt"

model = ECGCombined(d_input=12, d_model=64, num_classes=2, nhead=4, num_encoder_layers=2, dim_feedforward=128).to(device)
trainer = Trainer(model, device, accum_steps=16, lr=1e-8, resume_checkpoint=resume_from)
trainer.train(train_dataloader, test_dataloader, num_epochs=1, save_path="training_progress/cut")

# Plot a checkpoint

In [None]:
acc_history = np.load('training_progress/cut/acc_history.npy', allow_pickle=True)

In [None]:
# Loss history plot
loss_history = np.load('training_progress/cut/loss_history.npy', allow_pickle=True)  # Load loss history
plt.plot(loss_history)
plt.xlabel('Batches')
plt.ylabel('Loss')
plt.title('Loss History')
plt.show()

In [None]:
# F1 score plot
x = np.array([i[0] for i in acc_history])

f1 = np.array([i[1]['f1'] for i in acc_history])
plt.plot(x, f1)
plt.xlabel('Batches')
plt.ylabel('F1 Score')
plt.title('F1 Score History')
plt.show()

In [None]:
# Hamming accuracy plot
classes_acc = np.array([epoch[1]['per_class'] for epoch in acc_history])
plt.plot(x, classes_acc)
plt.legend([f'Class {i}' for i in range(2)], loc='upper right')
plt.xlabel('Epochs')
plt.ylabel('Hamming Loss')
plt.yscale('log')
plt.title('Hamming Loss per class')