Resnet34, randomrotation-randomflip, 224x224 resize, stratified split, focal loss

Public Score: 0.037

In [None]:
'''
Imports
'''
import h5py
# import os
# import shutil
import csv
import pandas as pd
import numpy as np
from io import BytesIO
from PIL import Image
import torch, torchvision
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.transforms import v2
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import StratifiedShuffleSplit

In [None]:
'''
Constants
'''
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 32
LEARNING_RATE = 0.01
CLASSES = 1
EPOCH = 3

In [None]:
# Paths
TRAIN_HDF5_PATH = "/kaggle/input/isic-2024-challenge/train-image.hdf5"
TEST_HDF5_PATH = "/kaggle/input/isic-2024-challenge/test-image.hdf5"
ANNOTATIONS_FILE = "/kaggle/input/isic-2024-challenge/train-metadata.csv"
MODEL_SAVE_PATH = "/kaggle/working/model_resnet34_aug_2-1.pth"
LOG_FILE = "/kaggle/working/log_res34_aug.csv"
RESNET34_IMAGENET_WEIGHTS_PYTORCH = "/kaggle/input/resnet34-weights/pytorch/nan/1/resnet34-b627a593.pth"        # change properly
SUBMISSION_FILE_PATH = "/kaggle/working/submission.csv"

In [None]:
# Check for NaN values in the dataset
print("Checking for NaN values in the dataset...")
df = pd.read_csv(ANNOTATIONS_FILE)
print(df.isna().sum())

In [None]:
'''
Utils
'''
def pauc_above_tpr(y_true, y_pred, min_tpr=0.80):
    y_true = abs(np.array(y_true) - 1)
    y_pred = -1.0 * np.array(y_pred)
    
    # Check for NaN values
    if np.isnan(y_true).any() or np.isnan(y_pred).any():
        print("NaN values detected in inputs to pauc_above_tpr")
        return 0
    
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    max_fpr = 1 - min_tpr

    stop = np.searchsorted(fpr, max_fpr, "right")
    x_interp = [fpr[stop - 1], fpr[stop]]
    y_interp = [tpr[stop - 1], tpr[stop]]
    tpr = np.append(tpr[:stop], np.interp(max_fpr, x_interp, y_interp))
    fpr = np.append(fpr[:stop], max_fpr)
    
    if len(fpr) < 2:
        print("Warning: Not enough points to compute pAUC. Returning 0.")
        return 0
    
    partial_auc = auc(fpr, tpr)

    return partial_auc

def evaluate(loader, model):
    criterion = FocalLoss(alpha=0.25, gamma=2)
    metric = BinaryF1Score(threshold=0.5).to(DEVICE)
    prec = BinaryPrecision(threshold=0.5).to(DEVICE)
    recall = BinaryRecall(threshold=0.5).to(DEVICE)
    acc = BinaryAccuracy(threshold=0.5).to(DEVICE)
    loss = 0.0
    num_corr = 0
    num_samp = 0
    all_preds = []
    all_labels = []
    
    model.eval()
    with torch.no_grad():
        for inputs, labels, _ in tqdm(loader):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs).squeeze(1)
            
            # Check for NaN in outputs
            if torch.isnan(outputs).any():
                print("NaN detected in model outputs")
                continue
            
            preds = torch.sigmoid(outputs)
            num_corr += ((preds > 0.5) == labels).sum()
            num_samp += preds.size(0)
            loss += criterion(outputs, labels.float()).item()
            metric.update(preds, labels)
            prec.update(preds, labels)
            recall.update(preds, labels)
            acc.update(preds, labels)
            all_preds.extend(preds.cpu().detach().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = loss / len(loader)
    accu = float(num_corr) / float(num_samp)
    pauc = pauc_above_tpr(all_labels, all_preds)
    
    print(f"Total loss: {loss}, Average loss: {avg_loss}")
    print(f"Got {num_corr}/{num_samp} correct with accuracy {accu*100:.2f}")
    print(f"pAUC above 80% TPR: {pauc:.3f}, Accuracy: {acc.compute().item():.3f}, precision: {prec.compute().item():.3f}, recall: {recall.compute().item():.3f}, F1Score: {metric.compute().item():.3f}")
    model.train()

    return avg_loss, acc.compute().item(), prec.compute().item(), recall.compute().item(), metric.compute().item(), pauc

def save_checkpoint(state, filename="my_checkpoint.pth"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, output, target):
        eps = 1e-7  # Small epsilon to prevent log(0)
        ce_loss = F.binary_cross_entropy_with_logits(output, target.float(), reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt + eps) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss
        
def load_model():
    model = torchvision.models.resnet34(weights=None)
    num_ftrs = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=num_ftrs, out_features=1)
    model.load_state_dict(torch.load(MODEL_SAVE_PATH)["state_dict"])
    model.to(DEVICE)
    model.eval()
    return model

def create_submission(model, test_loader):
    predictions = []
    image_ids = []

    with torch.no_grad():
        for inputs, image_names in tqdm(test_loader, desc="Evaluating"):
            inputs = inputs.to(DEVICE)
            outputs = model(inputs).squeeze(1)
            probs = torch.sigmoid(outputs)
            predictions.extend(probs.cpu().numpy())
            image_ids.extend(image_names)  # Append all image names from the batch

    # Check if the lengths match
    if len(image_ids) != len(predictions):
        print(f"Warning: Number of image IDs ({len(image_ids)}) does not match number of predictions ({len(predictions)})")

    # Create DataFrame
    submission_df = pd.DataFrame({
        'isic_id': image_ids,
        'target': predictions
    })

    # Save to CSV
    submission_df.to_csv(SUBMISSION_FILE_PATH, index=False)
    print(f"Submission file saved to {SUBMISSION_FILE_PATH}")

In [None]:
'''
Transformations
'''
TRAIN_TRANS = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToImage(),
    v2.RandomRotation(degrees=(0, 360)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale = True)
])
TEST_TRANS = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale = True)
])

In [None]:
'''
DataClass
'''
class ISIC2024_HDF5(Dataset):
    def __init__(self, hdf5_path, annotations_file=None, transform=None):
        self.hdf5_path = hdf5_path
        self.annotations_file = annotations_file
        self.transform = transform
        self.image_ids = []
        
        self.hdf5_file = h5py.File(self.hdf5_path, 'r')
        self.image_ids = list(self.hdf5_file.keys())

        if self.annotations_file is not None:
            self.labels = pd.read_csv(annotations_file, low_memory=False).set_index('isic_id')['target'].to_dict()

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image = Image.open(BytesIO(self.hdf5_file[image_id][()]))

        if self.transform:
            image = self.transform(image)

        # Check for NaN in image
        if torch.isnan(image).any():
            print(f"NaN detected in image {image_id}")

        if self.annotations_file is not None:
            label = self.labels[image_id]
            # Check for NaN in label
            if np.isnan(label):
                print(f"NaN detected in label for image {image_id}")
            return image, label, image_id
        else:
            return image, image_id

    def close(self):
        self.hdf5_file.close()

In [None]:
'''
Data Loader
'''
def get_loader(dataset_cls=ISIC2024_HDF5,
               train_hdf5_path=TRAIN_HDF5_PATH, 
               test_hdf5_path=TEST_HDF5_PATH, 
               train_labels_file=ANNOTATIONS_FILE, 
               train_img_trans=TRAIN_TRANS, 
               test_img_trans=TEST_TRANS, 
               batch=32, 
               seed=None):
    
    train_dataset_all = dataset_cls(hdf5_path=train_hdf5_path, annotations_file=train_labels_file, transform=train_img_trans)
    test_dataset = dataset_cls(hdf5_path=test_hdf5_path, transform=test_img_trans)

    train_annotations_all = pd.read_csv(train_labels_file, low_memory=False)
    labels = train_annotations_all['target']
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)
    train_idx, val_idx = next(splitter.split(train_annotations_all, labels))
    train_subset = Subset(train_dataset_all, train_idx)
    val_subset = Subset(train_dataset_all, val_idx)

    train_loader = DataLoader(train_subset, batch_size=batch, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch, shuffle=False)
    
    return train_loader, val_loader, test_loader

In [None]:
'''
Loaders and model definition
'''
train_dl, val_dl, test_dl = get_loader(batch=BATCH_SIZE, seed=42)
model_resnet = torchvision.models.resnet34(weights=None)
model_resnet.load_state_dict(torch.load(RESNET34_IMAGENET_WEIGHTS_PYTORCH))
num_ftrs = model_resnet.fc.in_features
model_resnet.fc = nn.Linear(in_features=num_ftrs, out_features=1)
model_resnet.to(DEVICE)

In [None]:
'''
Training Loop
'''
def train(epochs, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_func = FocalLoss(alpha=0.25, gamma=2)
    scaler = torch.cuda.amp.GradScaler()
    with open(LOG_FILE, 'w', newline='') as f:
        csv_writer = csv.writer(f)
        csv_writer.writerow(['Epoch', 'Training Loss', 'Training Accuracy', 'Validation Loss', 'Validation Accuracy', 'Validation Precision', 'Validation Recall', 'Validation F1 Score', 'Validation pAUC'])
        for epoch in range(epochs):
            print(f"\n | Epoch: {epoch+1}")
            total_loss = 0
            num_corr = 0
            num_samp = 0
            loop = tqdm(train_dl)
            model.train()
            for batch_idx, (inputs, labels, _) in enumerate(loop):
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                optimizer.zero_grad()
                with torch.cuda.amp.autocast():
                    outputs = model(inputs).squeeze(1)
                    loss = loss_func(outputs, labels.float())
                
                # Check for NaN in loss
                if torch.isnan(loss):
                    print(f"NaN loss detected at batch {batch_idx}")
                    continue
                
                scaler.scale(loss).backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                scaler.step(optimizer)
                scaler.update()
                preds = torch.sigmoid(outputs)
                num_corr += ((preds > 0.5) == labels).sum()
                num_samp += preds.size(0)
                total_loss += loss.item()
                loop.set_postfix(loss=loss.item())
            avg_loss = total_loss / len(train_dl)
            acc = num_corr / num_samp
            print(f"| Epoch {epoch+1}/{epochs} total training loss: {total_loss}, average training loss: {avg_loss}.")
            print("On Validation Data:")
            model.eval()
            with torch.inference_mode():
                val_loss, val_acc, val_pre, val_rec, val_f1, val_pauc = evaluate(val_dl, model)
            row = [epoch+1, avg_loss, acc.item(), val_loss, val_acc, val_pre, val_rec, val_f1, val_pauc]
            csv_writer.writerow(row)
            print('Saving model...')
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            save_checkpoint(checkpoint, MODEL_SAVE_PATH)
            print(f'Model saved at {MODEL_SAVE_PATH}')

In [None]:
'''
Train
'''
train(epochs=EPOCH, model=model_resnet)

In [None]:
'''
Generating submission file
'''
model = load_model()
_, _, test_loader = get_loader()
create_submission(model, test_loader)