In [1]:
## Imports

import os

print(os.cpu_count())

import copy
import wandb

wandb.require("core")

import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
import torch.optim as optim
from torch import nn
import torch.backends.cudnn as cudnn
from torcheval.metrics.functional import binary_auroc

from torchvision.models import (
    convnext_tiny,
    convnext_small,
)

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import roc_curve, auc, roc_auc_score

import albumentations as A
from albumentations.pytorch import ToTensorV2

from colorama import Fore, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using {device} device")

cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

30


Using cuda device


In [2]:
configs = dict(
    image_size = 224,
)

In [3]:
train_metadata_df = pd.read_csv("../data/stratified_5_fold_train_metadata.csv")

def add_path(row):
    return f"../data/train-image/image/{row.isic_id}.jpg"

train_metadata_df["path"] = train_metadata_df.apply(lambda row: add_path(row), axis=1)

  train_metadata_df = pd.read_csv("../data/stratified_5_fold_train_metadata.csv")


## Dataloader

In [4]:
class PretrainSkinDataset(Dataset):
    def __init__(self, df: pd.DataFrame, transform=None):
        assert "path" in df.columns
        assert "tbp_lv_symm_2axis" in df.columns
        # TODO: add more features

        self.paths = df.path.tolist()
        self.tbp_lv_symm_2axis = df.tbp_lv_symm_2axis.tolist() # continuous, float
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx: int):
        image = read_image(self.paths[idx]).to(torch.float32) / 255.0
        label = self.tbp_lv_symm_2axis[idx]
        if self.transform:
            image = image.numpy().transpose((1,2,0))
            image = self.transform(image=image)["image"]
        return image, label

In [5]:
train_df = train_metadata_df.loc[
    train_metadata_df.fold != 0
]
valid_df = train_metadata_df.loc[
    train_metadata_df.fold == 0
]

In [6]:
# # placeholders
# psum = torch.tensor([0.0, 0.0, 0.0])
# psum_sq = torch.tensor([0.0, 0.0, 0.0])

# num_workers = 24 # based on profiling

# simple_transform = A.Compose([
#     A.Resize(configs["image_size"], configs["image_size"]),
#     ToTensorV2(),
# ])

# _train_dataset = PretrainSkinDataset(train_df, transform=simple_transform)
# _train_dataloader = DataLoader(
#     _train_dataset, batch_size=128, shuffle=True,
#     num_workers=num_workers, pin_memory=True, persistent_workers=True
# )

# # loop through images
# for inputs, labels in tqdm(_train_dataloader):
#     psum += inputs.sum(axis=[0, 2, 3])
#     psum_sq += (inputs**2).sum(axis=[0, 2, 3])

# count = len(train_df) * configs["image_size"] * configs["image_size"]
# total_mean = psum / count
# total_var = (psum_sq / count) - (total_mean**2)
# total_std = torch.sqrt(total_var)
# print("mean: " + str(total_mean))
# print("std:  " + str(total_std))

In [7]:
transforms_train = A.Compose([
    A.Transpose(p=0.5),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.75),
    A.OneOf([
        A.MotionBlur(blur_limit=5),
        A.MedianBlur(blur_limit=5),
        A.GaussianBlur(blur_limit=5),
        A.GaussNoise(var_limit=(5.0, 30.0)),
    ], p=0.7),

    A.OneOf([
        A.OpticalDistortion(distort_limit=1.0),
        A.GridDistortion(num_steps=5, distort_limit=1.),
        A.ElasticTransform(alpha=3),
    ], p=0.7),
    A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
    A.Resize(configs["image_size"], configs["image_size"]),
    A.Normalize(
        mean=(0.6962, 0.5209, 0.4193),
        std=(0.1395, 0.1320, 0.1240)
    ),
    ToTensorV2(),
])

transforms_val = A.Compose([
    A.Resize(configs["image_size"], configs["image_size"]),
    A.Normalize(
        mean=(0.6962, 0.5209, 0.4193),
        std=(0.1395, 0.1320, 0.1240)
    ),
    ToTensorV2(),
])

  __pydantic_self__.__pydantic_validator__.validate_python(data, self_instance=__pydantic_self__)
  Expected `Union[float, json-or-python[json=list[float], python=list[float]]]` but got `tuple` - serialized value may not be as expected
  Expected `Union[float, json-or-python[json=list[float], python=list[float]]]` but got `tuple` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(


In [8]:
# # Calculate statistics
# mean_train = train_df['tbp_lv_symm_2axis'].mean()
# median_train = train_df['tbp_lv_symm_2axis'].median()
# mean_valid = valid_df['tbp_lv_symm_2axis'].mean()
# median_valid = valid_df['tbp_lv_symm_2axis'].median()

# # Plot the distributions
# plt.figure(figsize=(10, 6))
# sns.histplot(train_df['tbp_lv_symm_2axis'], color='blue', label='Train Set', kde=True, alpha=0.5)
# sns.histplot(valid_df['tbp_lv_symm_2axis'], color='green', label='Validation Set', kde=True, alpha=0.5)

# # Add mean and median lines and annotations
# plt.axvline(mean_train, color='blue', linestyle='dashed', linewidth=1)
# plt.axvline(median_train, color='blue', linestyle='solid', linewidth=1)
# plt.axvline(mean_valid, color='green', linestyle='dashed', linewidth=1)
# plt.axvline(median_valid, color='green', linestyle='solid', linewidth=1)

# plt.text(mean_train, plt.ylim()[1]*0.8, f'Mean: {mean_train:.2f}', color='blue', ha='center')
# plt.text(median_train, plt.ylim()[1]*0.7, f'Median: {median_train:.2f}', color='blue', ha='center')
# plt.text(mean_valid, plt.ylim()[1]*0.6, f'Mean: {mean_valid:.2f}', color='green', ha='center')
# plt.text(median_valid, plt.ylim()[1]*0.5, f'Median: {median_valid:.2f}', color='green', ha='center')

# # Add title and labels
# plt.title('Distribution of tbp_lv_symm_2axis in Train and Validation Sets')
# plt.xlabel('tbp_lv_symm_2axis')
# plt.ylabel('Density')
# plt.legend()

# # Show plot
# plt.show()

In [9]:
num_workers = 24 # based on profiling

train_dataset = PretrainSkinDataset(train_df, transform=transforms_train)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, drop_last=True)

valid_dataset = PretrainSkinDataset(valid_df, transform=transforms_val)
valid_dataloader = DataLoader(valid_dataset, batch_size=128, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=True, drop_last=True)

dataset_sizes = {"train": len(train_dataset), "val": len(valid_dataset)}
dataset_sizes

{'train': 320848, 'val': 80211}

## Model

In [10]:
model_ft = convnext_tiny()
model_ft.classifier[2] = nn.Linear(model_ft.classifier[2].in_features, 1, bias=False)

model_ft = model_ft.to(device)
model_ft = torch.compile(model_ft)
print(model_ft)

OptimizedModule(
  (_orig_mod): ConvNeXt(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
      )
      (1): Sequential(
        (0): CNBlock(
          (block): Sequential(
            (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
            (1): Permute()
            (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
            (3): Linear(in_features=96, out_features=384, bias=True)
            (4): GELU(approximate='none')
            (5): Linear(in_features=384, out_features=96, bias=True)
            (6): Permute()
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        )
        (1): CNBlock(
          (block): Sequential(
            (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
            (1): Permute()
            (2): LayerNorm

## Utils

In [11]:
def train_model(model, dataloader, criterion, optimizer, scheduler=None):
    model.train()  # Set model to training mode

    running_loss = 0.0

    for idx, (inputs, labels) in enumerate(dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device).flatten().to(torch.float32)

        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            outputs = model(inputs).flatten()
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

        loss_val = loss.detach()
        running_loss += loss_val * inputs.size(0)

        if (idx + 1) % 10 == 0:
            wandb.log({"train_loss": loss_val})
            print(f"Train Batch Loss: {loss_val}")

    epoch_loss = running_loss / dataset_sizes["train"]

    return model, epoch_loss


def validate_model(model, dataloader, criterion, optimizer):
    model.eval()

    running_loss = 0.0

    for idx, (inputs, labels) in enumerate(dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device).flatten().to(torch.float32)

        optimizer.zero_grad()

        with torch.no_grad():
            outputs = model(inputs).flatten()
            loss = criterion(outputs, labels)

        loss_val = loss.detach()
        running_loss += loss_val * inputs.size(0)

        if (idx + 1) % 10 == 0:
            wandb.log({"valid_loss": loss_val})
            print(f"valid Batch Loss: {loss_val}")

    valid_loss = running_loss / dataset_sizes["val"]

    return model, valid_loss

## Pretrain

In [12]:
criterion = nn.MSELoss()
optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.001, weight_decay=1e-6)

In [13]:
run = wandb.init(project="isic_lesions_24", job_type="pretrain")

best_epoch_auroc = -np.inf
best_valid_loss = np.inf
early_stopping_patience = 4
epochs_no_improve = 0

for epoch in range(10):
    model_ft, epoch_loss = train_model(
        model_ft, train_dataloader, criterion, optimizer_ft
    )
    model_ft, valid_loss = validate_model(
        model_ft, valid_dataloader, criterion, optimizer_ft
    )

    print(
        f"Epoch: {epoch} | Train Loss: {epoch_loss} | Valid Loss: {valid_loss}\n"
    )
    wandb.log(
        {
            "epoch": epoch,
            "epoch_loss": epoch_loss,
            "epoch_val_loss": valid_loss,
        }
    )

    # earlystopping dependent on validation loss
    if best_valid_loss >= valid_loss:
        print(f"{b_}Validation Loss Improved ({best_valid_loss} ---> {valid_loss}){sr_}")
        
        # checkpointing
        best_model_wts = copy.deepcopy(model_ft.state_dict())
        PATH = "../models/pretrain_valid_loss{:.4f}_epoch{:.0f}.bin".format(valid_loss, epoch)
        torch.save(model_ft.state_dict(), PATH)
        # Save a model file from the current directory
        print(f"{b_}Model Saved{sr_}")
        best_valid_loss = valid_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= early_stopping_patience:
        print(
            f"{b_}Early stopping triggered after {epochs_no_improve} epochs with no improvement.{sr_}"
        )
        break

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33mayush-thakur[0m. Use [1m`wandb login --relogin`[0m to force relogin


Train Batch Loss: 0.9938347339630127
Train Batch Loss: 0.29654261469841003
Train Batch Loss: 0.07255585491657257
Train Batch Loss: 0.024559348821640015
Train Batch Loss: 0.016417246311903
Train Batch Loss: 0.013899106532335281
Train Batch Loss: 0.018576525151729584
Train Batch Loss: 0.015132731758058071
Train Batch Loss: 0.015436557121574879
Train Batch Loss: 0.01929567940533161
Train Batch Loss: 0.0200906191021204
Train Batch Loss: 0.014607090502977371
Train Batch Loss: 0.014009620994329453
Train Batch Loss: 0.012998992577195168
Train Batch Loss: 0.015194405801594257
Train Batch Loss: 0.014454683288931847
Train Batch Loss: 0.01786598190665245
Train Batch Loss: 0.016805382445454597
Train Batch Loss: 0.01873870939016342
Train Batch Loss: 0.017156973481178284
Train Batch Loss: 0.015766769647598267
Train Batch Loss: 0.015033743344247341
Train Batch Loss: 0.016078319400548935
Train Batch Loss: 0.015765275806188583
Train Batch Loss: 0.014872428961098194
Train Batch Loss: 0.01492001395672559