In [1]:
import re
import os 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F

In [2]:
# Define device for torch
device = torch.device("cpu")
# MPS for Apple Silicon GPUs
if torch.mps.is_available():
   print("MPS is available")
   device = torch.device("mps")

# CUDA for Nvidia GPUs
if torch.cuda.is_available():
   print("CUDA is available")
   device = torch.device("cuda")
print(device)

MPS is available
mps


In [3]:
def count_parameters(model):
    for name, module in model.named_modules():
        params = sum(p.numel() for p in module.parameters())
        print(f"{name}: {params} parameters")

# Dataset and Dataloader

To investigate: Normalization or other transforms

In [4]:
class NormalizeECG:
    def __call__(self, tensor):
        # Z-score normalization per lead
        means = tensor.mean(dim=1, keepdim=True)
        stds = tensor.std(dim=1, keepdim=True)
        return (tensor - means) / (stds + 1e-8)

In [5]:
class ECGDataset(Dataset):
    def __init__(self, path="data/WFDBRecords", diagnoses='data/diagnoses.csv', transform=None):
        # Store file paths with cleaned IDs
        self.file_info = []  # List of tuples (cleaned_id, file_path)
        
        # Extract numeric IDs from filenames
        for root, _, files in os.walk(path):
            for file in files:
                if file.endswith('.csv'):
                    # Extract numeric part from filename (e.g., "abc00001.csv" -> "00001")
                    cleaned_id = re.sub(r'\D', '', file)
                    self.file_info.append((
                        cleaned_id,
                        os.path.join(root, file)
                    ))

        # Load and prepare labels
        self.labels_df = pd.read_csv(diagnoses)
        # Clean IDs in both DataFrame and our file list
        self.labels_df['ID'] = self.labels_df['ID'].astype(str).str.replace(r'\D', '', regex=True)
        self.labels_df.set_index('ID', inplace=True)

        # Filter out files without corresponding labels
        valid_ids = set(self.labels_df.index)
        self.file_info = [(i, p) for i, p in self.file_info if i in valid_ids]

        self.transform = transform
        self.cache = {}

    def __len__(self):
        return len(self.file_info)

    def __getitem__(self, idx):
        if len(self.file_info) == 0:
            raise ValueError("Dataset is empty. Check file paths and diagnoses.csv.")
    
        if idx in self.cache:
            return self.cache[idx]

        cleaned_id, file_path = self.file_info[idx]
        
        try:
            # Load ECG data
            df = pd.read_csv(file_path)
            ecg_data = df.drop(columns=['time']).values
            tensor = torch.tensor(ecg_data, dtype=torch.float32).T  # (leads, timesteps)
            
            if self.transform:
                tensor = self.transform(tensor)
                
            # Get corresponding label
            label_values = self.labels_df.loc[cleaned_id].values  # Get all label columns
            label = torch.tensor(label_values, dtype=torch.float32)  # Use float for multi-label

            return tensor, label
            
        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
            # Return zero tensor and -1 label placeholder
            return torch.zeros((12, 5000)), torch.tensor(-1, dtype=torch.float32)


In [6]:
dataset = ECGDataset()
data, label = dataset.__getitem__(0)
print(data.shape)
print(label.shape)

torch.Size([12, 5000])
torch.Size([94])


# Basic Transformer

In [7]:
class ECGTransformer(nn.Module):
    def __init__(self, d_model, num_classes=63, nhead=8, num_encoder_layers=2, dim_feedforward=2048):
        super().__init__()
        
        # Define encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, batch_first=True)

        # Encoder stack
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        # Classification head
        self.classifier = nn.Linear(d_model, num_classes)
        
    def forward(self, x):
        encoded = self.transformer(x)
        # encoded shape: (batch_size, seq_len, d_model)
        # Pick out only the last in the sequence for classification
        encoded = encoded[:, -1, :]
        result = self.classifier(encoded)
        return result

In [8]:
model = ECGTransformer(d_model=12, nhead=4, num_classes=63, num_encoder_layers=6, dim_feedforward=512)
print(model)
count_parameters(model)

ECGTransformer(
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (linear1): Linear(in_features=12, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=12, bias=True)
        (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (classifier): Linear(in_features=12, out_features=63, bias=True)
)
: 81723 parameters
transformer: 80904 parameters
transformer.layers: 80904 parameters
transformer.layers.0: 13484 parameters
transformer.layers.0.self_attn: 624 parameters
transformer.layers.0.self_attn.out_proj:

In [9]:
inputs = torch.rand((2, 5000, 12))
out = model(inputs)
print(out.shape)
print(out[0])

torch.Size([2, 63])
tensor([ 1.0182, -0.9489,  0.2784,  0.9533, -0.1510, -0.6213,  0.5716, -0.4948,
        -0.2717, -0.6362,  0.0804,  0.9113, -0.1880,  0.4689, -0.2032,  0.1908,
         0.3878, -0.3923, -0.0663, -0.0287, -0.8647,  0.4793,  0.2846, -0.3575,
        -0.6002,  0.2385, -0.6557, -0.3875, -0.0052,  0.8377, -1.0622, -0.9162,
        -0.4564, -1.1498,  0.9539,  0.4649,  1.3470, -0.8705,  0.1994,  1.0465,
        -0.7717, -0.5953,  0.4582, -0.1437, -0.9721, -0.4984, -1.4124,  0.3614,
        -0.7855,  0.5249,  0.2516,  0.2192,  0.5857, -0.3805, -0.7895, -0.4147,
         0.3128, -0.8980,  0.2463, -0.7288, -0.8712,  0.3469,  0.1505],
       grad_fn=<SelectBackward0>)


# An embedding model 
that uses convolution

Convolution turning 12 channels to 128, repeated to transfer forward 200ms.


In [10]:
class ECGEmbeddings(nn.Module):
    def __init__(self, d_input, d_model):
        super().__init__()
        # Keep original layers
        self.conv_layers = nn.ModuleList([
            nn.Conv1d(d_model if i>0 else d_input, d_model, 50, stride=1, padding='same')
            for i in range(8)
        ])
        self.activation = nn.ReLU(inplace=False)  # Important for checkpointing

    def forward(self, x):
        # Split into gradient checkpoint blocks
        def run_block(x, start, end):
            for i in range(start, end+1):
                x = self.conv_layers[i](x)
                x = self.activation(x)
            return x

        if not x.requires_grad:
            x = x.detach().requires_grad_(True)
        
        # Checkpoint groups of layers
        x = torch.utils.checkpoint.checkpoint(run_block, x, 0, 3, use_reentrant=False)  # Layers 0-3
        x = torch.utils.checkpoint.checkpoint(run_block, x, 4, 7, use_reentrant=False)  # Layers 4-7
        return x

In [11]:
embedding_model = ECGEmbeddings(d_input = 12, d_model=512)
print(embedding_model)
count_parameters(embedding_model)

ECGEmbeddings(
  (conv_layers): ModuleList(
    (0): Conv1d(12, 512, kernel_size=(50,), stride=(1,), padding=same)
    (1-7): 7 x Conv1d(512, 512, kernel_size=(50,), stride=(1,), padding=same)
  )
  (activation): ReLU()
)
: 92061696 parameters
conv_layers: 92061696 parameters
conv_layers.0: 307712 parameters
conv_layers.1: 13107712 parameters
conv_layers.2: 13107712 parameters
conv_layers.3: 13107712 parameters
conv_layers.4: 13107712 parameters
conv_layers.5: 13107712 parameters
conv_layers.6: 13107712 parameters
conv_layers.7: 13107712 parameters
activation: 0 parameters


# Combining together embedding with transformer

In [12]:
class ECGCombined(nn.Module):
    def __init__(self, d_input, d_model, num_classes=63, nhead=8, num_encoder_layers=2, dim_feedforward=2048):
        super().__init__()
        
        self.embedding_model = ECGEmbeddings(d_input, d_model)
        self.transformer = ECGTransformer(d_model, num_classes, nhead, num_encoder_layers, dim_feedforward)

    def forward(self, x):
        x = self.embedding_model(x)
        x = x.permute(0, 2, 1)       # Reshape to (batch_size, seq_len, d_model)
        x = self.transformer(x)
        return x

# Trainer class

In [13]:
class Trainer:
    def __init__(self, model, device, accum_steps=4, checkpoint_interval=100, lr = 1e-4,
                 resume_checkpoint=None):
        self.model = model
        self.device = device
        self.accum_steps = accum_steps
        self.checkpoint_interval = checkpoint_interval
        
        # Initialize essential components first
        self.scaler = torch.cuda.amp.GradScaler()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.loss_history = []
        self.acc_history = []
        self.batch_count = 0
        self.start_epoch = 0
        self.start_batch = 0

        # Override with checkpoint if provided
        if resume_checkpoint:
            self._load_checkpoint(resume_checkpoint)

    def _load_checkpoint(self, checkpoint_path):
        """Load training state from checkpoint"""
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # Essential parameters
        self.model.load_state_dict(checkpoint['model_state'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state'])
        self.scaler.load_state_dict(checkpoint['scaler_state'])
        
        # Training progress
        self.loss_history = checkpoint['loss_history']
        self.acc_history = checkpoint['acc_history']
        self.batch_count = checkpoint.get('batch_count', 0)
        self.start_epoch = checkpoint['epoch']  # Resume from same epoch
        self.start_batch = checkpoint.get('batch', 0) + 1  # Next batch
        
        # Configurations (optional but recommended)
        self.checkpoint_interval = checkpoint.get('checkpoint_interval', 
                                                 self.checkpoint_interval)
        
        print(f"Resuming from epoch {self.start_epoch} batch {self.start_batch}")

    def loss(self, output, target):
        return F.binary_cross_entropy_with_logits(output, target.float())

    def train(self, train_dataloader, test_dataloader, num_epochs, save_path="training_progress"):
        os.makedirs(save_path, exist_ok=True)
        self.model.train()
        
        # Start from the last checkpoint epoch
        print("Hi")
        for epoch in range(self.start_epoch, num_epochs):
            print("Hello")
            # Skip batches we've already processed in this epoch
            for batch_idx, (inputs, labels) in enumerate(train_dataloader):
                if batch_idx < self.start_batch:
                    continue
                # Forward pass
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    outputs = self.model(inputs)
                    loss = self.loss(outputs, labels) / self.accum_steps

                # Backward pass
                self.scaler.scale(loss).backward()
                
                # Gradient accumulation
                if (batch_idx + 1) % self.accum_steps == 0:
                    self._update_parameters()
                
                # Logging and checkpointing
                current_loss = loss.item() * self.accum_steps
                self.loss_history.append(current_loss)
                self.batch_count += 1

                # Print every batch
                print(f"Epoch {epoch+1}/{num_epochs} | Batch {batch_idx+1}/{len(train_dataloader)} | "
                    f"Loss: {current_loss:.4f}")

                # Save checkpoint
                if self.batch_count % self.checkpoint_interval == 0:
                    acc = self.evaluate(test_dataloader)
                    self.acc_history.append([self.batch_count, acc])
                    self._save_checkpoint(save_path, epoch, batch_idx)
                
                del inputs, labels, outputs, loss

    def evaluate(self, dataloader):
        self.model.eval()
        total_mismatches = 0.0
        total_labels = 0
        with torch.no_grad():
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                
                # Get binary predictions (0 or 1) using threshold
                predicted = (torch.sigmoid(outputs) >= 0.5).float()
                
                # Calculate mismatches
                mismatches = (predicted != labels).float().sum().item()
                total_mismatches += mismatches
                total_labels += labels.numel()  # Total labels = batch_size * num_classes

        hamming_loss = total_mismatches / total_labels
        print(f'Hamming Loss: {hamming_loss:.4f}')
        return hamming_loss


    def _update_parameters(self):
        """Update model parameters with gradient clipping"""
        self.scaler.unscale_(self.optimizer)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()

    def _save_checkpoint(self, path, epoch, batch_idx):
        """Save model and training state"""
        checkpoint = {
            'epoch': epoch,
            'batch': batch_idx,
            'batch_count': self.batch_count,
            'checkpoint_interval': self.checkpoint_interval,
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'loss_history': self.loss_history,
            'acc_history': self.acc_history,
            'scaler_state': self.scaler.state_dict()
        }
        
        torch.save(checkpoint, f"{path}/checkpoint_ep{epoch}_b{batch_idx}.pt")
        print(f"\nCheckpoint saved at epoch {epoch+1} batch {batch_idx+1}")
        
        # Save loss history separately for easy plotting
        np.save(f"{path}/loss_history.npy", np.array(self.loss_history))
        np.save(f"{path}/acc_history.npy", np.array(self.acc_history))


# Let's Go Training

In [14]:
ecg_dataset = ECGDataset(transform=None)
train_dataset, test_dataset, val_dataset = random_split(
                                            ecg_dataset, [len(ecg_dataset) - 1000, 500, 500], 
                                            torch.Generator().manual_seed(42))

batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

## Start from 0:

In [41]:
model = ECGCombined(d_input=12, d_model=128, num_classes=94, nhead=4, num_encoder_layers=2, dim_feedforward=256).to(device)
trainer = Trainer(model, device, accum_steps=4, lr=1e-4)
trainer.train(train_dataloader, test_dataloader, num_epochs=1)

  self.scaler = torch.cuda.amp.GradScaler()


Hi
Hello


  return F.conv1d(


Epoch 1/1 | Batch 1/8638 | Loss: 0.7124




Epoch 1/1 | Batch 2/8638 | Loss: 0.7103




Epoch 1/1 | Batch 3/8638 | Loss: 0.7207




Epoch 1/1 | Batch 4/8638 | Loss: 0.7128




Epoch 1/1 | Batch 5/8638 | Loss: 0.6502




## Resume from a checkpoint:

In [15]:
resume_from = "training_progress/checkpoint_ep0_b1099.pt"

model = ECGCombined(d_input=12, d_model=128, num_classes=94, nhead=4, num_encoder_layers=2, dim_feedforward=256).to(device)
trainer = Trainer(model, device, accum_steps=4, resume_checkpoint=resume_from, lr=1e-6)
trainer.train(train_dataloader, test_dataloader, num_epochs=1)

  self.scaler = torch.cuda.amp.GradScaler()


Resuming from epoch 0 batch 1100
Hi
Hello


  return F.conv1d(


Epoch 1/1 | Batch 1101/8638 | Loss: 0.1097




Epoch 1/1 | Batch 1102/8638 | Loss: 0.0998




Epoch 1/1 | Batch 1103/8638 | Loss: 0.1173




Epoch 1/1 | Batch 1104/8638 | Loss: 0.0896




Epoch 1/1 | Batch 1105/8638 | Loss: nan




Epoch 1/1 | Batch 1106/8638 | Loss: nan




Epoch 1/1 | Batch 1107/8638 | Loss: nan




KeyboardInterrupt: 

# Plot a checkpoint

In [None]:
checkpoint = torch.load("training_progress/checkpoint_ep0_b1399.pt")

fig, ax1 = plt.subplots()

ax1.plot(checkpoint['loss_history'], label='Loss')
ax1.set_xlabel('Batch')
ax1.set_ylabel('Training Loss')
ax1.set_title('Training Loss and Accuracy')

ax2 = ax1.twinx()
ax2.plot([x[0] for x in checkpoint['acc_history']], [x[1] for x in checkpoint['acc_history']], color='red', label='Accuracy')
ax2.set_ylabel('Hamming Loss')

fig.tight_layout()
plt.show()