In [2]:
# Fine-tuning SL model for AFIB at 2y PTB-XL lite dataset
import os
import torch
import random
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics.classification import (
    BinaryAccuracy, 
    BinaryAUROC, 
    BinaryAveragePrecision
)
from torch.utils.data import Dataset
from huggingface_hub import snapshot_download

# Define device
device = torch.device('cuda:0')

# Define hugging face token
hugging_face_token = "" # Set your Hugging Face token here

  from .autonotebook import tqdm as notebook_tqdm


# Set seed for reproducibility

In [3]:
def seed_everything(seed: int):
    """
    Seed everything for reproducibility.

    Parameters:
    seed (int): The seed to set for reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

# Load Model

In [4]:
# Download model from Hugging Face
local_dir = snapshot_download(
    repo_id="heartwise/EfficientNetV2_SL_Model_Tunable", 
    local_dir=".", 
    repo_type="model", 
    token=hugging_face_token
)

Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 9992.39it/s]


In [5]:
from EfficientNetv2 import *

In [6]:
sl_model = torch.load("sl_model.h5")

for param in sl_model.parameters():
    param.requires_grad = True

# Set up new classifier and reset weights

In [7]:
def reinitialize_weights(m):
    if isinstance(m, nn.Conv1d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm1d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.AdaptiveAvgPool1d):
        # AdaptiveAvgPool1d has no parameters to initialize
        pass
    elif isinstance(m, nn.Dropout):
        # Dropout has no parameters to initialize
        pass

sl_model = sl_model.apply(reinitialize_weights)


class NewClassifier(nn.Module):
    def __init__(self):
        super(NewClassifier, self).__init__()
        self.pool = nn.AdaptiveAvgPool1d(output_size=1)
        self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
        #self.dropout1 = nn.Dropout(p=0.2)
        self.fc1 = nn.Linear(640, 1)

    def forward(self, x):
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return x


sl_model.classifier = NewClassifier()
for param in sl_model.classifier.parameters():
    param.requires_grad = True


sl_model = sl_model.to(device)

# Prepare Dataset

In [15]:
# Download ptb-xl lite data from Hugging Face
dataset_dir = snapshot_download(
    repo_id="heartwise/PTB-XL_lite", 
    local_dir="ptb-xl_lite", 
    repo_type="dataset", 
    token=hugging_face_token
)

Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 21094.57it/s]


In [16]:
train_ds_X = np.squeeze(
    np.load(os.path.join(dataset_dir, "ptb_xl_afib_2y_train_subset.npy")),
)
val_ds_X = np.squeeze(
    np.load(os.path.join(dataset_dir, "ptb_xl_afib_2y_val_subset.npy")),
)

In [17]:
train_df = pd.read_parquet(
    os.path.join(dataset_dir, "ptb_xl_afib_2y_train_subset_labels.parquet")
)
val_df = pd.read_parquet(
    os.path.join(dataset_dir, "ptb_xl_afib_2y_val_subset_labels.parquet")
)

train_ds_Y = train_df['label_2y'].astype(int).tolist()
val_ds_Y = val_df['label_2y'].astype(int).tolist()

# Define DataLoader

In [18]:

train_params = {
        'batch_size': 254,
        'shuffle': True,
        'num_workers': 12,
        'pin_memory':True,
        'multiprocessing_context': 'fork',
        'persistent_workers': True
    }

val_params = {
        'batch_size': 254,
        'shuffle': False,
        'num_workers': 12,
        'pin_memory':True,
        'multiprocessing_context': 'fork',
        'persistent_workers': True
    }

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        x = np.swapaxes(x, 0, 1)
        
        # Convert numpy arrays to torch tensors
        x = torch.tensor(x, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.long)

        return x, y
    
train_set = CustomDataset(train_ds_X, train_ds_Y)
dataloader = torch.utils.data.DataLoader(train_set, **train_params)

val_set = CustomDataset(val_ds_X, val_ds_Y)
val_dataloader = torch.utils.data.DataLoader(val_set, **val_params)

# Use Binary Focal Loss for imbalanced data

In [19]:
class BinaryFocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        """
        Initialize the Binary Focal Loss function.

        Parameters:
        - alpha: Weighting factor for the rare class (usually the minority class), default is 0.25.
        - gamma: Focusing parameter that adjusts the rate at which easy examples are down-weighted, default is 2.0.
        - reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default is 'mean'.
        """
        super(BinaryFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        Forward pass of the loss function.

        Parameters:
        - inputs: Predicted logits (before applying sigmoid), shape (batch_size, 1).
        - targets: Ground truth labels, shape (batch_size, 1).

        Returns:
        - loss: Calculated focal loss.
        """
        # Convert targets to float
        targets = targets.float()
        
        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(inputs)

        # Calculate binary cross-entropy loss
        bce_loss = F.binary_cross_entropy(probs, targets, reduction='none')

        # Calculate the modulating factor (1 - p_t)^gamma
        p_t = probs * targets + (1 - probs) * (1 - targets)
        modulating_factor = torch.pow(1 - p_t, self.gamma)

        # Apply the alpha factor
        alpha_factor = self.alpha * targets + (1 - self.alpha) * (1 - targets)

        # Combine factors to compute the final focal loss
        focal_loss = alpha_factor * modulating_factor * bce_loss

        # Apply reduction method
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Define Metrics and Training/Validation Loop

In [20]:
def train_and_evaluate_model(
    model, 
    dataloader, 
    val_dataloader, 
    optimizer, 
    criterion, 
    scheduler, 
    device, 
    num_epochs=10, 
    patience=3
    ):
    # Initialize metrics
    accuracy = BinaryAccuracy().to(device)
    auroc = BinaryAUROC().to(device)
    auprc = BinaryAveragePrecision().to(device)
    model = model.to(device)

    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0

        # Reset metrics at the start of each epoch
        accuracy.reset()
        auroc.reset()
        auprc.reset()

        # Training loop
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f'Epoch {epoch+1}/{num_epochs} (Training)', leave=False)
        for i, (inputs, labels) in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.squeeze()  # Squeeze to match the shape of labels
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            avg_loss = running_loss / (i + 1)

            # Update metrics
            accuracy.update(outputs, labels)
            auroc.update(outputs, labels)
            auprc.update(outputs, labels)

            # Update tqdm progress bar with aggregate metrics
            progress_bar.set_postfix(
                loss=avg_loss,
                acc=accuracy.compute().item(),
                auroc=auroc.compute().item(),
                auprc=auprc.compute().item()
            )

        # Free GPU memory after each training epoch
        torch.cuda.empty_cache()

        # Validation loop
        model.eval()
        val_running_loss = 0.0

        accuracy.reset()
        auroc.reset()
        auprc.reset()

        with torch.no_grad():
            progress_bar = tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc=f'Epoch {epoch+1}/{num_epochs} (Validation)', leave=False)
            for i, (inputs, labels) in progress_bar:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                outputs = outputs.squeeze()  # Squeeze to match the shape of labels
                loss = criterion(outputs, labels)

                val_running_loss += loss.item()
                avg_val_loss = val_running_loss / (i + 1)

                # Update metrics
                accuracy.update(outputs, labels)
                auroc.update(outputs, labels)
                auprc.update(outputs, labels)

                # Update tqdm progress bar with validation metrics
                progress_bar.set_postfix(
                    val_loss=avg_val_loss,
                    val_acc=accuracy.compute().item(),
                    val_auroc=auroc.compute().item(),
                    val_auprc=auprc.compute().item()
                )

        # Free GPU memory after each validation epoch
        torch.cuda.empty_cache()

        # End of epoch logging (including learning rate)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch+1}/{num_epochs}, LR: {current_lr}, Training Loss: {avg_loss}, Validation Loss: {avg_val_loss}")
        print(f"Validation Accuracy: {accuracy.compute().item()}")
        print(f"Validation AUROC: {auroc.compute().item()}")
        print(f"Validation AUPRC: {auprc.compute().item()}")

        # Check if the validation loss improved
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            best_model_state = model.state_dict()  # Save the best model's state
        else:
            patience_counter += 1

        # Early stopping check
        if patience_counter >= patience:
            print(f"Early stopping after {epoch+1} epochs due to no improvement in validation loss.")
            break

        # Scheduler step with warm restarts
        scheduler.step()

        # Switch back to training mode
        model.train()

    # Load the best model state (if needed later)
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    # Free GPU memory at the end of training
    torch.cuda.empty_cache()

    return model

# Run Fine tuning

In [21]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, sl_model.parameters()), lr=1e-3)
criterion = BinaryFocalLoss(gamma=2)
scheduler = CosineAnnealingLR(optimizer, T_max=15, eta_min=0)

e2e_sl_model_afib_2y = train_and_evaluate_model(
    sl_model, 
    dataloader, 
    val_dataloader, 
    optimizer, 
    criterion, 
    scheduler, 
    device, 
    num_epochs=50, 
    patience=5
)


torch.save(e2e_sl_model_afib_2y, "e2e_sl_model_afib_2y.h5")

                                                                                                                                       

Epoch 1/50, LR: 0.001, Training Loss: 0.10344375111162663, Validation Loss: 0.11459421087056398
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5074581503868103
Validation AUPRC: 0.09703006595373154


                                                                                                                                         

Epoch 2/50, LR: 0.0009890738003669028, Training Loss: 0.0937123941257596, Validation Loss: 0.09033680986613035
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.49239927530288696
Validation AUPRC: 0.09474300593137741


                                                                                                                                         

Epoch 3/50, LR: 0.0009567727288213003, Training Loss: 0.08461196254938841, Validation Loss: 0.07981687225401402
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.4524739384651184
Validation AUPRC: 0.08631341904401779


                                                                                                                                        

Epoch 4/50, LR: 0.0009045084971874737, Training Loss: 0.08005208056420088, Validation Loss: 0.07234698999673128
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5254093408584595
Validation AUPRC: 0.10624224692583084


                                                                                                                                         

Epoch 5/50, LR: 0.0008345653031794292, Training Loss: 0.07999465055763721, Validation Loss: 0.06326500046998262
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.49052098393440247
Validation AUPRC: 0.09969422221183777


                                                                                                                                        

Epoch 6/50, LR: 0.00075, Training Loss: 0.07618012093007565, Validation Loss: 0.06475048419088125
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5584064722061157
Validation AUPRC: 0.11620519310235977


                                                                                                                                        

Epoch 7/50, LR: 0.0006545084971874737, Training Loss: 0.07415025494992733, Validation Loss: 0.0537027376703918
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5676276683807373
Validation AUPRC: 0.11436361074447632


                                                                                                                                         

Epoch 8/50, LR: 0.0005522642316338268, Training Loss: 0.07704869098961353, Validation Loss: 0.05789340892806649
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.4764443635940552
Validation AUPRC: 0.09478125721216202


                                                                                                                                         

Epoch 9/50, LR: 0.0004477357683661734, Training Loss: 0.07466145697981119, Validation Loss: 0.05669789295643568
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.4423295259475708
Validation AUPRC: 0.08700594305992126


                                                                                                                                          

Epoch 10/50, LR: 0.0003454915028125264, Training Loss: 0.07765408605337143, Validation Loss: 0.051485198084264994
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5043873190879822
Validation AUPRC: 0.1074489951133728


                                                                                                                                         

Epoch 11/50, LR: 0.0002500000000000001, Training Loss: 0.07007570285350084, Validation Loss: 0.05107170855626464
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5336669683456421
Validation AUPRC: 0.11195982992649078


                                                                                                                                         

Epoch 12/50, LR: 0.00016543469682057103, Training Loss: 0.07153486087918282, Validation Loss: 0.0515222093090415
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5469081401824951
Validation AUPRC: 0.11658299714326859


                                                                                                                                         

Epoch 13/50, LR: 9.549150281252631e-05, Training Loss: 0.07045465614646673, Validation Loss: 0.05142587190493941
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5388046503067017
Validation AUPRC: 0.11704397946596146


                                                                                                                                         

Epoch 14/50, LR: 4.322727117869951e-05, Training Loss: 0.06976697873324156, Validation Loss: 0.0508447727188468
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.540520191192627
Validation AUPRC: 0.1183457002043724


                                                                                                                                         

Epoch 15/50, LR: 1.0926199633097156e-05, Training Loss: 0.0693838307633996, Validation Loss: 0.05059600621461868
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5434988141059875
Validation AUPRC: 0.11919495463371277


                                                                                                                                         

Epoch 16/50, LR: 0.0, Training Loss: 0.06935949623584747, Validation Loss: 0.050081104040145874
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.544285237789154
Validation AUPRC: 0.1192476898431778


                                                                                                                                         

Epoch 17/50, LR: 1.0926199633097156e-05, Training Loss: 0.0691316407173872, Validation Loss: 0.04996328288689256
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5506328344345093
Validation AUPRC: 0.12080682814121246


                                                                                                                                         

Epoch 18/50, LR: 4.322727117869957e-05, Training Loss: 0.06998142600059509, Validation Loss: 0.04976402409374714
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5452473759651184
Validation AUPRC: 0.11964261531829834


                                                                                                                                         

Epoch 19/50, LR: 9.549150281252622e-05, Training Loss: 0.06947323121130466, Validation Loss: 0.049510282929986715
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5461807250976562
Validation AUPRC: 0.11992952227592468


                                                                                                                                         

Epoch 20/50, LR: 0.00016543469682057078, Training Loss: 0.06908863130956888, Validation Loss: 0.04859771579504013
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5422730445861816
Validation AUPRC: 0.1182282343506813


                                                                                                                                         

Epoch 21/50, LR: 0.0002499999999999998, Training Loss: 0.06975671742111444, Validation Loss: 0.04988825926557183
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5468634366989136
Validation AUPRC: 0.1177796721458435


                                                                                                                                         

Epoch 22/50, LR: 0.0003454915028125263, Training Loss: 0.0697169853374362, Validation Loss: 0.049544093664735556
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.547191858291626
Validation AUPRC: 0.11371799558401108


                                                                                                                                         

Epoch 23/50, LR: 0.000447735768366173, Training Loss: 0.07251705881208181, Validation Loss: 0.051039488054811954
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5478098392486572
Validation AUPRC: 0.12066294997930527


                                                                                                                                         

Epoch 24/50, LR: 0.0005522642316338266, Training Loss: 0.07043930422514677, Validation Loss: 0.04989225836470723
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.5491824746131897
Validation AUPRC: 0.122047558426857


                                                                                                                                          

Epoch 25/50, LR: 0.0006545084971874736, Training Loss: 0.07332280557602644, Validation Loss: 0.055592084769159555
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.49423858523368835
Validation AUPRC: 0.10547289252281189
Early stopping after 25 epochs due to no improvement in validation loss.
