In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

print(os.cpu_count())

import gc
import re
import math
from glob import glob
import wandb

wandb.require("core")

import random
import pandas as pd
import numpy as np
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

import h5py
from PIL import Image
from io import BytesIO

import torch
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import WeightedRandomSampler
import torch.optim as optim
from torch import nn
from torchvision import models

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torcheval.metrics.functional import binary_auroc

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler

import albumentations as A
from albumentations.pytorch import ToTensorV2

from colorama import Fore, Style

b_ = Fore.BLUE
sr_ = Style.RESET_ALL

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using {device} device")

30


INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.14 (you have 1.4.11). Upgrade using: pip install --upgrade albumentations


Using cuda device


In [3]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_float32_matmul_precision("highest")


# Set the random seed
set_seed(42)

In [4]:
train_metadata_df = pd.read_csv("../data/train-metadata.csv")
print(f"Train: {len(train_metadata_df)}")

train_cols = [
    'isic_id', 'patient_id', 'target', 'age_approx', 'sex',
    'anatom_site_general', 'clin_size_long_diam_mm',
    'tbp_tile_type', 'tbp_lv_A', 'tbp_lv_Aext', 'tbp_lv_B', 'tbp_lv_Bext',
    'tbp_lv_C', 'tbp_lv_Cext', 'tbp_lv_H', 'tbp_lv_Hext', 'tbp_lv_L',
    'tbp_lv_Lext', 'tbp_lv_areaMM2', 'tbp_lv_area_perim_ratio',
    'tbp_lv_color_std_mean', 'tbp_lv_deltaA', 'tbp_lv_deltaB',
    'tbp_lv_deltaL', 'tbp_lv_deltaLB', 'tbp_lv_deltaLBnorm',
    'tbp_lv_eccentricity', 'tbp_lv_location', 'tbp_lv_location_simple',
    'tbp_lv_minorAxisMM', 'tbp_lv_nevi_confidence', 'tbp_lv_norm_border',
    'tbp_lv_norm_color', 'tbp_lv_perimeterMM',
    'tbp_lv_radial_color_std_max', 'tbp_lv_stdL', 'tbp_lv_stdLExt',
    'tbp_lv_symm_2axis', 'tbp_lv_symm_2axis_angle', 'tbp_lv_x', 'tbp_lv_y',
    'tbp_lv_z'
]

train_metadata_df = train_metadata_df[train_cols].dropna().reset_index(drop=True) # dropping nan doesn't drop any pos sample

N_SPLITS = 4
gkf = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
train_metadata_df["fold"] = -1
for idx, (train_idx, val_idx) in enumerate(
    gkf.split(
        train_metadata_df,
        train_metadata_df["target"],
        groups=train_metadata_df["patient_id"],
    )
):
    train_metadata_df.loc[val_idx, "fold"] = idx

train_metadata_df = train_metadata_df.drop(columns=["patient_id"])

# Scale
scaler = MinMaxScaler()

features = train_metadata_df.drop(columns=['target', 'isic_id', 'fold'])
numerical_cols = features.select_dtypes(include=['float64', 'int64']).columns

train_metadata_df[numerical_cols] = scaler.fit_transform(train_metadata_df[numerical_cols])

# Categorical
female_male = pd.get_dummies(train_metadata_df["sex"]).astype(float)
train_metadata_df = train_metadata_df.drop(columns=["sex"])
train_metadata_df = train_metadata_df.join(female_male)

anatom_site_general = pd.get_dummies(train_metadata_df["anatom_site_general"]).astype(float)
train_metadata_df = train_metadata_df.drop(columns=["anatom_site_general"])
train_metadata_df = train_metadata_df.join(anatom_site_general)

train_metadata_df = train_metadata_df.drop(columns=["tbp_tile_type", "tbp_lv_location", "tbp_lv_location_simple"]) # ignoring two cat columns for now.

  train_metadata_df = pd.read_csv("../data/train-metadata.csv")


Train: 401059


In [5]:
# dataset
class SkinDataset(Dataset):
    def __init__(self, df: pd.DataFrame, file_hdf: str, transform=None):
        assert "isic_id" in df.columns
        assert "target" in df.columns

        # add features
        feature_cols = df.select_dtypes(include=['float64']).columns
        self.features = df[feature_cols].values.astype('float32')

        self.fp_hdf = h5py.File(file_hdf, mode="r")
        self.isic_ids = df['isic_id'].values
        self.labels = df.target.tolist()
        self.transform = transform

    def __len__(self):
        return len(self.isic_ids)

    def __getitem__(self, idx: int):
        isic_id = self.isic_ids[idx]
        image = np.array(Image.open(BytesIO(self.fp_hdf[isic_id][()])))
        label = self.labels[idx] / 1.0
        if self.transform:
            image = self.transform(image=image)["image"]
        return image, self.features[idx], label

    def get_class_samples(self, class_label):
        indices = [i for i, label in enumerate(self.labels) if label == class_label]
        return indices

In [6]:
transforms_train = A.Compose(
    [
        A.Resize(124, 124),
        A.CenterCrop(
            height=124,
            width=124,
            p=1.0,
        ),
        A.CLAHE(
            clip_limit=4, tile_grid_size=(10, 10), p=0.5
        ),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.3, rotate_limit=60, p=0.6),
        A.HueSaturationValue(
            hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5
        ),
        A.RandomBrightnessContrast(
            brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5
        ),
        A.RandomRotate90(p=0.6),
        A.Flip(p=0.7),
        A.Normalize(),
        ToTensorV2(),
    ]
)

transforms_valid = A.Compose(
    [
        A.Resize(124, 124),
        A.Normalize(),
        ToTensorV2(),
    ]
)

  Expected `Union[float, json-or-python[json=list[float], python=list[float]]]` but got `tuple` - serialized value may not be as expected
  Expected `Union[float, json-or-python[json=list[float], python=list[float]]]` but got `tuple` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(


In [7]:
def get_dataloaders_and_stats(fold):
    # dataloaders
    train_df = train_metadata_df.loc[train_metadata_df.fold != fold]
    valid_df = train_metadata_df.loc[train_metadata_df.fold == fold]

    num_workers = 24  # based on profiling

    file_hdf = "/home/ubuntu/ayusht/skin/data/train-image.hdf5"
    train_dataset = SkinDataset(train_df, file_hdf, transform=transforms_train)
    valid_dataset = SkinDataset(valid_df, file_hdf, transform=transforms_valid)
    dataset_sizes = {"train": len(train_dataset), "val": len(valid_dataset)}
    print(dataset_sizes)

    # calculate bias value
    neg_samples = len(train_dataset.get_class_samples(0))
    pos_samples = len(train_dataset.get_class_samples(1))
    p_positive = pos_samples / (neg_samples + pos_samples)
    bias_value = math.log(p_positive / (1 - p_positive))
    print(f"Calculated bias value: {bias_value}")

    # calculate class weight
    pos_weight = torch.ones([1]) * (neg_samples / pos_samples)
    pos_weight = pos_weight.to(device)
    print(f"Calculated pos_weight: {pos_weight}")

    # calculate weight for each class for random sampler
    neg_wts = 1 / neg_samples
    pos_wts = 1 / pos_samples
    sample_wts = []

    for label in train_dataset.labels:
        if label == 0:
            sample_wts.append(neg_wts)
        else:
            sample_wts.append(pos_wts)

    sampler = WeightedRandomSampler(
        weights=sample_wts, num_samples=int(len(train_dataset) / 3), replacement=True
    )
    train_dataloader = DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=128,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        drop_last=True,
    )

    valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=128,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
    )

    return (
        train_dataloader,
        valid_dataloader,
        dataset_sizes,
        bias_value,
        pos_weight,
    )

In [8]:
# model
class SkinClassifier(nn.Module):
    def __init__(self, model_name="resnet18", freeze_backbone=False, num_tabular_features=41, bias_value=None):
        super(SkinClassifier, self).__init__()

        if model_name == "efficientnet_v2_s":
            self.backbone = models.efficientnet_v2_s(weights="IMAGENET1K_V1")
            if freeze_backbone:
                self.freeze_backbone()
            num_ftrs = self.backbone.classifier[1].in_features
            self.backbone.classifier = nn.Identity()
        else:
            raise ValueError(f"Model {model_name} not supported")

        self.mlp = nn.Sequential(
            nn.Linear(num_tabular_features, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(num_ftrs + 256, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 1)
        )

        if bias_value is not None:
            nn.init.constant_(self.classifier[-1].bias, bias_value)

    def forward(self, image, tabular_features):
        # Forward pass through the CNN backbone
        cnn_out = self.backbone(image)

        # Forward pass through the MLP for tabular data
        mlp_out = self.mlp(tabular_features)
        
        # Concatenate CNN and MLP outputs
        combined_out = torch.cat((cnn_out, mlp_out), dim=1)
        
        # Pass through the final classifier
        return self.classifier(combined_out)

    def freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False

        for param in self.backbone.features[6].parameters():
            param.requires_grad = True

        for param in self.backbone.features[7].parameters():
            param.requires_grad = True

    def count_parameters(self):
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        non_trainable_params = sum(
            p.numel() for p in self.parameters() if not p.requires_grad
        )
        return trainable_params, non_trainable_params

In [9]:
# training and validation utils
def train_model(
    model,
    dataloader,
    criterion,
    optimizer,
    train_step,
    dataset_sizes,
    scheduler=None,
    debug=False,
):
    model.train()

    running_loss = 0.0
    preds = []
    gts = []

    for idx, (inputs, feats, labels) in enumerate(dataloader):
        inputs = inputs.to(device)
        feats = feats.to(device)
        labels = labels.to(device).flatten().to(torch.float32)

        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            outputs = model(inputs, feats).flatten()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        loss_val = loss.detach()
        running_loss += loss_val * inputs.size(0)

        preds.extend(outputs)
        gts.extend(labels)

        if (idx + 1) % 10 == 0:
            train_step += 1
            wandb.log(
                {
                    "train_loss": loss_val,
                    "train_step": train_step,
                    "lr": optimizer.param_groups[0]["lr"],
                }
            )
            print(f"Train Batch Loss: {loss_val}")

        if scheduler:
            scheduler.step()

        if debug:
            break

    epoch_loss = running_loss / dataset_sizes["train"]
    epoch_auroc = binary_auroc(
        input=torch.tensor(preds).to(device), target=torch.tensor(gts).to(device)
    ).item()

    return model, epoch_loss, epoch_auroc


def validate_model(
    model, dataloader, criterion, optimizer, valid_step, dataset_sizes, debug=False
):
    model.eval()

    running_loss = 0.0
    preds = []
    gts = []

    for idx, (inputs, feats, labels) in enumerate(dataloader):
        inputs = inputs.to(device)
        feats = feats.to(device)
        labels = labels.to(device).flatten().to(torch.float32)

        optimizer.zero_grad()

        with torch.no_grad():
            outputs = model(inputs, feats).flatten()
            loss = criterion(outputs, labels)

        loss_val = loss.detach()
        running_loss += loss_val * inputs.size(0)

        preds.extend(outputs)
        gts.extend(labels)

        if (idx + 1) % 10 == 0:
            valid_step += 1
            wandb.log({"valid_loss": loss_val, "valid_step": valid_step})
            print(f"valid Batch Loss: {loss_val}")

        if debug:
            break

    valid_loss = running_loss / dataset_sizes["val"]
    valid_auroc = binary_auroc(
        input=torch.tensor(preds).to(device), target=torch.tensor(gts).to(device)
    ).item()

    return model, valid_loss, valid_auroc

In [125]:
def train_and_validate_folds(fold):
    # Initialize wandb
    run = wandb.init(
        project="isic_lesions_24", job_type="4_fold_nn", name=f"fold_{fold}"
    )
    wandb.define_metric("train_step")
    wandb.define_metric("valid_step")

    model_name = "efficientnet_v2_s"
    debug = False
    epochs = 15 if not debug else 1

    # Get data and stats
    train_dataloader, valid_dataloader, dataset_sizes, bias_value, pos_weight = (
        get_dataloaders_and_stats(fold)
    )

    # Create the model
    model = SkinClassifier(
        model_name=model_name, freeze_backbone=True, bias_value=bias_value
    )
    model = model.to(device)
    model = torch.compile(model)
    print(model)

    trainable_params, non_trainable_params = model.count_parameters()
    print(f"Trainable parameters: {trainable_params}")
    print(f"Non-trainable parameters: {non_trainable_params}")

    # Loss fn and optimizer
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=0.001,
        weight_decay=1e-5,
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=781 * 2, T_mult=2, eta_min=1e-6, last_epoch=-1
    )

    # train loop
    train_step = 0
    valid_step = 0
    best_epoch_auroc = -np.inf
    best_valid_loss = np.inf
    early_stopping_patience = 4
    epochs_no_improve = 0

    for epoch in range(
        epochs
    ):  # reducing epoch to 15 because quick overfitting after correct init
        model, epoch_loss, epoch_train_auroc = train_model(
            model,
            train_dataloader,
            criterion,
            optimizer,
            train_step,
            dataset_sizes,
            scheduler,
            debug=debug,
        )

        model, valid_loss, epoch_valid_auroc = validate_model(
            model,
            valid_dataloader,
            criterion,
            optimizer,
            valid_step,
            dataset_sizes,
            debug=debug,
        )

        wandb.log(
            {
                "epoch": epoch,
                "epoch_loss": epoch_loss,
                "epoch_val_loss": valid_loss,
                "epoch_train_auroc": epoch_train_auroc,
                "epoch_valid_auroc": epoch_valid_auroc,
            }
        )

        print(f"Epoch: {epoch} | Train Loss: {epoch_loss} | Valid Loss: {valid_loss}\n")
        print(
            f"Epoch: {epoch} | Train AUROC: {epoch_train_auroc} | Valid AUROC: {epoch_valid_auroc}\n"
        )

        # earlystopping dependent on validation loss
        if best_valid_loss >= valid_loss:
            print(
                f"{b_}Validation Loss Improved ({best_valid_loss} ---> {valid_loss}){sr_}"
            )

            # checkpointing
            PATH = f"../models/{model_name}_{run.id}_valid_loss{valid_loss}_epoch{epoch}_fold{fold}.bin"
            torch.save(model.state_dict(), PATH)
            print(f"{b_}Model Saved{sr_}")
            best_valid_loss = valid_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

    # end training for this fold
    # offload to cpu
    model = model.to("cpu")
    del model
    del train_dataloader
    del valid_dataloader
    gc.collect()

    # finish wandb run
    run.finish()

In [126]:
train_and_validate_folds(0)

VBox(children=(Label(value='0.054 MB of 0.054 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

{'train': 290148, 'val': 91766}
Calculated bias value: -6.883390138019604
Calculated pos_weight: tensor([975.9293], device='cuda:0')
OptimizedModule(
  (_orig_mod): SkinClassifier(
    (backbone): EfficientNet(
      (features): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (2): SiLU(inplace=True)
        )
        (1): Sequential(
          (0): FusedMBConv(
            (block): Sequential(
              (0): Conv2dNormActivation(
                (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
                (2): SiLU(inplace=True)
              )
            )
            (stochastic_depth): StochasticDepth(p=0.0, mode=row)
          )
          (1): Fuse



Train Batch Loss: 314.7380065917969
Train Batch Loss: 11.717415809631348
Train Batch Loss: 18.08574867248535
Train Batch Loss: 29.576379776000977
Train Batch Loss: 26.49083709716797
Train Batch Loss: 19.69489288330078
Train Batch Loss: 24.176483154296875
Train Batch Loss: 25.03244400024414
Train Batch Loss: 22.138965606689453
Train Batch Loss: 16.58848762512207
Train Batch Loss: 17.642223358154297
Train Batch Loss: 23.06472396850586
Train Batch Loss: 17.123607635498047
Train Batch Loss: 13.009292602539062
Train Batch Loss: 15.282464981079102
Train Batch Loss: 14.833134651184082
Train Batch Loss: 12.460323333740234
Train Batch Loss: 12.060583114624023
Train Batch Loss: 10.933899879455566
Train Batch Loss: 10.405517578125
Train Batch Loss: 7.901212215423584
Train Batch Loss: 7.223062038421631
Train Batch Loss: 8.098145484924316
Train Batch Loss: 6.41211462020874
Train Batch Loss: 6.555633068084717
Train Batch Loss: 5.717312812805176
Train Batch Loss: 7.088094234466553
Train Batch Loss: 5

KeyboardInterrupt: 