In [None]:
'''Imports'''
import os, time
import csv, numpy as np, pandas as pd, h5py
from tqdm import tqdm
import torch, torchvision
import torch.nn as nn
from torch.optim import lr_scheduler
from sklearn.metrics import roc_curve, auc
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score
import matplotlib.pyplot as plt
from vit_pytorch import ViT
from io import BytesIO
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.transforms import v2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import cv2
from sklearn.model_selection import StratifiedShuffleSplit

In [None]:
'''Constants'''
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 32
LEARNING_RATE = 0.0001
CLASSES = 1
EPOCHS = 50
MIN_EPOCH_TRAIN = 10
PATIENCE = 5
EPSILON = 0.0005
NEG_POS_RATIO = 20
FOLDS = 5

In [None]:
'''Paths'''
TRAIN_HDF5_PATH = "/kaggle/input/isic-2024-challenge/train-image.hdf5"
TEST_HDF5_PATH = "/kaggle/input/isic-2024-challenge/test-image.hdf5"
ANNOTATIONS_FILE = "/kaggle/input/isic-2024-challenge/train-metadata.csv"
MODEL_SAVE_PATH_ = "/kaggle/working/"
LOG_FILE_1 = "/kaggle/working/"
LOG_FILE_2 = "/kaggle/working/log_folds.csv"
SUBMISSION_FILE_PATH = "/kaggle/working/submission.csv"
METRICS_PLOT_SAVE_PATH = "/kaggle/working/metrics.png"

In [None]:
'''Dataset and Dataloaders'''
data_transforms_v2 = {
    "train": v2.Compose([
        v2.Resize((224, 224)),
        v2.ToImage(),
        v2.RandomRotation(degrees=(0, 360)),
        v2.RandomHorizontalFlip(p=0.5),
        v2.ToDtype(torch.float32, scale = True)
    ]),
    "test": v2.Compose([
        v2.Resize((224, 224)),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale = True)
    ])
}

data_transforms_album = {
    "train": A.Compose([
        A.Resize(224, 224, interpolation=cv2.INTER_AREA),
        A.RandomRotate90(p=0.5),
        A.Flip(p=0.5),
        A.Downscale(p=0.25),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.15, rotate_limit=60, p=0.5),
        A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224,0.225], max_pixel_value=255.0, p=1.0),
        ToTensorV2()], p=1.),
    
    "test": A.Compose([
        A.Resize(224, 224, interpolation=cv2.INTER_AREA),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        ToTensorV2()], p=1.)
}

class ISIC2024_HDF5(Dataset):
    def __init__(self, hdf5_path, annotations_df=None, transform=None):
        self.hdf5_path = hdf5_path
        self.annotations_df = annotations_df
        self.transform = transform
        self.image_ids = []
        self.hdf5_file = h5py.File(self.hdf5_path, 'r')
        if self.annotations_df is not None:
            self.image_ids = annotations_df['isic_id']
            self.labels = annotations_df.set_index('isic_id')['target'].to_dict()
        else:
            self.image_ids = list(self.hdf5_file.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image = Image.open(BytesIO(self.hdf5_file[image_id][()]))
        if self.transform:
            image = self.transform(image)
        if torch.isnan(image).any():
            print(f"NaN detected in image {image_id}")
        if self.annotations_df is not None:
            label = self.labels[image_id]
            if np.isnan(label):
                print(f"NaN detected in label for image {image_id}")
            return image, label, image_id
        else:
            return image, image_id
    
    def close(self):
        self.hdf5_file.close()

class ISIC2024_HDF5_ALBUM(Dataset):
    def __init__(self, hdf5_path, annotations_df=None, transform=None):
        self.hdf5_path = hdf5_path
        self.annotations_df = annotations_df
        self.transform = transform
        self.image_ids = []
        self.hdf5_file = h5py.File(self.hdf5_path, 'r')
        if self.annotations_df is not None:
            self.image_ids = annotations_df['isic_id']
            self.labels = annotations_df.set_index('isic_id')['target'].to_dict()
        else:
            self.image_ids = list(self.hdf5_file.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image = np.array(Image.open(BytesIO(self.hdf5_file[image_id][()])))
        if self.transform:
            image = self.transform(image=image)["image"]
        if torch.isnan(image).any():
            print(f"NaN detected in image {image_id}")
        if self.annotations_df is not None:
            label = self.labels[image_id]
            if np.isnan(label):
                print(f"NaN detected in label for image {image_id}")
            return image, label, image_id
        else:
            return image, image_id
    
    def close(self):
        self.hdf5_file.close()

def get_loader(test_hdf5_path, 
               train_labels_df = None, 
               train_hdf5_path = None, 
               dataset_cls=ISIC2024_HDF5_ALBUM,
               train_img_trans=data_transforms_album["train"], 
               test_img_trans=data_transforms_album["test"], 
               batch=32, 
               seed=None):
    if train_labels_df is not None and train_hdf5_path is not None:
        train_dataset_all = dataset_cls(hdf5_path=train_hdf5_path, annotations_df=train_labels_df, transform=train_img_trans)
        test_dataset = dataset_cls(hdf5_path=test_hdf5_path, transform=test_img_trans)
        train_annotations_all = train_labels_df
        labels = train_annotations_all['target']
        splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)
        train_idx, val_idx = next(splitter.split(train_annotations_all, labels))
        train_subset = Subset(train_dataset_all, train_idx)
        val_subset = Subset(train_dataset_all, val_idx)
        train_loader = DataLoader(train_subset, batch_size=batch, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch, shuffle=True)
        test_loader = DataLoader(test_dataset, shuffle=False)
        return train_loader, val_loader, test_loader
    else:
        test_dataset = dataset_cls(hdf5_path=test_hdf5_path, transform=test_img_trans)
        test_loader = DataLoader(test_dataset, shuffle=False)
        return test_loader

In [None]:
'''Utils'''
def train(epochs, model, learning_rate, train_dl, val_dl, min_epoch_train, patience, epsilon, log_file, model_save_path, criterion = nn.BCEWithLogitsLoss()):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    scaler = torch.cuda.amp.GradScaler()
    best_val_pauc = -1.0
    current_patience = 0
    with open(log_file, "w", newline="") as f:
        csv_writer = csv.writer(f)
        csv_writer.writerow(['Epoch', 'Learning Rate', 'Training Loss', 'Training Accuracy', 'Validation Loss', 'Validation Accuracy', 'Validation Precision', 'Validation Recall', 'Validation F1 Score', 'Validation pAUC'])
        for epoch in range(epochs):
            print(f"\n | Epoch: {epoch+1}")
            total_loss = 0
            num_corr = 0
            num_samp = 0
            loop = tqdm(train_dl)
            model.train()
            for batch_idx, (inputs, labels, _) in enumerate(loop):
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                optimizer.zero_grad()
                with torch.cuda.amp.autocast():
                    outputs = model(inputs).squeeze(1)
                    loss = criterion(outputs, labels.float())
                if torch.isnan(loss):
                    print(f"NaN loss detected at batch {batch_idx}")
                    continue
                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                preds = torch.sigmoid(outputs)
                num_corr += ((preds > 0.5) == labels).sum()
                num_samp += preds.size(0)
                total_loss += loss.item()
                loop.set_postfix(loss=loss.item())
            avg_loss = total_loss / len(train_dl)
            acc = num_corr / num_samp
            print(f"| Epoch {epoch+1}/{epochs} total training loss: {total_loss}, average training loss: {avg_loss}.")
            print("On Validation Data:")
            model.eval()
            with torch.inference_mode():
                val_loss, val_acc, val_pre, val_rec, val_f1, val_pauc = evaluate(val_dl, model, criterion)
            print("learning rate:", scheduler.get_last_lr()[0])
            row = [epoch+1, scheduler.get_last_lr()[0], avg_loss, acc.item(), val_loss, val_acc, val_pre, val_rec, val_f1, val_pauc]
            csv_writer.writerow(row)
            if epoch + 1 > min_epoch_train:
                if val_pauc > best_val_pauc and (val_pauc - best_val_pauc) > epsilon:
                    best_val_pauc = val_pauc
                    print(f'Validation pAUC improved by more than {epsilon}, ({best_val_pauc} > {best_val_pauc})); saving model...')
                    checkpoint = {
                        "state_dict": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                    }
                    save_checkpoint(checkpoint, model_save_path)
                    print(f'Model saved at {model_save_path}')
                    current_patience = 0
                else:
                    current_patience += 1
                    print(f'Validation pAUC did not improve. Patience left: {patience - current_patience}')
                    if current_patience >= patience:
                        print(f'\n---Early stopping at epoch {epoch+1}.---')
                        break
            else:
                if val_pauc > best_val_pauc:
                    best_val_pauc = val_pauc
                    checkpoint = {
                        "state_dict": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                    }
                    save_checkpoint(checkpoint, model_save_path)
                    print(f'Model saved at {model_save_path}')
            print(f'Current Best Validation pAUC: {best_val_pauc}')
            scheduler.step()
    print('Training complete.')
    return best_val_pauc

def pauc_above_tpr(y_true, y_pred, min_tpr=0.80):
    y_true = abs(np.array(y_true) - 1)
    y_pred = -1.0 * np.array(y_pred)
    if np.isnan(y_true).any() or np.isnan(y_pred).any():
        print("NaN values detected in inputs to pauc_above_tpr")
        return 0
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    max_fpr = 1 - min_tpr
    stop = np.searchsorted(fpr, max_fpr, "right")
    x_interp = [fpr[stop - 1], fpr[stop]]
    y_interp = [tpr[stop - 1], tpr[stop]]
    tpr = np.append(tpr[:stop], np.interp(max_fpr, x_interp, y_interp))
    fpr = np.append(fpr[:stop], max_fpr)
    if len(fpr) < 2:
        print("Warning: Not enough points to compute pAUC. Returning 0.")
        return 0
    partial_auc = auc(fpr, tpr)
    return partial_auc

def evaluate(loader, model, criterion):
    metric = BinaryF1Score(threshold=0.5).to(DEVICE)
    prec = BinaryPrecision(threshold=0.5).to(DEVICE)
    recall = BinaryRecall(threshold=0.5).to(DEVICE)
    acc = BinaryAccuracy(threshold=0.5).to(DEVICE)
    loss = 0.0
    num_corr = 0
    num_samp = 0
    all_preds = []
    all_labels = []
    model.eval()
    with torch.no_grad():
        for inputs, labels, _ in tqdm(loader):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs).squeeze(1)
            if torch.isnan(outputs).any():
                print("NaN detected in model outputs")
                continue
            loss += criterion(outputs, labels.float()).item()
            preds = torch.sigmoid(outputs)
            num_corr += ((preds > 0.5) == labels).sum()
            num_samp += preds.size(0)
            metric.update(preds, labels)
            prec.update(preds, labels)
            recall.update(preds, labels)
            acc.update(preds, labels)
            all_preds.extend(preds.cpu().detach().numpy())
            all_labels.extend(labels.cpu().numpy())
    avg_loss = loss / len(loader)
    accu = float(num_corr) / float(num_samp)
    pauc = pauc_above_tpr(all_labels, all_preds)
    print(f"Total loss: {loss}, Average loss: {avg_loss}")
    print(f"Got {num_corr}/{num_samp} correct with accuracy {accu*100:.2f}")
    print(f"pAUC above 80% TPR: {pauc:.3f}, Accuracy: {acc.compute().item():.3f}, precision: {prec.compute().item():.3f}, recall: {recall.compute().item():.3f}, F1Score: {metric.compute().item():.3f}")
    model.train()
    return avg_loss, acc.compute().item(), prec.compute().item(), recall.compute().item(), metric.compute().item(), pauc

def save_checkpoint(state, filename="my_checkpoint.pth"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def load_model(model_save_path = None):
    if model_save_path:
        model = ViT(
            image_size = 224,
            patch_size = 16,
            num_classes = 1,
            dim = 768,
            depth = 12,
            heads = 12,
            mlp_dim = 3072,
            dropout = 0.1,
            emb_dropout = 0.1
        )
        model.load_state_dict(torch.load(model_save_path)["state_dict"])
        model.to(DEVICE)
        model.eval()
    else:
        model = ViT(
            image_size = 224,
            patch_size = 16,
            num_classes = 1,
            dim = 768,
            depth = 12,
            heads = 12,
            mlp_dim = 3072,
            dropout = 0.1,
            emb_dropout = 0.1
        )
        model.to(DEVICE)
    return model

def create_submission(model, test_loader, submission_file_path):
    predictions = []
    image_ids = []
    with torch.no_grad():
        for inputs, image_names in tqdm(test_loader, desc="Evaluating"):
            inputs = inputs.to(DEVICE)
            outputs = model(inputs).squeeze(1)
            probs = torch.sigmoid(outputs)
            predictions.extend(probs.cpu().numpy())
            image_ids.extend(image_names)
    if len(image_ids) != len(predictions):
        print(f"Warning: Number of image IDs ({len(image_ids)}) does not match number of predictions ({len(predictions)})")
    submission_df = pd.DataFrame({
        'isic_id': image_ids,
        'target': predictions
    })
    submission_df.to_csv(submission_file_path, index=False)
    print(f"Submission file saved to {submission_file_path}")

def visualize_train_images(images, titles=None):
    plt.figure(figsize=(15, 5))
    for i, image in enumerate(images):
        plt.subplot(1, len(images), i + 1)
        if isinstance(image, torch.Tensor):
            image = image.permute(1, 2, 0).numpy()
        plt.imshow(image)
        if titles:
            plt.title(f"{titles[0][i]} (label: {titles[1][i]})")
        plt.axis('off')
    plt.show()

def visualize_test_images(images, titles=None):
    plt.figure(figsize=(15, 5))
    for i, image in enumerate(images):
        plt.subplot(1, len(images), i + 1)
        if isinstance(image, torch.Tensor):
            image = image.permute(1, 2, 0).numpy()
        plt.imshow(image)
        if titles:
            plt.title(titles[i])
        plt.axis('off')
    plt.show()

def plot_metrics_from_files(file_paths, save_path=None):
    num_files = len(file_paths)
    rows = 3
    cols = num_files
    fig, axes = plt.subplots(rows, cols, figsize=(10 * cols, 10))
    if num_files == 1:
        axes = axes[:, None]
    for i, file_path in enumerate(file_paths):
        df = pd.read_csv(file_path)
        epochs = df['Epoch']
        learning_rate = df['Learning Rate']
        train_loss = df['Training Loss']
        valid_loss = df['Validation Loss']
        valid_pAUC = df['Validation pAUC']
        axes[0, i].plot(epochs, train_loss, label='Training Loss', marker='o')
        axes[0, i].plot(epochs, valid_loss, label='Validation Loss', marker='o')
        axes[0, i].set_title(f'File: {file_path.split("/")[-1]}')
        axes[0, i].set_xlabel('Epochs')
        axes[0, i].set_ylabel('Loss')
        axes[0, i].legend()
        axes[1, i].plot(epochs, learning_rate, label='Learning Rate', marker='o', color='orange')
        axes[1, i].set_title(f'File: {file_path.split("/")[-1]}')
        axes[1, i].set_xlabel('Epochs')
        axes[1, i].set_ylabel('Learning Rate')
        axes[1, i].legend()
        axes[2, i].plot(epochs, valid_pAUC, label='Validation pAUC', marker='o', color='green')
        axes[2, i].set_title(f'File: {file_path.split("/")[-1]}')
        axes[2, i].set_xlabel('Epochs')
        axes[2, i].set_ylabel('Validation pAUC')
        axes[2, i].legend()
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)

In [None]:
'''Training loop'''
def lesgooo():
    annotations_df_full = pd.read_csv(ANNOTATIONS_FILE, low_memory=False)
    df_positive_all = annotations_df_full[annotations_df_full["target"] == 1].reset_index(drop=True)
    df_negative_all = annotations_df_full[annotations_df_full["target"] == 0].reset_index(drop=True)
    df_negative_trunc = df_negative_all.sample(df_positive_all.shape[0]*NEG_POS_RATIO)
    annotations_df_trunc = pd.concat([df_positive_all, df_negative_trunc]).sample(frac=1).reset_index()
    val_pAUC = []
    with open(LOG_FILE_2, 'w', newline="") as f:
        csv_writer = csv.writer(f)
        csv_writer.writerow(['Fold', "Model_Name", "val_pAUC"])
        for fold in range(FOLDS):
            train_dl, val_dl, _ = get_loader(train_labels_df=annotations_df_trunc,
                                             train_hdf5_path=TRAIN_HDF5_PATH,
                                             test_hdf5_path=TEST_HDF5_PATH)
            model_vit = load_model()
            print(f"---------------\nTraining for fold: {fold+1}:\n---------------")
            val_pAUC_fold = train(epochs=EPOCHS,
                                        model=model_vit,
                                        learning_rate=LEARNING_RATE,
                                        train_dl=train_dl,
                                        val_dl=val_dl,
                                        min_epoch_train=MIN_EPOCH_TRAIN,
                                        patience=PATIENCE,
                                        epsilon=EPSILON,
                                        log_file=os.path.join(LOG_FILE_1, f'log_vit_aug_fold_{fold}.csv'),
                                        model_save_path=os.path.join(MODEL_SAVE_PATH_, f'model_vit_aug_fold_{fold}.pth'))
            val_pAUC.append(val_pAUC_fold)
            csv_writer.writerow([fold, os.path.basename(os.path.join(MODEL_SAVE_PATH_, f'model_vit_aug_fold_{fold}.pth')), val_pAUC_fold])
    best_model_fold_index = val_pAUC.index(max(val_pAUC))
    print(f"Average of OOF pAUC: {np.mean(val_pAUC)}")
    file_paths = [os.path.join(LOG_FILE_1, f'log_vit_aug_fold_{i}.csv') for i in range(FOLDS)]
    plot_metrics_from_files(file_paths, save_path=METRICS_PLOT_SAVE_PATH)
    
    return best_model_fold_index 

In [None]:
IDX = lesgooo()
print(f"\n\nDone, index = {IDX}\n")

In [None]:
MODEL_SAVE_PATH = f"/kaggle/working/model_vit_aug_fold_{IDX}.pth"
print(f"Best Model: {MODEL_SAVE_PATH}")

In [None]:
'''Evaluation'''
def predict():
    model = load_model(model_save_path=MODEL_SAVE_PATH)
    test_loader = get_loader(test_hdf5_path=TEST_HDF5_PATH)
    create_submission(model, test_loader, submission_file_path=SUBMISSION_FILE_PATH)

In [None]:
predict()