<a href="https://colab.research.google.com/github/Zfeng0207/FIT3199-FYP/blob/dev%2Fzfeng/hpa_pt_lightning_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
os.chdir('/content/drive/MyDrive/Colab Notebooks (1)/ECG-MIMIC-main')

In [3]:
!pip install -qqqq mlflow torchmetrics pytorch_lightning iterative-stratification

# Number of sparse target class

In [4]:
import pandas as pd

df = pd.read_csv("src/data/label_df.csv")

In [5]:
def count_empty_labels(df, label_col="res"):
    """
    Counts how many samples in the DataFrame have all-zero labels.

    Args:
        df (pd.DataFrame): DataFrame containing the dataset
        label_col (str): Column name containing the multi-hot labels

    Returns:
        int: Number of rows with all-zero labels
    """
    empty_count = 0

    for label_str in df[label_col]:
        label_array = np.fromstring(label_str.strip('[]'), sep=' ')  # Use np.fromstring instead of eval
        if label_array.sum() == 0:
            empty_count += 1

    return empty_count


In [6]:
import numpy as np

num_empty = count_empty_labels(df, label_col="res")
print(f"There are {num_empty} samples with all-zero labels.")

There are 56715 samples with all-zero labels.


# Configurations

In [7]:
from dataclasses import dataclass
import os
import platform

# You can define ROOT_PATH somewhere above
ROOT_PATH = "/content/drive/MyDrive/Colab Notebooks (1)/ECG-MIMIC-main/src"

@dataclass
class DatasetConfig:
    # ECG-specific
    NUM_LEADS:    int = 12  # 12 ECG channels (leads)
    NUM_CLASSES:  int = 12  # 12 ICD disease codes
    VALID_PCT:  float = 0.1

    # Dataset file and folder paths
    TRAIN_CSV:   str = os.path.join(ROOT_PATH, "data/train.csv")  # Your preprocessed split CSV
    TEST_CSV:    str = os.path.join(ROOT_PATH, "data/test.csv")
    MEMMAP_FILE: str = os.path.join(ROOT_PATH, "ecg_dataset", "data/memmap/memmap.npy")
    MEMMAP_META: str = os.path.join(ROOT_PATH, "ecg_dataset", "data/memmap/memmap_meta.npz")

@dataclass
class TrainingConfig:
    BATCH_SIZE:      int = 32
    NUM_EPOCHS:      int = 30  # Actual training epochs
    INIT_LR:       float = 1e-3
    NUM_WORKERS:     int = 0
    OPTIMIZER_NAME:  str = "Adam"
    WEIGHT_DECAY:  float = 1e-4
    USE_SCHEDULER:  bool = True
    SCHEDULER:       str = "multi_step_lr"  # or "cosine_annealing"
    F1_METRIC_THRESH: float = 0.5
    FREEZE_BACKBONE: bool = False

    # (Optional) model name (if you want to log it somewhere)
    MODEL_NAME:      str = "resnet50"


In [8]:
def encode_label(label: list, num_classes=10):
    """
    This functions converts labels into multi-hot encoding.
    Handles both single ICD codes and lists of codes.
    """
    target = torch.zeros(num_classes)

    # If label is a single code, make it a list
    if isinstance(label, str):
        label = [label]

    for l in label:
        # Check if 'l' contains brackets (indicating list within a string)
        if '[' in l or ']' in l:
            l = l.strip('[]').replace("'", "").split(",")  # Handle list-like strings
            for code in l:
                code = code.strip()  # Remove any whitespace around code
                if code in icd_to_index:
                    target[icd_to_index[code]] = 1.0
        else:
            l = l.strip()  # Remove any whitespace around code
            if l in icd_to_index:
                target[icd_to_index[l]] = 1.0
    return target


def decode_target(
    target: list,
    text_labels: bool = False,
    threshold: float = 0.4,
    cls_labels: dict = None,
):
    """This function converts the labels from
    probablities to outputs or string representations
    """

    result = []
    for i, x in enumerate(target):
        if x >= threshold:
            if text_labels:
                result.append(cls_labels[i] + "(" + str(i) + ")")
            else:
                result.append(str(i))
    return " ".join(result)


# This function is used for reversing the Normalization step performed
# during image preprocessing.
# Note the mean and std values must match the ones used.

def denormalize(tensors, *, mean, std):
    """Denormalizes image tensors using mean and std provided
    and clip values between 0 and 1"""

    for c in range(DatasetConfig.CHANNELS):
        tensors[:, c, :, :].mul_(std[c]).add_(mean[c])

    return torch.clamp(tensors, min=0.0, max=1.0)

In [9]:
# Create a dictionary mapping ICD codes to index
# icd_to_index = {code: idx for idx, code in enumerate(target_icd_codes)}


# Dataset

In [10]:
import torch
from torch.utils.data import Dataset

class ECGDataset(Dataset):
    def __init__(self, dataframe, memmap, normalize=True):
        self.df = dataframe.reset_index(drop=True)
        self.memmap = memmap
        self.normalize = normalize
        self.num_classes = 12

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        start = int(row['start'])
        length = int(row['length'])

        signal = self.memmap[start : start + length * 12]
        signal = signal.reshape(length, 12)

        if self.normalize:
            signal = (signal - np.nanmean(signal, axis=0)) / (np.nanstd(signal, axis=0) + 1e-6)

        signal = torch.tensor(signal, dtype=torch.float32).permute(1, 0)  # [12, time]
        label = np.fromstring(row['res'].strip('[]'), sep=' ', dtype=np.float32)  # Change this line

        # Return label as float32
        return signal, torch.tensor(label, dtype=torch.float32)  # Change this line

# Data Module

In [11]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import numpy as np
import os
import pandas as pd
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold, MultilabelStratifiedShuffleSplit

class ECGDataModule(pl.LightningDataModule):
    def __init__(self, dataframe, memmap, batch_size, num_workers, pin_memory, valid_pct, normalize=True, shuffle_validation=False): # Add shuffle_validation as a parameter
        super().__init__()
        self.dataframe = dataframe  # Full df
        self.memmap = memmap
        self.normalize = normalize
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.valid_pct = valid_pct
        self.shuffle_validation = shuffle_validation

    def setup(self, stage=None):
        np.random.seed(42)

        # 1. Prepare your multi-hot label matrix
        label_cols = ["res"]
        Y_wrong = self.dataframe[label_cols].values
        Y = np.vstack([np.fromstring(row[0][1:-1], sep=' ') for row in Y_wrong])

        # 2. First split into train_val and test
        splitter = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
        train_val_idx, test_idx = next(splitter.split(self.dataframe, Y))

        df_train_val = self.dataframe.iloc[train_val_idx].reset_index(drop=True)
        df_test = self.dataframe.iloc[test_idx].reset_index(drop=True)

        # 3. Now split train_val into train and val
        splitter_val = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
        Y_train_val_wrong = df_train_val[label_cols].values
        Y_train_val = np.vstack([np.fromstring(row[0][1:-1], sep=' ', dtype=np.float64) for row in Y_train_val_wrong])

        train_idx, val_idx = next(splitter_val.split(df_train_val, Y_train_val))

        df_train = df_train_val.iloc[train_idx].reset_index(drop=True)
        df_val = df_train_val.iloc[val_idx].reset_index(drop=True)

        # 4. Create datasets
        self.train_ds = ECGDataset(
            dataframe=df_train,
            memmap=self.memmap,
            normalize=self.normalize,
        )

        self.valid_ds = ECGDataset(
            dataframe=df_val,
            memmap=self.memmap,
            normalize=self.normalize,
        )

        self.test_ds = ECGDataset(
            dataframe=df_test,
            memmap=self.memmap,
            normalize=self.normalize,
        )
    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_ds,
            batch_size=self.batch_size,
            shuffle=self.shuffle_validation, # Use the instance variable
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )

    def test_dataloader(self):
        # Optional: If you set up a test set later
        return None

In [12]:
import torchvision

def get_model(model_name: str, num_classes: int, freeze_backbone: bool= True):
    """A helper function to load and prepare any classification model
    available in Torchvision for transfer learning or fine-tuning."""

    model = getattr(torchvision.models, model_name)(weights="DEFAULT")

    if freeze_backbone:
        # Set all layer to be non-trainable
        for param in model.parameters():
            param.requires_grad = False

    model_childrens = [name for name, _ in model.named_children()]

    try:
        final_layer_in_features = getattr(model, f"{model_childrens[-1]}")[-1].in_features
    except Exception as e:
        final_layer_in_features = getattr(model, f"{model_childrens[-1]}").in_features

    new_output_layer = nn.Linear(
        in_features=final_layer_in_features,
        out_features=num_classes
    )

    try:
        getattr(model, f"{model_childrens[-1]}")[-1] = new_output_layer
    except:
        setattr(model, model_childrens[-1], new_output_layer)

    return model

**Function usage example:**

In [13]:
!pip install torchinfo



In [14]:
from torchinfo import summary
import torch.nn as nn

# Suppose your ECG signals are 1000 time steps long
TIME_LENGTH = 1000

model = get_model(
    model_name=TrainingConfig.MODEL_NAME,    # Should be "resnet50"
    num_classes=DatasetConfig.NUM_CLASSES,
    freeze_backbone=False,
)

# Correctly modify the first convolutional layer to accept 12 channels
model.conv1 = nn.Conv2d(in_channels=12, out_channels=64, kernel_size=(7, 1), stride=(2, 1), padding=(3, 0), bias=False) # Reassign the layer

# Proper ECG input shape
summary(
    model,
    input_size=(TrainingConfig.BATCH_SIZE, DatasetConfig.NUM_LEADS, TIME_LENGTH, 1),  # (batch, channels=12, time, width=1)
    depth=2,
    device="cpu",
    col_names=["output_size", "num_params", "trainable"]
)

Layer (type:depth-idx)                   Output Shape              Param #                   Trainable
ResNet                                   [32, 12]                  --                        True
├─Conv2d: 1-1                            [32, 64, 500, 1]          5,376                     True
├─BatchNorm2d: 1-2                       [32, 64, 500, 1]          128                       True
├─ReLU: 1-3                              [32, 64, 500, 1]          --                        --
├─MaxPool2d: 1-4                         [32, 64, 250, 1]          --                        --
├─Sequential: 1-5                        [32, 256, 250, 1]         --                        True
│    └─Bottleneck: 2-1                   [32, 256, 250, 1]         75,008                    True
│    └─Bottleneck: 2-2                   [32, 256, 250, 1]         70,400                    True
│    └─Bottleneck: 2-3                   [32, 256, 250, 1]         70,400                    True
├─Sequential: 1-6  

In [15]:
# # Assuming 'df' is your DataFrame and 'res' is the column with labels
# class_frequencies = []
# for code in target_icd_codes:
#     # Count occurrences of the current code in the 'res' column
#     freq = df['res'].str.contains(code).sum()
#     class_frequencies.append(freq)

# # Convert the list to a PyTorch tensor
# class_frequencies = torch.tensor(class_frequencies, dtype=torch.float32)

In [16]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics import MeanMetric
from torchmetrics.classification import MultilabelF1Score
from torchvision.models import resnet50

class ECGModel(pl.LightningModule):
    def __init__(
        self,
        num_classes: int = 12,
        init_lr: float = 1e-3,
        optimizer_name: str = "Adam",
        weight_decay: float = 1e-4,
        use_scheduler: bool = False,
        f1_metric_threshold: float = 0.5,
        freeze_backbone: bool = False,
    ):
        super().__init__()

        # Save the arguments as hyperparameters.
        self.save_hyperparameters()

        # Build model
        self.model = self.get_resnet50_for_ecg(num_classes, freeze_backbone)

        # class_weights = torch.tensor([1 / freq for freq in class_frequencies])

        # Loss function
        self.loss_fn = nn.BCEWithLogitsLoss()

        # Metrics
        self.mean_train_loss = MeanMetric()
        self.mean_train_f1 = MultilabelF1Score(num_labels=num_classes, average="macro", threshold=f1_metric_threshold)
        self.mean_valid_loss = MeanMetric()
        self.mean_valid_f1 = MultilabelF1Score(num_labels=num_classes, average="macro", threshold=f1_metric_threshold)

    def get_resnet50_for_ecg(self, num_classes, freeze_backbone):
        model = resnet50(pretrained=True)

        # Modify the first conv layer to accept 12 leads instead of 3 RGB channels
        model.conv1 = nn.Conv2d(
            in_channels=12,  # ECG leads
            out_channels=64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False,
        )

        # Modify the final fully connected layer to output num_classes
        model.fc = nn.Linear(model.fc.in_features, num_classes)

        if freeze_backbone:
            for name, param in model.named_parameters():
                if "fc" not in name:  # Only leave the fc layer unfrozen
                    param.requires_grad = False

        return model

    def forward(self, x):
        # x shape: [batch_size, channels=12, time_steps]
        # ResNet expects 2D images, so we need to simulate [batch, channels, height, width]
        # Treat ECG [channels, time] as [channels, height, width=1]
        x = x.unsqueeze(-1)  # [batch, 12, time, 1]
        return self.model(x)

    def training_step(self, batch, *args, **kwargs):
        data, target = batch
        logits = self(data)
        loss = self.loss_fn(logits, target)

        self.mean_train_loss(loss, weight=data.shape[0])
        self.mean_train_f1(logits, target)

        self.log("train/batch_loss", self.mean_train_loss, prog_bar=True)
        self.log("train/batch_f1", self.mean_train_f1, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        self.log("train/loss", self.mean_train_loss, prog_bar=True)
        self.log("train/f1", self.mean_train_f1, prog_bar=True)
        self.log("step", self.current_epoch)

    def validation_step(self, batch, *args, **kwargs):
        data, target = batch
        logits = self(data)
        loss = self.loss_fn(logits, target)

        self.mean_valid_loss.update(loss, weight=data.shape[0])
        self.mean_valid_f1.update(logits, target)

    def on_validation_epoch_end(self):
        self.log("valid/loss", self.mean_valid_loss, prog_bar=True)
        self.log("valid/f1", self.mean_valid_f1, prog_bar=True)
        self.log("step", self.current_epoch)

    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.hparams.optimizer_name)(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.hparams.init_lr,
            weight_decay=self.hparams.weight_decay,
        )

        if self.hparams.use_scheduler:
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=[self.trainer.max_epochs // 2],
                gamma=0.1,
            )

            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
        else:
            return optimizer


# Dataset Initialization

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

# 1. Seed everything for reproducibility
pl.seed_everything(42, workers=True)

memmap_path = "src/data/memmap/memmap.npy"

memmap_data = np.memmap(memmap_path, dtype=np.float32, mode='r')

# Instantiate the ECGDataModule
dm = ECGDataModule(
    dataframe=df,            # Your loaded DataFrame
    memmap=memmap_data,             # Your loaded memmap
    # icd_to_index=icd_to_index,      # Your ICD code -> index mapping
    batch_size=TrainingConfig.BATCH_SIZE,
    num_workers=TrainingConfig.NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
    valid_pct=DatasetConfig.VALID_PCT,
)

# Prepare data (nothing to download for ECG, so will pass)
dm.prepare_data()

# Split dataset into training and validation sets
dm.setup()

# 4. Create ModelCheckpoint callback
model_checkpoint = ModelCheckpoint(
    monitor="valid/f1",        # Monitor validation F1 score
    mode="max",                # Maximize F1
    filename="ecg_epoch{epoch:03d}_vloss{valid/loss:.4f}_vf1{valid/f1:.4f}",
    auto_insert_metric_name=False,
    save_top_k=1,              # Save the best model only
)

# 5. Create Learning Rate Monitor callback
lr_monitor = LearningRateMonitor(logging_interval="epoch")


INFO:lightning_fabric.utilities.seed:Seed set to 42


In [None]:
# # To reload tensorBoard
# %reload_ext tensorboard

# # logs folder path
# %tensorboard --logdir=lightning_logs

**Train**

In [None]:
model = ECGModel(
    num_classes=DatasetConfig.NUM_CLASSES,
    init_lr=TrainingConfig.INIT_LR,
    optimizer_name=TrainingConfig.OPTIMIZER_NAME,
    weight_decay=TrainingConfig.WEIGHT_DECAY,
    use_scheduler=TrainingConfig.USE_SCHEDULER,
    f1_metric_threshold=TrainingConfig.F1_METRIC_THRESH,
    freeze_backbone=TrainingConfig.FREEZE_BACKBONE,
)

# Training

In [None]:
# Access train_dataloader through the dm instance
for batch in dm.train_dataloader():  # Call train_dataloader() on dm
    signals, labels = batch
    print("Sample labels:", labels[:5])
    print("Mean positive rate per label:", labels.mean(dim=0))
    break

In [None]:
# Initializing the Trainer class object.
# It uses 'Tensorboard' as its default logger.
trainer = pl.Trainer(
    accelerator="auto", # Auto select the best hardware accelerator available
    devices="auto", # Auto select available devices for the accelerator (For eg. mutiple GPUs)
    strategy="auto", # Auto select the distributed training strategy.
    max_epochs=TrainingConfig.NUM_EPOCHS, # Maximum number of epoch to train for.
    deterministic=True, # For deteministic and reproducible training.
    enable_model_summary=False, # Disable printing of model summary as we are using torchinfo.
    callbacks=[model_checkpoint, lr_monitor],  # Declaring callbacks to use.
    precision="16", # Using Mixed Precision training.
    logger=True, # Auto generate TensorBoard logs.
)

# Start training
trainer.fit(model, dm)

## 7 Inference

To perform inference, first, we need to load the best checkpoint saved during training. We can do it simply by executing the following:

In [None]:
model = ProteinModel.load_from_checkpoint(CKPT_PATH)

In [None]:
# Initialize trainer class for inference.
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    enable_checkpointing=False,
    inference_mode=True,
)

# Run evaluation.
data_module.setup()
valid_loader = data_module.val_dataloader()
trainer.validate(model=model, dataloaders=valid_loader)

## 8 Summary

Multi-label image classification is a fundamental and necessary task in many real-world scenarios where an image may contain more than one object or feature of interest. Unlike single-label classification, where each image is associated with only one label or class, multi-label classification acknowledges the inherent complexity and variety in real-world images by allowing them to be associated with multiple labels or classes simultaneously. This is particularly relevant in various domains. For example, in medical imaging, a scan may reveal multiple conditions or observations. Similarly, in social media, a photo may contain multiple people, objects, or activities. Additionally, in the field of autonomous vehicles, a single frame of a video feed may contain cars, pedestrians, signs, and more. By recognizing and categorizing multiple elements within a single image, multi-label classification provides a more comprehensive and nuanced understanding of the visual world, enabling us to build more effective and versatile AI systems.

To summarise this article📜, we covered a comprehensive list of related topics:

1. We explored image classification, highlighting the distinction between multi-class (one label per image) and multi-label (multiple labels per image) types.

2. We emphasized the unique post-processing and loss function requirements in multi-label classification, which set it apart from traditional classifications.

3. We utilized a subset of Kaggle's "Human Protein Atlas Image Classification" challenge to illustrate medical multi-label image classification in PyTorch.

4. We streamlined our code and improve readability using the PyTorch-Lightning library, which simplifies PyTorch's complex aspects.

5.We leveraged the pre-trained EfficientNetv2-small model from torchvision as our starting point and then fine-tune it for our specific task.

6. We designed a user-friendly interface using the Gradio app, making our medical multi-label image classification model accessible to everyone.