<a href="https://colab.research.google.com/github/Zfeng0207/FIT3199-FYP/blob/dev%2Fzfeng/hpa_pt_lightning_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [330]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [331]:
import os
os.chdir('/content/drive/MyDrive/Colab Notebooks/ECG-MIMIC-main')

In [332]:
# !pip install -qqqq mlflow torchmetrics pytorch_lightning iterative-stratification

# Step 3.5: Multihot encode: Setting up target binary labels

In [333]:
import numpy as np

def multihot_encode(diagnoses, icd_codes):
    """
    Encodes a list of diagnoses into a multi-hot vector.
    Disregards rows where no target ICD codes are set for the sample.

    Args:
        diagnoses (str or list): A string representing a list of diagnoses or a list of diagnoses.
        icd_codes (tuple): A tuple of ICD codes to be encoded.

    Returns:
        np.ndarray or None: A multi-hot vector representing the diagnoses,
                            or None if no target ICD codes are found.
    """
    num_classes = len(icd_codes)
    res = np.zeros(num_classes, dtype=np.float32)

    # Ensure diagnoses is a list
    if isinstance(diagnoses, str):
        diagnoses = diagnoses.strip('[]').replace("'", "").split(",")  # Handle list-like strings
        diagnoses = [d.strip() for d in diagnoses]  # Remove any whitespace around code

    found_target = False  # Flag to track if any target ICD code is found

    for diag in diagnoses:
        for i, code in enumerate(icd_codes):
            if diag.startswith(code):
                res[i] = 1
                found_target = True
                break  # Break inner loop if target is found
        if found_target:
            break  # Break outer loop if target is found

    if not found_target:
        return None  # Return None if no target ICD code is found

    return res

In [334]:
import numpy as np
import pandas as pd

df_full = pd.read_csv("src/data/label_df.csv")

df_labels =  df_full[["filename",
            "study_id",
            "patient_id",
            "ecg_time",
            "label_train",
            "all_diag_all",
            "label_stroke",
            "start",
            "length"]]

target_icd_codes = (
   "I48", "E11"
)

df_labels['res'] = df_labels['label_train'].apply(lambda diagnoses: multihot_encode(diagnoses, target_icd_codes))

# df_labels['stroke_yn'] = df_labels['res'].apply(lambda x: 1 if 1 in x else 0)

df_rm_nan = df_labels[df_labels['all_diag_all'].apply(lambda x: len(x) > 0)]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_labels['res'] = df_labels['label_train'].apply(lambda diagnoses: multihot_encode(diagnoses, target_icd_codes))


# Number of sparse target class

In [335]:
import pandas as pd

# df = pd.read_csv("src/data/label_df.csv")
df = df_rm_nan.copy()

In [336]:
df = df[df['res'].notna()]

In [337]:
df.shape

(157501, 10)

In [338]:
import numpy as np
import pandas as pd

def calculate_mean_positive_rate(df, label_col="res"):
    """
    Calculates the mean positive rate per label for a pandas DataFrame.

    Args:
        df (pd.DataFrame): The DataFrame containing the multi-hot labels.
        label_col (str): The column name containing the multi-hot labels.

    Returns:
        np.ndarray: An array containing the mean positive rate for each label.
    """

    # Extract labels and ensure they are NumPy arrays with consistent shape
    labels = df[label_col].apply(lambda x: np.array(x, dtype=np.float32)).values
    labels = np.vstack(labels)  # Stack the labels into a 2D array


    # Calculate mean positive rate per label
    mean_positive_rate = labels.mean(axis=0)

    return mean_positive_rate

# Assuming 'df' is your DataFrame
mean_positive_rates = calculate_mean_positive_rate(df)

# Print the results
print("Mean positive rate per label:", mean_positive_rates)

Mean positive rate per label: [0.49932382 0.5006762 ]


In [339]:
def count_empty_labels(df, label_col="res"):
    """
    Counts how many samples in the DataFrame have all-zero labels.

    Args:
        df (pd.DataFrame): DataFrame containing the dataset
        label_col (str): Column name containing the multi-hot labels

    Returns:
        int: Number of rows with all-zero labels
    """
    empty_count = 0

    for label_str in df[label_col]:
        if label_str.sum() == 0:
            empty_count += 1

    return empty_count


In [340]:
import numpy as np

num_empty = count_empty_labels(df, label_col="res")
print(f"There are {num_empty} samples with all-zero labels.")

There are 0 samples with all-zero labels.


# Configurations

In [527]:
from dataclasses import dataclass
import os
import platform

# You can define ROOT_PATH somewhere above
ROOT_PATH = "/content/drive/MyDrive/Colab Notebooks (1)/ECG-MIMIC-main/src"

@dataclass
class DatasetConfig:
    # ECG-specific
    NUM_LEADS:    int = 12  # 12 ECG channels (leads)
    NUM_CLASSES:  int = 2  # 12 ICD disease codes
    VALID_PCT:  float = 0.1

    # Dataset file and folder paths
    TRAIN_CSV:   str = os.path.join(ROOT_PATH, "data/train.csv")  # Your preprocessed split CSV
    TEST_CSV:    str = os.path.join(ROOT_PATH, "data/test.csv")
    MEMMAP_FILE: str = os.path.join(ROOT_PATH, "ecg_dataset", "data/memmap/memmap.npy")
    MEMMAP_META: str = os.path.join(ROOT_PATH, "ecg_dataset", "data/memmap/memmap_meta.npz")

@dataclass
class TrainingConfig:
    BATCH_SIZE:      int = 64
    NUM_EPOCHS:      int = 30  # Actual training epochs
    INIT_LR:       float = 1e-3
    NUM_WORKERS:     int = 7
    OPTIMIZER_NAME:  str = "Adam"
    WEIGHT_DECAY:  float = 1e-4
    USE_SCHEDULER:  bool = True
    SCHEDULER:       str = "multi_step_lr"  # or "cosine_annealing"
    F1_METRIC_THRESH: float = 0.5
    FREEZE_BACKBONE: bool = False

    # (Optional) model name (if you want to log it somewhere)
    MODEL_NAME:      str = "resnet18"


In [528]:
def encode_label(label: list, num_classes=10):
    """
    This functions converts labels into multi-hot encoding.
    Handles both single ICD codes and lists of codes.
    """
    target = torch.zeros(num_classes)

    # If label is a single code, make it a list
    if isinstance(label, str):
        label = [label]

    for l in label:
        # Check if 'l' contains brackets (indicating list within a string)
        if '[' in l or ']' in l:
            l = l.strip('[]').replace("'", "").split(",")  # Handle list-like strings
            for code in l:
                code = code.strip()  # Remove any whitespace around code
                if code in icd_to_index:
                    target[icd_to_index[code]] = 1.0
        else:
            l = l.strip()  # Remove any whitespace around code
            if l in icd_to_index:
                target[icd_to_index[l]] = 1.0
    return target


def decode_target(
    target: list,
    text_labels: bool = False,
    threshold: float = 0.4,
    cls_labels: dict = None,
):
    """This function converts the labels from
    probablities to outputs or string representations
    """

    result = []
    for i, x in enumerate(target):
        if x >= threshold:
            if text_labels:
                result.append(cls_labels[i] + "(" + str(i) + ")")
            else:
                result.append(str(i))
    return " ".join(result)


# This function is used for reversing the Normalization step performed
# during image preprocessing.
# Note the mean and std values must match the ones used.

def denormalize(tensors, *, mean, std):
    """Denormalizes image tensors using mean and std provided
    and clip values between 0 and 1"""

    for c in range(DatasetConfig.CHANNELS):
        tensors[:, c, :, :].mul_(std[c]).add_(mean[c])

    return torch.clamp(tensors, min=0.0, max=1.0)

In [529]:
# Create a dictionary mapping ICD codes to index
# icd_to_index = {code: idx for idx, code in enumerate(target_icd_codes)}


# Dataset

In [530]:
import torch
from torch.utils.data import Dataset
import numpy as np

class ECGDataset(Dataset):
    def __init__(self, dataframe, memmap, normalize=True):
        self.df = dataframe.reset_index(drop=True)
        self.memmap = memmap
        self.normalize = normalize
        self.num_classes = DatasetConfig.NUM_CLASSES

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        start = int(row['start'])
        length = int(row['length'])

        signal = self.memmap[start : start + length * 12]
        signal = signal.reshape(length, 12)

        # Check for NaN or infinity values
        if np.isnan(signal).any() or np.isinf(signal).any():
            # Skip this sample and return an empty tensor
            return torch.empty((12, 0)), torch.zeros(self.num_classes)

        if self.normalize:
            signal = (signal - np.nanmean(signal, axis=0)) / (np.nanstd(signal, axis=0) + 1e-6)

        signal = torch.tensor(signal, dtype=torch.float32).permute(1, 0)  # [12, time]
        label = row['res']

        # Return label as float32 and the signal without padding
        return signal, torch.tensor(label, dtype=torch.float32)


# Data Module

In [556]:
def pad_collate(batch):
    """
    Pads the ECG signals in a batch to the maximum length.

    Args:
        batch: A list of tuples (signal, label).

    Returns:
        A tuple containing the padded signals and labels.
    """
    signals, labels = zip(*batch)

    # Find the maximum length
    max_len = max(signal.shape[1] for signal in signals)

    # Pad the signals
    padded_signals = [torch.nn.functional.pad(signal, (0, max_len - signal.shape[1]), 'constant', 0) for signal in signals]

    # Stack the signals and labels
    signals = torch.stack(padded_signals)
    labels = torch.stack(labels)

    return signals, labels


In [557]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import numpy as np
import os
import pandas as pd
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold, MultilabelStratifiedShuffleSplit

class ECGDataModule(pl.LightningDataModule):
    def __init__(self, dataframe, memmap, batch_size, num_workers, pin_memory, valid_pct, normalize=True, shuffle_validation=False): # Add shuffle_validation as a parameter
        super().__init__()
        self.dataframe = dataframe  # Full df
        self.memmap = memmap
        self.normalize = normalize
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.valid_pct = valid_pct
        self.shuffle_validation = shuffle_validation

    def setup(self, stage=None):
        label_cols = 'res'
        np.random.seed(42)

        # 1. Prepare your multi-hot label matrix
        # Y = np.vstack(self.dataframe[label_cols].values)
        Y = np.vstack(self.dataframe[label_cols].apply(lambda x: np.array(x, dtype=np.float32)).values)

        # 2. First split into train_val and test
        splitter = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
        train_val_idx, test_idx = next(splitter.split(self.dataframe, Y))

        df_train_val = self.dataframe.iloc[train_val_idx].reset_index(drop=True)
        df_test = self.dataframe.iloc[test_idx].reset_index(drop=True)

        # 3. Now split train_val into train and val
        splitter_val = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
        # Y_train_val_wrong = df_train_val[label_cols].values
        # Y_train_val = np.vstack([np.fromstring(row[0][1:-1], sep=' ', dtype=np.float64) for row in Y_train_val_wrong])
        Y_train_val = np.vstack(df_train_val[label_cols].apply(lambda x: np.array(x, dtype=np.float32)).values)


        train_idx, val_idx = next(splitter_val.split(df_train_val, Y_train_val))

        df_train = df_train_val.iloc[train_idx].reset_index(drop=True)
        df_val = df_train_val.iloc[val_idx].reset_index(drop=True)

        # 4. Create datasets
        self.train_ds = ECGDataset(
            dataframe=df_train,
            memmap=self.memmap,
            normalize=self.normalize,
        )

        self.valid_ds = ECGDataset(
            dataframe=df_val,
            memmap=self.memmap,
            normalize=self.normalize,
        )

        self.test_ds = ECGDataset(
            dataframe=df_test,
            memmap=self.memmap,
            normalize=self.normalize,
        )
    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=pad_collate, # Apply padding collate function
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_ds,
            batch_size=self.batch_size,
            shuffle=self.shuffle_validation, # Use the instance variable
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=pad_collate, # Apply padding collate function

        )

    def test_dataloader(self):
        # Optional: If you set up a test set later
        return None

In [558]:
import torchvision

def get_model(model_name: str, num_classes: int, freeze_backbone: bool= True):
    """A helper function to load and prepare any classification model
    available in Torchvision for transfer learning or fine-tuning."""

    model = getattr(torchvision.models, model_name)(weights="DEFAULT")

    if freeze_backbone:
        # Set all layer to be non-trainable
        for param in model.parameters():
            param.requires_grad = False

    model_childrens = [name for name, _ in model.named_children()]

    try:
        final_layer_in_features = getattr(model, f"{model_childrens[-1]}")[-1].in_features
    except Exception as e:
        final_layer_in_features = getattr(model, f"{model_childrens[-1]}").in_features

    new_output_layer = nn.Linear(
        in_features=final_layer_in_features,
        out_features=num_classes
    )

    try:
        getattr(model, f"{model_childrens[-1]}")[-1] = new_output_layer
    except:
        setattr(model, model_childrens[-1], new_output_layer)

    return model

**Function usage example:**

In [559]:
!pip install torchinfo



In [560]:
from torchinfo import summary
import torch.nn as nn

# Suppose your ECG signals are 1000 time steps long
TIME_LENGTH = 1000

model = get_model(
    model_name=TrainingConfig.MODEL_NAME,    # Should be "resnet50"
    num_classes=DatasetConfig.NUM_CLASSES,
    freeze_backbone=False,
)

# Correctly modify the first convolutional layer to accept 12 channels
model.conv1 = nn.Conv2d(in_channels=12, out_channels=64, kernel_size=(7, 1), stride=(2, 1), padding=(3, 0), bias=False) # Reassign the layer

# Proper ECG input shape
summary(
    model,
    input_size=(TrainingConfig.BATCH_SIZE, DatasetConfig.NUM_LEADS, TIME_LENGTH, 1),  # (batch, channels=12, time, width=1)
    depth=2,
    device="cpu",
    col_names=["output_size", "num_params", "trainable"]
)

Layer (type:depth-idx)                   Output Shape              Param #                   Trainable
ResNet                                   [64, 2]                   --                        True
├─Conv2d: 1-1                            [64, 64, 500, 1]          5,376                     True
├─BatchNorm2d: 1-2                       [64, 64, 500, 1]          128                       True
├─ReLU: 1-3                              [64, 64, 500, 1]          --                        --
├─MaxPool2d: 1-4                         [64, 64, 250, 1]          --                        --
├─Sequential: 1-5                        [64, 64, 250, 1]          --                        True
│    └─BasicBlock: 2-1                   [64, 64, 250, 1]          73,984                    True
│    └─BasicBlock: 2-2                   [64, 64, 250, 1]          73,984                    True
├─Sequential: 1-6                        [64, 128, 125, 1]         --                        True
│    └─BasicBlock: 

In [561]:
# # Assuming 'df' is your DataFrame and 'res' is the column with labels
# class_frequencies = []
# for code in target_icd_codes:
#     # Count occurrences of the current code in the 'res' column
#     freq = df['res'].str.contains(code).sum()
#     class_frequencies.append(freq)

# # Convert the list to a PyTorch tensor
# class_frequencies = torch.tensor(class_frequencies, dtype=torch.float32)

# Model

In [562]:
import pytorch_lightning as pl
import torch

class Swish(pl.LightningModule):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)

## ConvNormPool

In [563]:
class ConvNormPool(pl.LightningModule):
    """Conv Skip-connection module"""
    def __init__(
        self,
        input_size,
        hidden_size,
        kernel_size,
        norm_type='bachnorm'
    ):
        super().__init__()

        self.kernel_size = kernel_size
        self.conv_1 = nn.Conv1d(
            in_channels=input_size,
            out_channels=hidden_size,
            kernel_size=kernel_size
        )
        self.conv_2 = nn.Conv1d(
            in_channels=hidden_size,
            out_channels=hidden_size,
            kernel_size=kernel_size
        )
        self.conv_3 = nn.Conv1d(
            in_channels=hidden_size,
            out_channels=hidden_size,
            kernel_size=kernel_size
        )
        self.swish_1 = Swish()
        self.swish_2 = Swish()
        self.swish_3 = Swish()
        if norm_type == 'group':
            self.normalization_1 = nn.GroupNorm(
                num_groups=8,
                num_channels=hidden_size
            )
            self.normalization_2 = nn.GroupNorm(
                num_groups=8,
                num_channels=hidden_size
            )
            self.normalization_3 = nn.GroupNorm(
                num_groups=8,
                num_channels=hidden_size
            )
        else:
            self.normalization_1 = nn.BatchNorm1d(num_features=hidden_size)
            self.normalization_2 = nn.BatchNorm1d(num_features=hidden_size)
            self.normalization_3 = nn.BatchNorm1d(num_features=hidden_size)

        self.pool = nn.MaxPool1d(kernel_size=2)

    def forward(self, input):
        conv1 = self.conv_1(input)
        x = self.normalization_1(conv1)
        x = self.swish_1(x)
        x = F.pad(x, pad=(self.kernel_size - 1, 0))

        x = self.conv_2(x)
        x = self.normalization_2(x)
        x = self.swish_2(x)
        x = F.pad(x, pad=(self.kernel_size - 1, 0))

        conv3 = self.conv_3(x)
        x = self.normalization_3(conv1+conv3)
        x = self.swish_3(x)
        x = F.pad(x, pad=(self.kernel_size - 1, 0))

        x = self.pool(x)
        return x


# CNN

In [564]:
class CNN(pl.LightningModule):
    def __init__(
        self,
        input_size = 1,
        hid_size = 256,
        kernel_size = 5,
        num_classes = 5,
    ):

        super().__init__()

        self.conv1 = ConvNormPool(
            input_size=input_size,
            hidden_size=hid_size,
            kernel_size=kernel_size,
        )
        self.conv2 = ConvNormPool(
            input_size=hid_size,
            hidden_size=hid_size//2,
            kernel_size=kernel_size,
        )
        self.conv3 = ConvNormPool(
            input_size=hid_size//2,
            hidden_size=hid_size//4,
            kernel_size=kernel_size,
        )
        self.avgpool = nn.AdaptiveAvgPool1d((1))
        self.fc = nn.Linear(in_features=hid_size//4, out_features=num_classes)

    def forward(self, input):
        x = self.conv1(input)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.avgpool(x)
        # print(x.shape) # num_features * num_channels
        x = x.view(-1, x.size(1) * x.size(2))
        x = F.softmax(self.fc(x), dim=1)
        return x


# RNN

In [565]:
class RNN(pl.LightningModule):
    """RNN module(cell type lstm or gru)"""
    def __init__(
        self,
        input_size,
        hid_size,
        num_rnn_layers=1,
        dropout_p = 0.2,
        bidirectional = False,
        rnn_type = 'lstm',
    ):
        super().__init__()

        if rnn_type == 'lstm':
            self.rnn_layer = nn.LSTM(
                input_size=input_size,
                hidden_size=hid_size,
                num_layers=num_rnn_layers,
                dropout=dropout_p if num_rnn_layers>1 else 0,
                bidirectional=bidirectional,
                batch_first=True,
            )

        else:
            self.rnn_layer = nn.GRU(
                input_size=input_size,
                hidden_size=hid_size,
                num_layers=num_rnn_layers,
                dropout=dropout_p if num_rnn_layers>1 else 0,
                bidirectional=bidirectional,
                batch_first=True,
            )
    def forward(self, input):
        outputs, hidden_states = self.rnn_layer(input)
        return outputs, hidden_states


# RNN Model

In [566]:
class RNNModel(pl.LightningModule):
    def __init__(
        self,
        input_size,
        hid_size,
        rnn_type,
        bidirectional,
        n_classes=5,
        kernel_size=5,
    ):
        super().__init__()

        self.rnn_layer = RNN(
            input_size=46,#hid_size * 2 if bidirectional else hid_size,
            hid_size=hid_size,
            rnn_type=rnn_type,
            bidirectional=bidirectional
        )
        self.conv1 = ConvNormPool(
            input_size=input_size,
            hidden_size=hid_size,
            kernel_size=kernel_size,
        )
        self.conv2 = ConvNormPool(
            input_size=hid_size,
            hidden_size=hid_size,
            kernel_size=kernel_size,
        )
        self.avgpool = nn.AdaptiveAvgPool1d((1))
        self.fc = nn.Linear(in_features=hid_size, out_features=n_classes)

    def forward(self, input):
        x = self.conv1(input)
        x = self.conv2(x)
        x, _ = self.rnn_layer(x)
        x = self.avgpool(x)
        x = x.view(-1, x.size(1) * x.size(2))
        x = F.sigmoid(self.fc(x), dim=1)#.squeeze(1)
        return x


# RNN Attention Model

In [567]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics.classification import MultilabelAccuracy, MultilabelF1Score, MultilabelAUROC

class RNNAttentionModel(pl.LightningModule):
    def __init__(
        self,
        hid_size =64,
        rnn_type = 'lstm',
        bidirectional=False,
        num_classes=2,
        input_size =12,
        kernel_size=5,
        lr=1e-3,
        f1_metric_threshold=0.5,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.conv1 = ConvNormPool(
            input_size=input_size,
            hidden_size=hid_size,
            kernel_size=kernel_size,
        )
        self.conv2 = ConvNormPool(
            input_size=hid_size,
            hidden_size=hid_size,
            kernel_size=kernel_size,
        )

        self.rnn_layer = RNN(
            input_size=hid_size,
            hid_size=hid_size,
            rnn_type=rnn_type,
            bidirectional=bidirectional
        )

        self.attn = nn.Linear(hid_size, hid_size, bias=False)
        self.fc = nn.Linear(in_features=hid_size, out_features=num_classes)  # Multi-label output
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.lr = lr

        # Metrics
        self.train_acc = MultilabelAccuracy(num_labels=num_classes, threshold=f1_metric_threshold)
        self.train_f1 = MultilabelF1Score(num_labels=num_classes, average="macro", threshold=f1_metric_threshold)
        self.train_auc = MultilabelAUROC(num_labels=num_classes)

        self.val_acc = MultilabelAccuracy(num_labels=num_classes, threshold=f1_metric_threshold)
        self.val_f1 = MultilabelF1Score(num_labels=num_classes, average="macro", threshold=f1_metric_threshold)
        self.val_auc = MultilabelAUROC(num_labels=num_classes)

    def forward(self, input):
        # input = input.permute(0, 2, 1)  # Remove this line - permutation is done in the dataset
        x = self.conv1(input)
        x = self.conv2(x)
        x = x.permute(0, 2, 1)  # Permute before the RNN layer

        x_out, _ = self.rnn_layer(x)

        attn_weights = torch.softmax(self.attn(x_out), dim=1)
        x = torch.sum(attn_weights * x_out, dim=1)

        logits = self.fc(x)
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.float())
        probs = torch.sigmoid(logits)

        acc = self.train_acc(probs, y.int())
        f1 = self.train_f1(probs, y.int())
        auc = self.train_auc(probs, y.int())

        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        self.log("train_f1", f1, prog_bar=True)
        self.log("train_auc", auc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.float())
        probs = torch.sigmoid(logits)

        acc = self.val_acc(probs, y.int())
        f1 = self.val_f1(probs, y.int())
        auc = self.val_auc(probs, y.int())

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        self.log("val_f1", f1, prog_bar=True)
        self.log("val_auc", auc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

# Dataset Initialization

In [568]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

# 1. Seed everything for reproducibility
pl.seed_everything(42, workers=True)

memmap_path = "src/data/memmap/memmap.npy"

memmap_data = np.memmap(memmap_path, dtype=np.float32, mode='r')

# Instantiate the ECGDataModule
dm = ECGDataModule(
    dataframe=df,            # Your loaded DataFrame
    memmap=memmap_data,             # Your loaded memmap
    # icd_to_index=icd_to_index,      # Your ICD code -> index mapping
    batch_size=TrainingConfig.BATCH_SIZE,
    num_workers=TrainingConfig.NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
    valid_pct=DatasetConfig.VALID_PCT,
)

# Prepare data (nothing to download for ECG, so will pass)
dm.prepare_data()

# Split dataset into training and validation sets
dm.setup()

# 4. Create ModelCheckpoint callback
model_checkpoint = ModelCheckpoint(
    monitor="valid/f1",        # Monitor validation F1 score
    mode="max",                # Maximize F1
    filename="ecg_epoch{epoch:03d}_vloss{valid/loss:.4f}_vf1{valid/f1:.4f}",
    auto_insert_metric_name=False,
    save_top_k=1,              # Save the best model only
)

# 5. Create Learning Rate Monitor callback
lr_monitor = LearningRateMonitor(logging_interval="epoch")


INFO:lightning_fabric.utilities.seed:Seed set to 42


In [569]:
# # To reload tensorBoard
# %reload_ext tensorboard

# # logs folder path
# %tensorboard --logdir=lightning_logs

**Train**

# Training

In [570]:
# Access train_dataloader through the dm instance
for batch in dm.train_dataloader():  # Call train_dataloader() on dm
    signals, labels = batch
    print("Sample labels:", labels[:5])
    print("Mean positive rate per label:", labels.mean(dim=0))
    break

Sample labels: tensor([[1., 0.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [1., 0.]])
Mean positive rate per label: tensor([0.3906, 0.6094])


In [571]:
# Assuming 'dm' is your ECGDataModule instance
train_loader = dm.train_dataloader()

# 1. Using len() on the dataloader:
num_batches = len(train_loader)
print(f"Number of batches in train_dataloader: {num_batches}")

# 2. Calculating total samples from batch size and num_batches:
total_samples = num_batches * train_loader.batch_size
print(f"Estimated total samples in training dataset: {total_samples}")

# 3. Accessing the underlying dataset directly (more accurate):
total_samples_accurate = len(train_loader.dataset)
print(f"Actual total samples in training dataset: {total_samples_accurate}")


Number of batches in train_dataloader: 1772
Estimated total samples in training dataset: 113408
Actual total samples in training dataset: 113400


In [572]:
model = RNNAttentionModel()

In [573]:
# Initializing the Trainer class object.
# It uses 'Tensorboard' as its default logger.
trainer = pl.Trainer(
    accelerator="auto", # Auto select the best hardware accelerator available
    devices="auto", # Auto select available devices for the accelerator (For eg. mutiple GPUs)
    strategy="auto", # Auto select the distributed training strategy.
    max_epochs=TrainingConfig.NUM_EPOCHS, # Maximum number of epoch to train for.
    deterministic=True, # For deteministic and reproducible training.
    enable_model_summary=False, # Disable printing of model summary as we are using torchinfo.
    callbacks=[model_checkpoint, lr_monitor],  # Declaring callbacks to use.
    precision="16", # Using Mixed Precision training.
    logger=True, # Auto generate TensorBoard logs.
)

# Start training
trainer.fit(model, dm)

/usr/local/lib/python3.11/dist-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

MisconfigurationException: `ModelCheckpoint(monitor='valid/f1')` could not find the monitored key in the returned metrics: ['lr-Adam', 'train_loss', 'train_acc', 'train_f1', 'train_auc', 'val_loss', 'val_acc', 'val_f1', 'val_auc', 'epoch', 'step']. HINT: Did you call `log('valid/f1', value)` in the `LightningModule`?

## 7 Inference

To perform inference, first, we need to load the best checkpoint saved during training. We can do it simply by executing the following:

In [None]:
model = ProteinModel.load_from_checkpoint(CKPT_PATH)

In [None]:
# Initialize trainer class for inference.
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    enable_checkpointing=False,
    inference_mode=True,
)

# Run evaluation.
data_module.setup()
valid_loader = data_module.val_dataloader()
trainer.validate(model=model, dataloaders=valid_loader)

## 8 Summary

Multi-label image classification is a fundamental and necessary task in many real-world scenarios where an image may contain more than one object or feature of interest. Unlike single-label classification, where each image is associated with only one label or class, multi-label classification acknowledges the inherent complexity and variety in real-world images by allowing them to be associated with multiple labels or classes simultaneously. This is particularly relevant in various domains. For example, in medical imaging, a scan may reveal multiple conditions or observations. Similarly, in social media, a photo may contain multiple people, objects, or activities. Additionally, in the field of autonomous vehicles, a single frame of a video feed may contain cars, pedestrians, signs, and more. By recognizing and categorizing multiple elements within a single image, multi-label classification provides a more comprehensive and nuanced understanding of the visual world, enabling us to build more effective and versatile AI systems.

To summarise this article📜, we covered a comprehensive list of related topics:

1. We explored image classification, highlighting the distinction between multi-class (one label per image) and multi-label (multiple labels per image) types.

2. We emphasized the unique post-processing and loss function requirements in multi-label classification, which set it apart from traditional classifications.

3. We utilized a subset of Kaggle's "Human Protein Atlas Image Classification" challenge to illustrate medical multi-label image classification in PyTorch.

4. We streamlined our code and improve readability using the PyTorch-Lightning library, which simplifies PyTorch's complex aspects.

5.We leveraged the pre-trained EfficientNetv2-small model from torchvision as our starting point and then fine-tune it for our specific task.

6. We designed a user-friendly interface using the Gradio app, making our medical multi-label image classification model accessible to everyone.