##  Imports

In [None]:
import os
import gc
import cv2
import math
import copy
import time
import random
import glob
from matplotlib import pyplot as plt

import h5py
from PIL import Image
from io import BytesIO

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
import torchvision

import joblib
from tqdm import tqdm
from collections import defaultdict

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold

import timm

import albumentations as A
from albumentations.pytorch import ToTensorV2

from colorama import Fore, Back, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

## Training Config

In [None]:
CONFIG = {
    "seed": 42,
    "img_size": 224,
    "model_name": "efficientnet_b5",
    "valid_batch_size": 128,
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
}

## Seed

In [None]:
def set_seed(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CONFIG['seed'])

In [None]:
ROOT_DIR = "/kaggle/input/isic-2024-challenge"
TEST_CSV = f'{ROOT_DIR}/test-metadata.csv'
TEST_HDF = f'{ROOT_DIR}/test-image.hdf5'
SAMPLE = f'{ROOT_DIR}/sample_submission.csv'

BEST_WEIGHT = '/kaggle/input/isic-skin-cancer/pytorch-attention/AUROC0.1841_Loss0.3200_epoch8.bin'

## Data

In [None]:
df = pd.read_csv(TEST_CSV)
df['target'] = 0 # dummy
df

In [None]:
df_sub = pd.read_csv(SAMPLE)
df_sub

## Dataset Class

In [None]:
class ISICDataset(Dataset):
    def __init__(self, df, file_hdf, transforms=None):
        self.df = df
        self.fp_hdf = h5py.File(file_hdf, mode="r")
        self.isic_ids = df['isic_id'].values
        self.targets = df['target'].values
        self.transforms = transforms
        
    def __len__(self):
        return len(self.isic_ids)
    
    def __getitem__(self, index):
        isic_id = self.isic_ids[index]
        img = np.array( Image.open(BytesIO(self.fp_hdf[isic_id][()])) )
        target = self.targets[index]
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            'image': img,
            'target': target,
        }

## Augmentations

In [None]:
data_transforms = {
    "valid": A.Compose(
        [
        A.LongestMaxSize(CONFIG['img_size'], interpolation=cv2.INTER_CUBIC),
        A.PadIfNeeded(min_height=CONFIG['img_size'], min_width=CONFIG['img_size'],
                  border_mode=cv2.BORDER_CONSTANT, fill=0, p=1.0),
        A.CLAHE(clip_limit=(1.5, 1.5), tile_grid_size=(12, 12), p=1.0),
        A.Normalize(
                max_pixel_value=255.0,
                normalization="image",
                p=1.0
            ),
        ToTensorV2()
        ],
        strict=True,
        p=1.0
    )
}

## Model

In [None]:
class simam_module(torch.nn.Module):
    def __init__(self, channels = None, e_lambda = 1e-4):
        super(simam_module, self).__init__()

        self.activaton = nn.Sigmoid()
        self.e_lambda = e_lambda

    def __repr__(self):
        s = self.__class__.__name__ + '('
        s += ('lambda=%f)' % self.e_lambda)
        return s

    @staticmethod
    def get_module_name():
        return "simam"

    def forward(self, x):

        b, c, h, w = x.size()
        
        n = w * h - 1

        x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)
        y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5

        return x * self.activaton(y)

if __name__ == "__main__":
    x = torch.randn(2, 64, 32, 32)
    attn = simam_module()
    y = attn(x)
    print(y.shape)

In [None]:
class Model_SimAM(nn.Module):
    """
    EfficientNet backbone with SimAM applied on the final spatial feature map.
    """
    def __init__(self, model_name, pretrained=True, num_classes=1, e_lambda=1e-4):
        super(Model_SimAM, self).__init__()
        full = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
        
        self.conv_stem   = full.conv_stem
        self.bn1         = full.bn1
        self.blocks      = full.blocks
        self.conv_head   = full.conv_head
        self.bn2         = full.bn2
        self.global_pool = full.global_pool
        self.simam       = simam_module(channels=full.num_features, e_lambda=e_lambda)

        self.classifier  = nn.Linear(full.num_features, num_classes)

    def forward(self, x):
        x = self.conv_stem(x)
        x = self.bn1(x)
        for block in self.blocks:
            x = block(x)
        x = self.conv_head(x)
        x = self.bn2(x)
        x = self.simam(x)
        x = self.global_pool(x)
        logits = self.classifier(x)
        return logits
    
model = Model_SimAM(CONFIG['model_name'], pretrained=False)
model.load_state_dict(torch.load(BEST_WEIGHT, map_location=CONFIG['device']))
model.to(CONFIG['device']);

## Dataloaders

In [None]:
test_dataset = ISICDataset(df, TEST_HDF, transforms=data_transforms["valid"])
test_loader = DataLoader(test_dataset, batch_size=CONFIG['valid_batch_size'], 
                          num_workers=2, shuffle=False, pin_memory=True)

## Inference

In [None]:
preds = []
with torch.no_grad():
    bar = tqdm(enumerate(test_loader), total=len(test_loader))
    for step, data in bar:        
        images = data['image'].to(CONFIG["device"], dtype=torch.float)        
        batch_size = images.size(0)
        outputs = model(images)
        probs = torch.sigmoid(outputs)
        preds.append(probs.detach().cpu().numpy())

preds = np.concatenate(preds).flatten()

In [None]:
df_sub["target"] = preds
df_sub.to_csv("submission.csv", index=False)

In [None]:
df_sub