In [1]:
import re
import os 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from torchmetrics import F1Score
from sklearn.metrics import multilabel_confusion_matrix, precision_score, recall_score

In [2]:
# Define device for torch
device = torch.device("cpu")
# MPS for Apple Silicon GPUs
if torch.mps.is_available():
   print("MPS is available")
   device = torch.device("mps")

# CUDA for Nvidia GPUs
if torch.cuda.is_available():
   print("CUDA is available")
   device = torch.device("cuda")
print(device)

CUDA is available
cuda


amdgpu.ids: No such file or directory


In [3]:
def count_parameters(model):
    for name, module in model.named_modules():
        params = sum(p.numel() for p in module.parameters())
        print(f"{name}: {params} parameters")

# Dataset and Dataloader

To investigate: Normalization or other transforms

In [4]:
class NormalizeECG:
    def __call__(self, tensor):
        # Z-score normalization per lead
        means = tensor.mean(dim=1, keepdim=True)
        stds = tensor.std(dim=1, keepdim=True)
        return (tensor - means) / (stds + 1e-8)

In [5]:
class ECGDataset(Dataset):
    def __init__(self, path="data/ecg", diagnoses='data/diagnoses.csv', transform=None):
        # Load and prepare labels
        self.labels_df = pd.read_csv(diagnoses)

        self.labels_df['ID'] = self.labels_df['ID'].astype(str).str.replace(r'\D', '', regex=True) # Remove the JS
        self.labels_df.set_index('ID', inplace=True)
        self.num_classes = self.labels_df.shape[1]
        print(f'Number of classes: {self.num_classes}')

        self.transform = transform
        self.cache = {}

    def get_pos_weights(self):
        # Compute counts
        pos_counts = self.labels_df.sum()
        neg_counts = len(self.labels_df) - pos_counts

        # Calculate pos_weight = #neg / #pos for each class
        pos_weight = (neg_counts / pos_counts).values

        # Move to device
        pos_weight_tensor = torch.tensor(pos_weight, dtype=torch.float32, device=device)
        return pos_weight_tensor
    
    def get_num_classes(self):
        return self.num_classes

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        if idx in self.cache:
            return self.cache[idx]
        
        try:
            # Access the row through iloc of the index,
            # Use the ID to make filepath
            ID = self.labels_df.iloc[idx].name

            file_path = f'data/ecg/{ID}.csv'
            
            # Load ECG data
            df = pd.read_csv(file_path)
            ecg_data = df.drop(columns=['time']).values
            tensor = torch.tensor(ecg_data, dtype=torch.float32).T  # (leads, timesteps)
            
            if self.transform:
                tensor = self.transform(tensor)
                
            # Get corresponding label

            label_values = self.labels_df.loc[ID].values  # Get all label columns
            label = torch.tensor(label_values, dtype=torch.float32)  # Use float for multi-label

            return tensor, label
            
        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
            # Return zero tensor and -1 label placeholder
            return torch.zeros((12, 5000), dtype=torch.float32), torch.full((self.num_classes,), -1, dtype=torch.float32)

In [6]:
dataset = ECGDataset(path="data/ecg", diagnoses='data/diagnoses_balanced.csv')
data, label = dataset.__getitem__(23423)
print(data.shape)
print(label.shape)
dataset.get_pos_weights()

Number of classes: 5
torch.Size([12, 5000])
torch.Size([5])


tensor([3.1199, 3.1205, 3.1199, 3.1205, 3.1205], device='cuda:0')

# Basic Transformer

In [7]:
class ECGTransformer(nn.Module):
    def __init__(self, d_model, num_classes=63, nhead=8, num_encoder_layers=2, dim_feedforward=2048):
        super().__init__()
        
        # Define encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, batch_first=True)

        # Encoder stack
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        # Classification head
        self.classifier = nn.Linear(d_model, num_classes)
        
    def forward(self, x):
        encoded = self.transformer(x)
        # encoded shape: (batch_size, seq_len, d_model)
        # Pick out only the last in the sequence for classification
        encoded = encoded[:, -1, :]
        result = self.classifier(encoded)
        return result

In [8]:
model = ECGTransformer(d_model=12, nhead=4, num_classes=63, num_encoder_layers=6, dim_feedforward=512)
print(model)
count_parameters(model)

ECGTransformer(
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (linear1): Linear(in_features=12, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=12, bias=True)
        (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (classifier): Linear(in_features=12, out_features=63, bias=True)
)
: 81723 parameters
transformer: 80904 parameters
transformer.layers: 80904 parameters
transformer.layers.0: 13484 parameters
transformer.layers.0.self_attn: 624 parameters
transformer.layers.0.self_attn.out_proj:

In [9]:
inputs = torch.rand((2, 5000, 12))
out = model(inputs)
print(out.shape)
print(out[0])

torch.Size([2, 63])
tensor([-0.3114,  0.7817,  0.3441,  0.6371,  0.4557, -0.2647,  0.0785,  1.4021,
        -0.5778,  0.5539, -1.1239, -0.5755, -0.3548, -1.1942, -0.7231, -0.9842,
         0.3549,  0.3671,  0.0777, -1.0710, -0.1278, -0.1683,  0.3399, -0.8317,
         1.1035,  0.4088,  0.4098, -0.8206, -0.8247, -0.5101,  0.8955, -0.2588,
        -0.2010, -0.1479, -0.2765, -0.0554, -0.1968, -0.2181,  0.3201,  0.0775,
         0.3281,  0.6800, -0.1614,  0.1403, -0.5414,  0.5355, -0.1686,  0.0174,
        -0.4426, -0.6130,  1.1272, -1.0367,  0.0677, -0.2011, -0.3319,  1.0212,
        -0.7615,  0.7125, -0.5321,  0.2784, -0.0158,  0.4694,  0.9192],
       grad_fn=<SelectBackward0>)


# An embedding model 
that uses convolution

Convolution turning 12 channels to 128, repeated to transfer forward 200ms.


In [10]:
class ECGEmbeddings_old(nn.Module):
    def __init__(self, d_input, d_model, n_conv_layers=8):
        super().__init__()
        self.conv_layers = nn.ModuleList([
            nn.Conv1d(d_model if i>0 else d_input, d_model, 51, stride=1, padding='same')
            for i in range(n_conv_layers)
        ])
        self.activation = nn.ReLU(inplace=False)  # Important for checkpointing

    def forward(self, x):
        for i in range(len(self.conv_layers)):
            x = self.conv_layers[i](x)
            if i < len(self.conv_layers) - 1:
                x = self.activation(x)

        if not x.requires_grad:
            x = x.detach().requires_grad_(True)

        return x

In [11]:
class ECGEmbeddings(nn.Module):
    def __init__(self, d_input, d_model):
        super().__init__()
        self.conv_layers = nn.ModuleList()
        in_channels = d_input
        out_channels = d_input

        # Dynamically adjust the number of channels to reach d_model
        while out_channels < d_model:
            out_channels = min(d_model, out_channels * 2)  # Double channels, but cap at d_model
            self.conv_layers.append(nn.Conv1d(in_channels, out_channels, 1, stride=1, padding='same'))
            in_channels = out_channels

        self.activation = nn.ReLU(inplace=False)  # Important for checkpointing

    def forward(self, x):
        for i, conv in enumerate(self.conv_layers):
            x = conv(x)
            if i < len(self.conv_layers) - 1:  # Apply activation except for the last layer
                x = self.activation(x)

        if not x.requires_grad:
            x = x.detach().requires_grad_(True)

        return x

In [None]:
embedding_model = ECGEmbeddings(d_input=12, d_model=256)
print(embedding_model)
count_parameters(embedding_model)

ECGEmbeddings(
  (conv_layers): ModuleList(
    (0): Conv1d(12, 24, kernel_size=(1,), stride=(1,), padding=same)
    (1): Conv1d(24, 48, kernel_size=(1,), stride=(1,), padding=same)
    (2): Conv1d(48, 96, kernel_size=(1,), stride=(1,), padding=same)
    (3): Conv1d(96, 192, kernel_size=(1,), stride=(1,), padding=same)
    (4): Conv1d(192, 256, kernel_size=(1,), stride=(1,), padding=same)
  )
  (activation): ReLU()
)
: 74248 parameters
conv_layers: 74248 parameters
conv_layers.0: 312 parameters
conv_layers.1: 1200 parameters
conv_layers.2: 4704 parameters
conv_layers.3: 18624 parameters
conv_layers.4: 49408 parameters
activation: 0 parameters


# Combining together embedding with transformer

In [13]:
class ECGCombined(nn.Module):
    def __init__(self, d_input, d_model, num_classes=63, nhead=8, num_encoder_layers=2, dim_feedforward=2048):
        super().__init__()
        self.num_classes = num_classes
        
        self.embedding_model = ECGEmbeddings(d_input, d_model)
        self.transformer = ECGTransformer(d_model, num_classes, nhead, num_encoder_layers, dim_feedforward)

    def forward(self, x):
        x = self.embedding_model(x)
        x = x.permute(0, 2, 1)       # Reshape to (batch_size, seq_len, d_model)
        x = self.transformer(x)
        return x

# Trainer class

In [14]:
class Trainer:
    def __init__(self, model, device, pos_weights=None, accum_steps=4, checkpoint_interval=256, lr=1e-4,
                 resume_checkpoint=None):
        self.model = model
        self.device = device
        self.accum_steps = accum_steps
        self.checkpoint_interval = checkpoint_interval
        
        # Initialize essential components
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.f1 = F1Score(task='multilabel', num_labels=self.model.num_classes, average=None)
        self.loss = nn.BCEWithLogitsLoss(pos_weight=pos_weights)   
        self.accum_loss = 0.0
        self.loss_history = []
        self.acc_history = []
        self.batch_count = 0
        self.start_epoch = 0
        self.start_batch = 0

        if resume_checkpoint:
            self._load_checkpoint(resume_checkpoint)

    def _load_checkpoint(self, checkpoint_path):
        """Load training state from checkpoint"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        
        # Essential parameters
        self.model.load_state_dict(checkpoint['model_state'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state'])
        
        # Training progress
        self.loss_history = checkpoint['loss_history']
        self.acc_history = checkpoint['acc_history']
        self.batch_count = checkpoint.get('batch_count', 0)
        self.start_epoch = checkpoint['epoch']
        self.start_batch = checkpoint.get('batch', 0) + 1
        
        # Configurations
        self.checkpoint_interval = checkpoint.get('checkpoint_interval', 
                                                 self.checkpoint_interval)
        
        print(f"Resuming from epoch {self.start_epoch} batch {self.start_batch}")

    def train(self, train_dataloader, test_dataloader, num_epochs, save_path="training_progress"):
        os.makedirs(save_path, exist_ok=True)
        self.model.train()
        
        for epoch in range(self.start_epoch, num_epochs):
            for batch_idx, (inputs, labels) in enumerate(train_dataloader):
                if batch_idx < self.start_batch:
                    continue
                
                # Forward pass
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.loss(outputs, labels) / self.accum_steps

                # Backward pass
                loss.backward()

                # Every batch
                self.accum_loss += loss.item()
                self.batch_count += 1
                
                # Every accum_steps
                if (batch_idx + 1) % self.accum_steps == 0:
                    self._update_parameters()
                    
                    # Save loss
                    avg_loss = self.accum_loss
                    self.loss_history.append([self.batch_count, avg_loss])
                    self.accum_loss = 0.0

                    print(f"Epoch {epoch+1}/{num_epochs} | Batch {batch_idx+1}/{len(train_dataloader)} | "
                        f"Avg Loss: {avg_loss:.4f}")

                # Every checkpoint_interval
                if self.batch_count % self.checkpoint_interval == 0:
                    acc = self.evaluate(test_dataloader)
                    self.acc_history.append([self.batch_count, acc])
                    self._save_checkpoint(save_path, epoch, batch_idx)
                
                if self.batch_count == 64:
                    probs = torch.sigmoid(outputs)
                    print("➡️ Example probs:", probs[0].detach().cpu().numpy())
                    print("➡️ Predictions:", (probs > 0.5).float()[0])
                    print("➡️ Ground truth :", labels[0].cpu().numpy())
                
                del inputs, labels, outputs, loss

    def evaluate(self, dataloader):
        self.model.eval()
        total_samples = 0
        num_classes = self.model.num_classes
        mismatches_per_class = torch.zeros(num_classes, device=self.device)

        all_preds = []
        all_labels = []

        with torch.no_grad():
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                
                # Get binary predictions (0 or 1) using threshold
                predicted = (torch.sigmoid(outputs) >= 0.5).float()

                all_preds.append(predicted.cpu())
                all_labels.append(labels.cpu())
                
                # Track mismatches per class
                mismatches_per_class += (predicted != labels).sum(dim=0).float()
                total_samples += inputs.size(0)  # Batch size

        # Concatenate all predictions and labels
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
        y_true = all_labels.numpy().astype(int)
        y_pred = all_preds.numpy().astype(int)  

        # Accuracy metrics
        f1_score_per_class = self.f1(all_preds, all_labels).cpu().numpy()
        print(f"F1 Score per class: {f1_score_per_class}")
        hamming_loss_per_class = mismatches_per_class.cpu().numpy() / total_samples

        # Per-class metrics
        f1_score_per_class = self.f1(all_preds, all_labels).cpu().numpy()
        hamming_loss_per_class = mismatches_per_class.cpu().numpy() / total_samples
        overall_hamming_loss = mismatches_per_class.sum().item() / (total_samples * num_classes)

        precision_per_class = precision_score(y_true, y_pred, average=None, zero_division=0)
        recall_per_class = recall_score(y_true, y_pred, average=None, zero_division=0)

        # Extract TP, FP, FN, TN from multilabel confusion matrices
        conf_matrices = multilabel_confusion_matrix(y_true, y_pred)
        tp = conf_matrices[:, 1, 1].tolist()
        fp = conf_matrices[:, 0, 1].tolist()
        fn = conf_matrices[:, 1, 0].tolist()
        tn = conf_matrices[:, 0, 0].tolist()

        print(f"F1 Score per class: {f1_score_per_class}")

        return {
            "f1_per_class": f1_score_per_class,
            "overall_hamming_loss": overall_hamming_loss,
            "hamming_loss_per_class": hamming_loss_per_class,
            "precision": precision_per_class,
            "recall": recall_per_class,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn
        }
    
    def _update_parameters(self):
        """Update model parameters with gradient clipping"""
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        self.optimizer.zero_grad()

    def _save_checkpoint(self, path, epoch, batch_idx):
        """Save model and training state"""
        checkpoint = {
            'epoch': epoch,
            'batch': batch_idx,
            'batch_count': self.batch_count,
            'checkpoint_interval': self.checkpoint_interval,
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'loss_history': self.loss_history,
            'acc_history': self.acc_history
        }
        
        torch.save(checkpoint, f"{path}/checkpoint_ep{epoch}_b{batch_idx}.pt")
        print(f"\nCheckpoint saved at epoch {epoch+1} batch {batch_idx+1}")

        np.save(f"{path}/loss_history.npy", np.array(self.loss_history))
        np.save(f"{path}/acc_history.npy", np.array(self.acc_history))


# Let's Go Training

## Settings

In [15]:
# Meta
diagnoses = "data/diagnoses_balanced.csv"
data_path = "data/ecg_clipped"
save_path = "training_progress/new_balanced"
checkpoint_interval = 64

# Hyperparameters
add_pos_weights = True
normalize = True
batch_size = 4
accum_steps = 4         # Updates every accum_steps batches
starting_lr = 1e-5      # For resuming, set lr (could be lower) at the resume cell below

# Embeddings parameters
d_input = 12
d_model = 128

# Transformer parameters
nhead = 4
num_encoder_layers = 2
dim_feedforward = 256

In [16]:
ecg_dataset = ECGDataset(diagnoses=diagnoses, transform=NormalizeECG() if normalize else None)
pos_weights = ecg_dataset.get_pos_weights() if add_pos_weights else None
num_classes = ecg_dataset.get_num_classes()

train_dataset, test_dataset, val_dataset = random_split(
                                            ecg_dataset, [len(ecg_dataset) - 1000, 500, 500])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Number of classes: 5


## Start from 0:

In [17]:
model = ECGCombined(d_input=d_input, d_model=d_model, num_classes=num_classes, nhead=nhead, num_encoder_layers=num_encoder_layers, dim_feedforward=dim_feedforward).to(device)
trainer = Trainer(model, device, accum_steps=accum_steps, lr=starting_lr, pos_weights=pos_weights, checkpoint_interval=checkpoint_interval)
trainer.train(train_dataloader, test_dataloader, num_epochs=1, save_path=save_path)

  proj = linear(q, w, b)


Epoch 1/1 | Batch 4/6998 | Avg Loss: 1.0758
Epoch 1/1 | Batch 8/6998 | Avg Loss: 1.0608
Epoch 1/1 | Batch 12/6998 | Avg Loss: 1.1923
Epoch 1/1 | Batch 16/6998 | Avg Loss: 1.0454
Epoch 1/1 | Batch 20/6998 | Avg Loss: 1.0194
Epoch 1/1 | Batch 24/6998 | Avg Loss: 1.2011
Epoch 1/1 | Batch 28/6998 | Avg Loss: 1.0529
Epoch 1/1 | Batch 32/6998 | Avg Loss: 1.0375
Epoch 1/1 | Batch 36/6998 | Avg Loss: 1.0557
Epoch 1/1 | Batch 40/6998 | Avg Loss: 1.0674
Epoch 1/1 | Batch 44/6998 | Avg Loss: 1.0651
Epoch 1/1 | Batch 48/6998 | Avg Loss: 1.0788
Epoch 1/1 | Batch 52/6998 | Avg Loss: 1.1982
Epoch 1/1 | Batch 56/6998 | Avg Loss: 1.1743
Epoch 1/1 | Batch 60/6998 | Avg Loss: 1.1770
Epoch 1/1 | Batch 64/6998 | Avg Loss: 1.0971
➡️ Example probs: [0.57436013 0.4888995  0.4334205  0.7291495  0.5747777 ]
➡️ Predictions: tensor([1., 0., 0., 1., 1.], device='cuda:0')
➡️ Ground truth : [0. 0. 0. 1. 0.]
Epoch 1/1 | Batch 68/6998 | Avg Loss: 1.1165
Epoch 1/1 | Batch 72/6998 | Avg Loss: 1.1874
Epoch 1/1 | Batch 76

KeyboardInterrupt: 

## Resume from a checkpoint:

In [None]:
resume_from = f"{save_path}/checkpoint_ep0_b2047.pt"
resume_lr = 1e-8

model = ECGCombined(d_input=d_input, d_model=d_model, num_classes=num_classes, nhead=nhead, num_encoder_layers=num_encoder_layers, dim_feedforward=dim_feedforward).to(device)
trainer = Trainer(model, device, accum_steps=accum_steps, lr=starting_lr, pos_weights=pos_weights, checkpoint_interval=checkpoint_interval, resume_checkpoint=resume_from)
trainer.train(train_dataloader, test_dataloader, num_epochs=1, save_path=save_path)

## Plot accuracy and loss

In [None]:
loss_history = np.load(f'{save_path}/loss_history.npy', allow_pickle=True)  # Load loss history
acc_history = np.load(f'{save_path}/acc_history.npy', allow_pickle=True)

In [None]:
# Loss history plot
x = [epoch[0] for epoch in loss_history]
y = [epoch[1] for epoch in loss_history]
plt.plot(x, y, label='Loss')
plt.xlabel('Batches')
plt.ylabel('Loss')
plt.title('Loss History')
plt.show()

In [None]:
# F1 score plot
x = [epoch[0] for epoch in acc_history]
y = [epoch[1]['f1_per_class'] for epoch in acc_history]

plt.plot(x, y)
plt.xlabel('Batches')
plt.legend([f'Class {i}' for i in range(len(y[0]))])
plt.ylabel('F1 Score')
plt.title('F1 Score History')
plt.show()

In [None]:
# Hamming accuracy plot
x = [epoch[0] for epoch in acc_history]
y = [epoch[1]['hamming_loss_per_class'] for epoch in acc_history]

plt.plot(x, y)
plt.legend([f'Class {i}' for i in range(len(y[0]))])
plt.xlabel('Batches')
plt.ylabel('Hamming Loss')
#plt.yscale('log')
plt.title('Hamming Loss per class')