# Setup

In [None]:
!pip install timm -q
!pip install albumentations --upgrade -q
!pip install segmentation_models_pytorch -q

# General imports.
import gc
import os
import cv2
import timm
import torch
import random
import sklearn
import numpy as np
import pandas as pd
import tensorflow as tf
import albumentations as A
import segmentation_models_pytorch


# Specific Imports.
from torch import nn
from tqdm import tqdm
from tensorflow import keras
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from kaggle_datasets import KaggleDatasets
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import StratifiedShuffleSplit
from segmentation_models_pytorch.encoders import get_encoder
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from segmentation_models_pytorch.base import initialization as init
from torch.utils.data.sampler import SequentialSampler, RandomSampler

import warnings
warnings.filterwarnings("ignore")

In [None]:
!pip install wandb -qqq
import wandb
from wandb.keras import WandbCallback
wandb.login()

# Utility Functions

In [None]:
def seed_everything(SEED):
    os.environ['PYTHONHASHSEED']=str(SEED)
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    os.environ['TF_CUDNN_DETERMINISTIC'] = str(SEED)
    
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return strategy


# Brief Descriptive Analysis and EDA (EDA can be done later)

In [None]:
# Ref: https://www.kaggle.com/andrewmvd/isic-2019.
# Note: Everything is done within a Kaggle Notebook.

train_val_test_img_path = r"/kaggle/input/isic-2019/ISIC_2019_Training_Input/ISIC_2019_Training_Input"
gt_path = r"/kaggle/input/isic-2019/ISIC_2019_Training_GroundTruth.csv"
metadata_path = r"/kaggle/input/isic-2019/ISIC_2019_Training_Metadata.csv"

ground_truth_df = pd.read_csv(gt_path)
metadata_df = pd.read_csv(metadata_path)

In [None]:
print("Ground Truth DataFrame"); display(ground_truth_df)
print("")
print("Metadata DataFrame"); display(metadata_df)

In [None]:
# Exploring the Ground Truth DataFrame.
print(f"Shape of Dataset: {ground_truth_df.shape}")
print(f"Number of Unique Image Identifiers (ID): {ground_truth_df.image.nunique()}", end="\n\n")
print("<=======Info=======>")
ground_truth_df.info(); print()

## Checking the validity of all unique values in the "image" column.
for idx, image_name in enumerate(ground_truth_df["image"]):
    if "ISIC_" not in image_name: print(f"Row {idx} has an invalid image name.")

## Looking at both the unique values for each column and their counts.
print("<=======Value Counts=======>")
for column in ground_truth_df.columns:
    display(ground_truth_df[column].value_counts().sort_index()); print()

In [None]:
# Exploring the Metadata DataFrame.
print(f"Shape of Dataset: {metadata_df.shape}")
print(f"Number of Unique Image Identifiers (ID): {metadata_df.image.nunique()}", end="\n\n")
print("<=======Info=======>")
metadata_df.info(); print()

## Checking the validity of all unique values in the "image" column.
for idx, image_name in enumerate(metadata_df["image"]):
    if "ISIC_" not in image_name: print(f"Row {idx} has an invalid image name.")

## Looking at both the unique values for each column and their counts.
print("<=======Value Counts=======>")
for column in metadata_df.columns:
    display(metadata_df[column].value_counts().sort_index()); print()

In [None]:
# Exploring the images folder.
files = os.listdir(train_val_test_img_path)
print(f"Number of files in train folder: {len(files)}")
for file in files:
    if ".jpg" not in file:
        print(f"Non-Image File Found: {file}")

# Hyperparameters and Pre-defined Terms

In [None]:
classes = [
    'Melanoma',
    'Melanocytic nevus',
    'Basal cell carcinoma',
    'Actinic keratosis',
    'Benign keratosis', # Also: (solar lentigo / seborrheic keratosis / lichen planus-like keratosis).
    'Dermatofibroma',
    'Vascular lesion',
    'Squamous cell carcinoma',
    'Unknown' # Used for unlabelled scans.
]

classes_abbrev = ["MEL","NV","BCC","AK","BKL","DF","VASC","SCC","UNK"]

# Final classes dictionary which excludes "Unknown" classes.
CLASSES_DICT = dict(tuple(zip(classes_abbrev[:-1], classes[:-1])))

seed = 42
n_splits = 1
batch_size = strategy.num_replicas_in_sync * 20

encoder_name = "timm-efficientnet-b5"
in_channels = 3
depth = 5
pretrained_weights = "noisy-student"
in_features = 1024
strategy = auto_select_accelerator()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
init_lr = 0.0001

epochs = 20

MODEL_SAVE_PATH = '{}.pth'.format(encoder_name)

# Dataset Generator

In [None]:
class skin_cancer_ds(Dataset):
    def __init__(self, df, image_size, mode):
        super(skin_cancer_ds, self).__init__()
        self.df = df
        self.image_size = image_size
        assert mode in ['train', 'valid', 'test']
        self.mode = mode

        if self.mode == 'train':
            self.transform = A.Compose([
                A.RandomResizedCrop(height=self.image_size, width=self.image_size, scale=(0.25, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=1, p=1.0),
                A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=30, interpolation=1, border_mode=0, value=0, p=0.25),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.OneOf([
                    A.MotionBlur(p=.2),
                    A.MedianBlur(blur_limit=3, p=0.1),
                    A.Blur(blur_limit=3, p=0.1),
                ], p=0.25),
                A.Cutout(num_holes=4, max_h_size=32, max_w_size=32, fill_value=0, p=0.25),
                A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
                ToTensorV2(),
            ])

        else:
            self.transform = A.Compose([
                A.Resize(self.image_size, self.image_size),
                A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
                ToTensorV2(),
            ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        img_path = train_val_test_img_path + f'/{self.df.loc[index]["image"]}.jpg'
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        image = np.stack([image, image, image], axis=-1)
        image = self.transform(image=image)["image"]
        if self.mode in ['train', 'valid']:
            label = torch.tensor(np.argmax(self.df.loc[index, CLASSES_DICT].values))
            # label = torch.Tensor(self.df.loc[index, CLASSES_DICT])
            return image, label
        else:
            return image

In [None]:
seed_everything(seed)

# Splitting train and val from test via a stratified shuffle split.
trainval_test_split = StratifiedShuffleSplit(n_splits=n_splits, train_size=0.9, random_state=seed)
for train_val_index, test_index in trainval_test_split.split(ground_truth_df["image"].values, 
                                                             np.argmax(ground_truth_df[CLASSES_DICT].values, axis=1)
):
    train_val_df = ground_truth_df.loc[train_val_index].reset_index(drop=True)
    test_df = ground_truth_df.loc[test_index].reset_index(drop=True)

# Splitting train and val via a stratified shuffle split.
train_val_split = StratifiedShuffleSplit(n_splits=n_splits, train_size=0.9, random_state=seed)
for train_index, val_index in train_val_split.split(train_val_df["image"].values, 
                                                    np.argmax(train_val_df[CLASSES_DICT].values, axis=1)
):
    train_df = train_val_df.loc[train_index].reset_index(drop=True)
    val_df = train_val_df.loc[val_index].reset_index(drop=True)

In [None]:
train_df

In [None]:
val_df

In [None]:
test_df

In [None]:
train_ds = skin_cancer_ds(train_df, 384, "train")
val_ds = skin_cancer_ds(val_df, 384, "valid")
test_ds = skin_cancer_ds(test_df, 384, "test")

In [None]:
train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=RandomSampler(train_ds))
val_loader = DataLoader(val_ds, batch_size=batch_size, sampler=SequentialSampler(val_ds))
test_loader = DataLoader(test_ds, batch_size=batch_size, sampler=SequentialSampler(test_ds))

In [None]:
DS_GCS_PATH = KaggleDatasets().get_gcs_path("isic-2019")  # Trying TPU with tf.

In [None]:
tmp_path = "/ISIC_2019_Training_Input/ISIC_2019_Training_Input/"

train_paths = [DS_GCS_PATH + tmp_path + image_name + ".jpg" for image_name in train_df.image.values]
val_paths = [DS_GCS_PATH + tmp_path + image_name + ".jpg" for image_name in val_df.image.values]
test_paths = [DS_GCS_PATH + tmp_path + image_name + ".jpg" for image_name in test_df.image.values]

train_labels = train_df[CLASSES_DICT].values
val_labels = val_df[CLASSES_DICT].values
test_labels = test_df[CLASSES_DICT].values

In [None]:
def build_decoder(with_labels=True, target_size=(256, 256), ext='jpg'):
    def decode(path):
        file_bytes = tf.io.read_file(path)

        if ext == 'png':
            img = tf.image.decode_png(file_bytes, channels=3)
        elif ext in ['jpg', 'jpeg']:
            img = tf.image.decode_jpeg(file_bytes, channels=3)
        else:
            raise ValueError("Image extension not supported")
        img = tf.cast(img, tf.float32) / 255.0
        img = tf.image.resize(img, target_size)

        return img
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    return decode_with_labels if with_labels else decode

def build_augmenter(with_labels=True):
    def augment(img):
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)
        img = tf.image.random_saturation(img, 0.9, 1.1)
        img = tf.image.random_contrast(img, 0.9, 1.1)
        img = tf.image.random_brightness(img, 0.1)
        return img
    
    def augment_with_labels(img, label):
        return augment(img), label
    
    return augment_with_labels if with_labels else augment


def build_dataset(paths, labels=None, bsize=128, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, shuffle=1024, 
                  cache_dir=""):
    
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    
    slices = paths if labels is None else (paths, labels)
    
    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache else dset
    dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
        
    dset = dset.shuffle(shuffle) if shuffle else dset
    dset = dset.batch(bsize).prefetch(AUTO)
    
    return dset

In [None]:
decoder = build_decoder(with_labels=True, target_size=(512, 512), ext='jpg')

train_dataset = build_dataset(
    train_paths, train_labels, bsize=batch_size, decode_fn=decoder
)

valid_dataset = build_dataset(
    val_paths, val_labels, bsize=batch_size, decode_fn=decoder,
    shuffle=False, augment=False
)

In [None]:
with strategy.scope():
    model = keras.Sequential([
                keras.applications.efficientnet.EfficientNetB3(include_top=False,
                                                               input_shape=(512, 512, 3)),
                tf.keras.layers.GlobalAveragePooling2D(),
                tf.keras.layers.Dense(len(CLASSES_DICT.keys()), activation='softmax')
            ])

    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss='categorical_crossentropy',
                  metrics=[tf.keras.metrics.AUC(multi_label=True)])

checkpoint = tf.keras.callbacks.ModelCheckpoint(
        f'model.h5', save_best_only=True, monitor='loss', mode='min')
lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(
        monitor="loss", patience=3, min_lr=1e-6, mode='min')
    
history = model.fit(
    train_dataset, 
    epochs=epochs,
    verbose=1,
    callbacks=[checkpoint, lr_reducer])

# Building the Model

In [None]:
efnb5_noisy_student_encoder = get_encoder(encoder_name, 
                                          in_channels=in_channels,
                                          depth=depth,
                                          weights=pretrained_weights)

class EfficientNetB5ClsHead(nn.Module):
    def __init__(self, encoder, in_features):
        super(EfficientNetB5ClsHead, self).__init__()
        self.encoder = encoder
        self.flatten_block = nn.Sequential(*list(self.encoder.children())[-4:])
        
        # Note: There seems to be a problem when I just slice the list of children layers. 
        # Deletion works however.
        del self.encoder.global_pool
        del self.encoder.act2
        del self.encoder.bn2
        del self.encoder.conv_head
        
        self.fc = nn.Linear(2048, in_features, bias=True)  
        self.cls_head = nn.Linear(in_features, len(CLASSES_DICT.keys()), bias=True)
        
        # Xavier uniform weight initialization.
        init.initialize_head(self.fc)
        init.initialize_head(self.cls_head)
    
#     @autocast
    def forward(self, x):
        x = self.encoder(x)[-1]  # Output shape: (batch_size, 640, 16, 16).
        x = self.flatten_block(x)  # Output shape: (batch_size, 2048).
        x = self.fc(x)  # Output shape: (batch_size, 1024).
        x = F.relu(x)  # Output shape: (batch_size, 1024).
        x = F.dropout(x, p=0.5, training=self.training)  # Output shape: (batch_size, 1024).
        x = self.cls_head(x)  # Output shape: (batch_size, 8).
        return x
    
model = EfficientNetB5ClsHead(efnb5_noisy_student_encoder, in_features=in_features)

# Training

In [None]:
model.to(device)

# To handle class imbalance we can weigh each class. 
# Do something like this and pass it into the Loss function:

# CE_weights = torch.zeros(len(CLASSES_DICT.keys()))  # This takes into account the imbalanced dataset.
# Increment CE_weights e.g. class 0 has 2439 counts then CE_weights[0] has 2439.
# CE_weights = 1. / CEweights.clamp_(min=1.)  # Weights should be inversely related to count.
# CE_weights = (CE_weights * numClass / CE_weights.sum()).to(device)

criterion = nn.CrossEntropyLoss(weight=None)
optimizer = torch.optim.Adam(model.parameters(), lr=init_lr)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs-1)
scaler = torch.cuda.amp.GradScaler()

val_loss_min = np.Inf  # Save model with best performance on val_loss.

run = wandb.init(project="cancer_classification_IH2021", name=f"efnb5-noisy-stdnt")  # Initialize a project.

for epoch in range(1, epochs+1):
    scheduler.step()
    model.train()
    train_loss = []

    loop = tqdm(train_loader)
    for images, labels in loop:
        images = images.to(device)
        labels = labels.to(device)
                
        optimizer.zero_grad()

        with torch.cuda.amp.autocast(): 
            outputs = model(images)
            loss = criterion(outputs.float(), labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss.append(loss.item())
        loop.set_description('Epoch {:02d}/{:02d}'.format(epoch, epochs))
        loop.set_postfix(loss=np.mean(train_loss))
        
        del images, labels
        gc.collect()
        torch.cuda.empty_cache()
        
    train_loss = np.mean(train_loss)

    model.eval()

    val_loss = 0.0
    for images, labels in tqdm(valid_loader):
        images = images.to(device)
        labels = labels.to(device)

        with torch.cuda.amp.autocast(), torch.no_grad():
            outputs = model(images)
            loss = criterion(outputs.float(), labels)
                
        val_loss += loss.item() * images.size(0)
            
        del images, labels
        gc.collect()
        torch.cuda.empty_cache()
            
    val_loss = val_loss / len(valid_loader.dataset)
            
    print('train loss: {:.5f} | val_loss: {:.5f}'.format(train_loss, val_loss))
            
    wandb.log({"epoch": epoch, 
            "loss": train_loss, 
            "val_loss": val_loss,
        })
            
    if val_loss < val_loss_min:
        print('Valid loss improved from {:.5f} to {:.5f}, saving model to wandb.'.format(val_loss_min, val_loss))
        val_loss_min = val_loss
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        artifact = wandb.Artifact(encoder_name, type='model')
        artifact.add_file(MODEL_SAVE_PATH, name=f"model{epoch}.pt")
        run.log_artifact(artifact)
        
    del train_loss
    gc.collect()
    torch.cuda.empty_cache()
        
run.finish()