In [49]:
# Fine-tuning SL model for AFIB at 2y PTB-XL lite dataset
import os
import torch
import random
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics.classification import (
    BinaryAccuracy, 
    BinaryAUROC, 
    BinaryAveragePrecision
)
from torch.utils.data import Dataset
from huggingface_hub import snapshot_download

# Define device
device = torch.device('cuda:0')

# Define hugging face token
hugging_face_token = "" # Set your Hugging Face token here


# Set seed for reproducibility


In [50]:
def seed_everything(seed: int):
    """
    Seed everything for reproducibility.

    Parameters:
    seed (int): The seed to set for reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

# Load Model

In [51]:
# Download model from Hugging Face
local_dir = snapshot_download(
    repo_id="heartwise/EfficientNetV2_SL_Model_Tunable", 
    local_dir=".", 
    repo_type="model", 
    token=hugging_face_token
)

Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 33961.98it/s]


In [52]:
from EfficientNetv2 import *

In [53]:
# Load model from local directory
sl_model = torch.load("sl_model.h5")

for param in sl_model.parameters():
    param.requires_grad = False

# Set up new classifier

In [54]:
class NewClassifier(nn.Module):
    def __init__(self, out_dim):
        super(NewClassifier, self).__init__()
        self.pool = nn.AdaptiveAvgPool1d(output_size=1)
        self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
        self.fc1 = nn.Linear(640, out_dim)

    def forward(self, x):
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return x

In [55]:
sl_model.classifier = NewClassifier(out_dim=1)
for param in sl_model.classifier.parameters():
    param.requires_grad = True

In [56]:
sl_model = sl_model.to(device)

# Prepare Dataset

In [57]:
# Download ptb-xl lite data from Hugging Face
dataset_dir = snapshot_download(
    repo_id="heartwise/PTB-XL_lite", 
    local_dir="ptb-xl_lite", 
    repo_type="dataset", 
    token=hugging_face_token
)

Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 56807.73it/s]




In [58]:
train_ds_X = np.squeeze(
    np.load(os.path.join(dataset_dir, "ptb_xl_afib_2y_train_subset.npy")),
)
val_ds_X = np.squeeze(
    np.load(os.path.join(dataset_dir, "ptb_xl_afib_2y_val_subset.npy")),
)

In [59]:
train_df = pd.read_parquet(
    os.path.join(dataset_dir, "ptb_xl_afib_2y_train_subset_labels.parquet")
)
val_df = pd.read_parquet(
    os.path.join(dataset_dir, "ptb_xl_afib_2y_val_subset_labels.parquet")
)

train_ds_Y = train_df['label_2y'].astype(int).tolist()
val_ds_Y = val_df['label_2y'].astype(int).tolist()

# Define DataLoader

In [60]:
train_params = {
        'batch_size': 254,
        'shuffle': True,
        'num_workers': 12,
        'pin_memory':True,
        'multiprocessing_context': 'fork',
        'persistent_workers': True
    }

val_params = {
        'batch_size': 254,
        'shuffle': False,
        'num_workers': 12,
        'pin_memory':True,
        'multiprocessing_context': 'fork',
        'persistent_workers': True
    }

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        x = np.swapaxes(x, 0, 1)
        
        # Convert numpy arrays to torch tensors
        x = torch.tensor(x, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.long)

        return x, y
    
train_set = CustomDataset(train_ds_X, train_ds_Y)
dataloader = torch.utils.data.DataLoader(train_set, **train_params)

val_set = CustomDataset(val_ds_X, val_ds_Y)
val_dataloader = torch.utils.data.DataLoader(val_set, **val_params)

# Use Binary Focal Loss for imbalanced data

In [61]:
class BinaryFocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        """
        Initialize the Binary Focal Loss function.

        Parameters:
        - alpha: Weighting factor for the rare class (usually the minority class), default is 0.25.
        - gamma: Focusing parameter that adjusts the rate at which easy examples are down-weighted, default is 2.0.
        - reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default is 'mean'.
        """
        super(BinaryFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        Forward pass of the loss function.

        Parameters:
        - inputs: Predicted logits (before applying sigmoid), shape (batch_size, 1).
        - targets: Ground truth labels, shape (batch_size, 1).

        Returns:
        - loss: Calculated focal loss.
        """
        # Convert targets to float
        targets = targets.float()
        
        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(inputs)

        # Calculate binary cross-entropy loss
        bce_loss = F.binary_cross_entropy(probs, targets, reduction='none')

        # Calculate the modulating factor (1 - p_t)^gamma
        p_t = probs * targets + (1 - probs) * (1 - targets)
        modulating_factor = torch.pow(1 - p_t, self.gamma)

        # Apply the alpha factor
        alpha_factor = self.alpha * targets + (1 - self.alpha) * (1 - targets)

        # Combine factors to compute the final focal loss
        focal_loss = alpha_factor * modulating_factor * bce_loss

        # Apply reduction method
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Define Metrics and Training/Validation Loop

In [62]:

def train_and_evaluate_model(
    model, 
    dataloader, 
    val_dataloader, 
    optimizer, 
    criterion, 
    scheduler, 
    device, 
    num_epochs=10, 
    patience=3
    ):
    # Initialize metrics
    accuracy = BinaryAccuracy().to(device)
    auroc = BinaryAUROC().to(device)
    auprc = BinaryAveragePrecision().to(device)
    model = model.to(device)

    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0

        # Reset metrics at the start of each epoch
        accuracy.reset()
        auroc.reset()
        auprc.reset()

        # Training loop
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f'Epoch {epoch+1}/{num_epochs} (Training)', leave=False)
        for i, (inputs, labels) in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.squeeze()  # Squeeze to match the shape of labels
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            avg_loss = running_loss / (i + 1)

            # Update metrics
            accuracy.update(outputs, labels)
            auroc.update(outputs, labels)
            auprc.update(outputs, labels)

            # Update tqdm progress bar with aggregate metrics
            progress_bar.set_postfix(
                loss=avg_loss,
                acc=accuracy.compute().item(),
                auroc=auroc.compute().item(),
                auprc=auprc.compute().item()
            )

        # Free GPU memory after each training epoch
        torch.cuda.empty_cache()

        # Validation loop
        model.eval()
        val_running_loss = 0.0

        accuracy.reset()
        auroc.reset()
        auprc.reset()

        with torch.no_grad():
            progress_bar = tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc=f'Epoch {epoch+1}/{num_epochs} (Validation)', leave=False)
            for i, (inputs, labels) in progress_bar:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                outputs = outputs.squeeze()  # Squeeze to match the shape of labels
                loss = criterion(outputs, labels)

                val_running_loss += loss.item()
                avg_val_loss = val_running_loss / (i + 1)

                # Update metrics
                accuracy.update(outputs, labels)
                auroc.update(outputs, labels)
                auprc.update(outputs, labels)

                # Update tqdm progress bar with validation metrics
                progress_bar.set_postfix(
                    val_loss=avg_val_loss,
                    val_acc=accuracy.compute().item(),
                    val_auroc=auroc.compute().item(),
                    val_auprc=auprc.compute().item()
                )

        # Free GPU memory after each validation epoch
        torch.cuda.empty_cache()

        # End of epoch logging (including learning rate)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch+1}/{num_epochs}, LR: {current_lr}, Training Loss: {avg_loss}, Validation Loss: {avg_val_loss}")
        print(f"Validation Accuracy: {accuracy.compute().item()}")
        print(f"Validation AUROC: {auroc.compute().item()}")
        print(f"Validation AUPRC: {auprc.compute().item()}")

        # Check if the validation loss improved
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            best_model_state = model.state_dict()  # Save the best model's state
        else:
            patience_counter += 1

        # Early stopping check
        if patience_counter >= patience:
            print(f"Early stopping after {epoch+1} epochs due to no improvement in validation loss.")
            break

        # Scheduler step with warm restarts
        scheduler.step()

        # Switch back to training mode
        model.train()

    # Load the best model state (if needed later)
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    # Free GPU memory at the end of training
    torch.cuda.empty_cache()

    return model

# Run Fine tuning

In [63]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, sl_model.parameters()), lr=1e-2)
criterion = BinaryFocalLoss(gamma=2)
scheduler = CosineAnnealingLR(optimizer, T_max=15, eta_min=0)

fine_tuned_sl_model_afib_2y = train_and_evaluate_model(
    sl_model, 
    dataloader, 
    val_dataloader, 
    optimizer, 
    criterion, 
    scheduler, 
    device, 
    num_epochs=50, 
    patience=5
)


torch.save(fine_tuned_sl_model_afib_2y, "fine_tuned_sl_model_afib_2y.h5")

                                                                                                                                        

Epoch 1/50, LR: 0.01, Training Loss: 0.07342986296862364, Validation Loss: 0.042087603360414505
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.6281976103782654
Validation AUPRC: 0.16620472073554993


                                                                                                                                        

Epoch 2/50, LR: 0.009890738003669028, Training Loss: 0.06786379870027304, Validation Loss: 0.04234182741492987
Validation Accuracy: 0.9039999842643738
Validation AUROC: 0.6419674158096313
Validation AUPRC: 0.16286659240722656


                                                                                                                                        

Epoch 3/50, LR: 0.009567727288213004, Training Loss: 0.06700187921524048, Validation Loss: 0.04757412290200591
Validation Accuracy: 0.9024999737739563
Validation AUROC: 0.6488580703735352
Validation AUPRC: 0.16419777274131775


                                                                                                                                        

Epoch 4/50, LR: 0.009045084971874737, Training Loss: 0.0658906614407897, Validation Loss: 0.043100038543343544
Validation Accuracy: 0.9024999737739563
Validation AUROC: 0.6444391012191772
Validation AUPRC: 0.16409432888031006


                                                                                                                                        

Epoch 5/50, LR: 0.008345653031794291, Training Loss: 0.06562896724790335, Validation Loss: 0.046420552767813206
Validation Accuracy: 0.9010000228881836
Validation AUROC: 0.6401928663253784
Validation AUPRC: 0.15866629779338837


                                                                                                                                        

Epoch 6/50, LR: 0.0075, Training Loss: 0.06572936661541462, Validation Loss: 0.044125021900981665
Validation Accuracy: 0.9024999737739563
Validation AUROC: 0.63759446144104
Validation AUPRC: 0.16052564978599548
Early stopping after 6 epochs due to no improvement in validation loss.


In [None]:
#