# Baseline

In [None]:
# !python -m pip install -q lightning
# !pip install -q git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git

In [None]:
import os, gc, sys, copy
from pathlib import Path
import glob
from collections import defaultdict, Counter
from tqdm.auto import tqdm
tqdm.pandas()

import random
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import model_selection
from sklearn import metrics
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, SequentialSampler

import albumentations as A
import timm

from transformers import get_cosine_schedule_with_warmup

import cv2
import PIL
from IPython import display

import tensorflow as tf

import warnings
warnings.filterwarnings("ignore")

In [None]:
def seeding(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
    print(f"Seeding done ...")
    
def flush():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

In [None]:
CONFIG = dict(
    seed = 42,
    nfolds = 5,
    backbone = "tf_efficientnet_b0.ns_jft_in1k", # convnext_base.fb_in1k, 
    
    drop_rate = 0,
    drop_rate_last = 0.3,
    drop_rate_path = 0.,
    out_dim = 1,
    batch_size = 32,
    img_size = 224,
    lr = 1e-3,
    warmup = 1,
    num_cycles = 0.475,
    epochs = 5,
    patience = 7,
    log_wandb = True,
    with_clip = False,
    device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
)

if CONFIG['log_wandb']:
    CONFIG['project_name'] = "ISIC2024-TransferLearning"
    CONFIG['artifact_name'] = "isicBaseModel"
    import wandb
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    secret_value_0 = user_secrets.get_secret("WANDB_API_KEY")
    wandb.login(key=secret_value_0)
    
seeding(CONFIG['seed'])

In [None]:
DATA_PATH_2024 = Path("/kaggle/input/isic-2024-challenge")
DATA_PATH_2020 = Path("/kaggle/input/isic-2020-jpg-256x256-resized")
DATA_PATH_2019 = Path("/kaggle/input/isic-2019-jpg-256x256-resized")
DATA_PATH_2018 = Path("/kaggle/input/isic-2018-jpg-256x256-resized")

In [None]:
df_train_2024 = pd.read_csv(DATA_PATH_2024/"train-metadata.csv")
df_train_2020 = pd.read_csv(DATA_PATH_2020/"train-metadata.csv")
df_train_2019 = pd.read_csv(DATA_PATH_2019/"train-metadata.csv")
df_train_2018 = pd.read_csv(DATA_PATH_2018/"train-metadata.csv")

## Check image paths

In [None]:
get_image_path_2024 = lambda p: os.path.join(f'{str(DATA_PATH_2024/"train-image/image")}/{p}.jpg')
get_image_path_2020 = lambda p: os.path.join(f'{str(DATA_PATH_2020/"train-image/image")}/{p}.jpg')
get_image_path_2019 = lambda p: os.path.join(f'{str(DATA_PATH_2019/"train-image/image")}/{p}.jpg')
get_image_path_2018 = lambda p: os.path.join(f'{str(DATA_PATH_2018/"train-image/image")}/{p}.jpg')

check_path = lambda p: tf.io.gfile.exists(p)

df_train_2024['image_path'] = df_train_2024['isic_id'].progress_apply(get_image_path_2024)
df_train_2020['image_path'] = df_train_2020['isic_id'].progress_apply(get_image_path_2020)
df_train_2019['image_path'] = df_train_2019['isic_id'].progress_apply(get_image_path_2019)

df_train_2018['image_path'] = df_train_2018['isic_id'].progress_apply(get_image_path_2018)
print("\nChecking 2018 image files ...")
df_train_2018['exists'] = df_train_2018['image_path'].progress_apply(check_path)
display.display(df_train_2018['exists'].value_counts())
df_train_2018 = df_train_2018[df_train_2018['exists'] == True].reset_index()

In [None]:
target_2020 = df_train_2020[df_train_2020['target'] == 1]
target_2019 = df_train_2019[df_train_2019['target'] == 1]
target_2018 = df_train_2018[df_train_2018['target'] == 1]

combined_data = pd.concat([df_train_2024, target_2020, target_2019, target_2018], axis=0).reset_index(drop=True)
combined_data = combined_data[['patient_id', 'target', 'image_path']]
# pd.concat([extern_data] * 10)
# extern_data = extern_data

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(7,7))
axes = axes.flatten()

for i in range(20):
    path = df_train_2024.loc[i, 'image_path']
    img = PIL.Image.open(path).convert("RGB")
    axes[i].imshow(img)
    axes[i].axis(False)
    
plt.tight_layout()
plt.show()

# Distribution of targets in the datasets

In [None]:
# plt.title("2024")
# df_train_2024['target'].value_counts().plot(kind='bar', figsize=(20,4));

## Split data

In [None]:
kfold = model_selection.StratifiedGroupKFold(n_splits=CONFIG['nfolds'], random_state=CONFIG['seed'], shuffle=True)
df = combined_data.sample(frac=1, random_state=CONFIG['seed']).reset_index(drop=True)

df = df.sample(frac=1, random_state=CONFIG['seed']).reset_index(drop=True)
df['fold'] = -1
x = df.index.values
y = df.target.astype(int)
g = df['patient_id']

for fold, (tr_idx, val_idx) in enumerate(kfold.split(x,y,g)):
    df.loc[val_idx, 'fold'] = fold
    
df.groupby('fold')['target'].value_counts()

# Create dataset

In [None]:
class ISICDataset(Dataset):
    def __init__(self, df, mode='train', transform=None, target='target'):
        self.df = df
        self.mode = mode
        self.transform = transform
        self.label = df[target]
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row.image_path
        target = row.target
        
        img = PIL.Image.open(img_path).convert("RGB")
        img = np.array(img)
        if self.transform is not None:
            img = self.transform(image=img)['image']
            
#         img = img.transpose(2, 0, 1).astype(np.float32) / 255.
        img = img.transpose(2, 0, 1).astype(np.float32) 
        return {"image": torch.tensor(img).float(), "target": torch.tensor(target)}
    
    def get_labels(self):
        return self.label

# xdf = df[df['fold'] == 3].reset_index(drop=True)
# labels = xdf['target'].astype(int)
labels = df['target'].astype(int)
WEIGHTS = compute_class_weight(class_weight="balanced", classes=np.unique(labels), y=labels)
WEIGHTS

## Check distribution of labels within the batch

To avoid highly imbalanced batches, its preferable to use larger batch sizes at least 16 and above

In [None]:
def check_class_distribution(dataloader):
    for i, batch in enumerate(dataloader):
        labels = batch['target']
        class_dist = Counter(labels.detach().cpu().numpy())
        if i % 500 == 1:
            print(f"Batch {i}\tClass Distribution: {class_dist}")
        
# dls = DataLoader(ds, batch_size=32, sampler=sampler, num_workers=os.cpu_count(), drop_last=True)
# check_class_distribution(dls)

## Create data loaders

In [None]:
def get_transforms(height, width):
    train_tsfm = A.Compose([
        A.Transpose(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=(-25, 25), p=0.5), 
        
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.75),
        
        A.OneOf([
            A.MotionBlur(blur_limit=5),
            A.MedianBlur(blur_limit=5),
            A.GaussianBlur(blur_limit=5),
            A.GaussNoise(var_limit=(5.0, 30.0))
        ], p=0.7),
        
        A.OneOf([
            A.OpticalDistortion(distort_limit=1.0),
            A.GridDistortion(num_steps=5, distort_limit=1.0),
            A.ElasticTransform(alpha=3),
        ], p=0.7),
        
        A.CLAHE(clip_limit=4.0, p=0.7),
        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
        
        A.Resize(height=height, width=width, p=1.0),
        A.CoarseDropout(max_holes=2, max_height=int(height * 0.275), max_width=int(width * 0.275), p=0.7),
        A.Normalize(mean=[0.485, 0.456, 0.406], 
                    std=[0.229, 0.224, 0.225], 
                    max_pixel_value=255.0, p=1)
    ])
    
    valid_tsfm = A.Compose([
        A.Resize(height=width, width=width, p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], 
                    std=[0.229, 0.224, 0.225], 
                    max_pixel_value=255.0, p=1)
    ])
    return {"train": train_tsfm, "eval": valid_tsfm}


def get_dataloaders(data, cfg, split="train"):
    img_size = cfg['img_size']
    height, width = img_size, img_size
    tsfm = get_transforms(height=height, width=width)
    if split == 'train':
        tr_tsfm = tsfm['train']
        ds = ISICDataset(data, transform=tr_tsfm)
        labels = ds.get_labels()
        class_weights = torch.tensor(compute_class_weight(class_weight="balanced", classes=np.unique(labels), y=labels))
        samples_weights = class_weights[labels]
        sampler = WeightedRandomSampler(weights=samples_weights, 
                                        num_samples=len(samples_weights), 
                                        replacement=True)

        dls = DataLoader(ds, 
                         batch_size=cfg['batch_size'], 
                         sampler=sampler, 
#                          shuffle=True,
                         num_workers=os.cpu_count(), 
                         pin_memory=True, 
                         drop_last=True)
        
    elif split == 'valid' or split == 'test':
        eval_tsfm = tsfm['eval']
        ds = ISICDataset(data, transform=eval_tsfm)
        dls = DataLoader(ds, 
                         batch_size=2*cfg['batch_size'], 
                         shuffle=False,
                         num_workers=os.cpu_count(), 
                         pin_memory=True,
                         drop_last=False)
    else:
        raise Exception("Split should be 'train' or 'valid' or 'test'!!!")
    return dls

In [None]:
dls = get_dataloaders(df, CONFIG, split='train')
# check_class_distribution(dls)

b = next(iter(dls))

fig, axes = plt.subplots(4, 8, figsize=(10,10))
axes = axes.flatten()
images = b['image'].detach().cpu().numpy().transpose(0, 2, 3, 1)
targets = b['target'].detach().cpu().numpy()
for i in range(32):
    axes[i].imshow(images[i, ...])
    axes[i].set_title(f"{targets[i]}")
    axes[i].axis(False)
plt.tight_layout()
plt.show()

# Model setup

In [None]:
class ISICModel(nn.Module):
    def __init__(self, backbone, pretrained=False):
        super(ISICModel, self).__init__()
        self.encoder = timm.create_model(
            backbone,
            features_only=False,
            drop_rate=CONFIG['drop_rate'],
            drop_path_rate=CONFIG['drop_rate_path'],
            pretrained=pretrained,
        )
        
        self.nb_fts = self.encoder.num_features
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.lstm = nn.LSTM(self.nb_fts, 256, num_layers=2, dropout=CONFIG["drop_rate"], bidirectional=True, batch_first=True)
        self.head = nn.Sequential(
                        nn.Linear(512, 256),
                        nn.BatchNorm1d(256),
                        nn.Dropout(CONFIG["drop_rate_last"]),
                        nn.LeakyReLU(0.1),
                        nn.Linear(256, 1),
                    )
        
    def forward(self, x):
        feat = self.encoder.forward_features(x)
        feat = self.gap(feat)[:,:,0,0]
        feat, _ = self.lstm(feat)
        y = self.head(feat)
        
        return y
    
    def freeze_encoder(self, flag):
        for param in self.encoder.parameters():
            param.requires_grad = not flag
            
# net = ISICModel('resnet18')
# x = torch.rand(8, 3, 224, 224)
# net.eval()
# net(x)

## Weighted Binary Cross Entropy lossess

Considering that the data is heavily imbalanced, it's definitely, a good practice to use a weighted loss for this task

For binary classification imbalanced dataset, always consider positive values to the minority class and negative values to the majority class [how-can-i-know-which-is-the-positive-class-value-and-negative-class-value](https://stackoverflow.com/questions/65304302/how-can-i-know-which-is-the-positive-class-value-and-negative-class-value-for-xg)

Based on 2024 dataset, there is a **very small fraction of positive samples (malignant)** with ratio of `0.5:510`.

After adding the external data, we get a ratio of `0.5:33` 

More information regarding how to deal with class imbalance can be found here [imbalanced_data](https://www.tensorflow.org/tutorials/structured_data/imbalanced_data), [class-imbalance-weighted-binary-cross-entropy](https://www.kaggle.com/code/parthdhameliya77/class-imbalance-weighted-binary-cross-entropy)

# Training with pytorch lightning

In [None]:
# def get_scheduler(optimizer, warmup=True):
#     if warmup:
#         scheduler = get_cosine_schedule_with_warmup(
#                 optimizer,
#                 num_warmup_steps=CONFIG["warmup"] * (CONFIG['n_steps_per_epoch']),
#                 num_training_steps=CONFIG["epochs"]* (CONFIG['n_steps_per_epoch']),
#                 num_cycles = CONFIG["num_cycles"],
#             )
#     else:
#         scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG["epochs"], eta_min=0)
#     return scheduler

In [None]:
# class ISICLightningModel(pl.LightningModule):
#     def __init__(self, pretrained=False):
#         super().__init__()
#         self.model = ISICModel(CONFIG['backbone'], pretrained=pretrained)
# #         print(WEIGHTS[1])
#         self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(WEIGHTS[1]))
    
#     def forward(self, image):
#         return self.model(image)
    
#     def shared_step(self, batch, stage):
#         images, labels = batch['image'], batch['target']
#         logits = self.forward(images)
#         loss = self.loss_fn(logits.squeeze(), labels)
#         preds = logits.sigmoid()
#         self.log(f"{stage}_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        
#         outputs = {
#             "loss": loss,
#             "preds": preds
#         }
#         return outputs
    
    
#     def training_step(self, batch, batch_idx):
#         return self.shared_step(batch, stage='train')
    
#     def validation_step(self, batch, batch_idx):
#         return self.shared_step(batch, stage='valid')
    
#     def configure_optimizers(self):
#         optimizer = torch.optim.AdamW(self.parameters(), lr=CONFIG['lr'])
#         after_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG["epochs"], eta_min=0)
#         scheduler = GradualWarmupScheduler(optimizer, multiplier=1, 
#                                            total_epoch=math.ceil(CONFIG["warmup"]*CONFIG['epochs']), 
#                                            after_scheduler=after_scheduler)
# #         scheduler = get_scheduler(optimizer)
#         return [optimizer], [scheduler]

In [None]:
# for fold in range(5):
#     train_ds = df[df['fold'] != fold].reset_index(drop=True)
#     valid_ds = df[df['fold'] == fold].reset_index(drop=True)
    
#     train_loader = get_dataloaders(train_ds, CONFIG, split="train")
#     valid_loader = get_dataloaders(valid_ds, CONFIG, split="valid")
#     n_steps_per_epoch = math.ceil(len(train_loader.dataset) / CONFIG['batch_size'])
#     CONFIG['n_steps_per_epoch'] = n_steps_per_epoch
    
#     if CONFIG['log_wandb']:
#         wandb_logger = WandbLogger(
#             project=CONFIG["project_name"],
#             checkpoint_name=f'{CONFIG["artifact_name"]}_{fold}',
#             log_model="all",
#         )
        
#     logger = wandb_logger if CONFIG['log_wandb'] else None

#     callbacks = [
#         ModelCheckpoint(save_weights_only=True, 
#                         mode="min", 
#                         monitor="valid_loss"),
#         LearningRateMonitor("epoch"),
#     ]
    
#     net = ISICLightningModel(pretrained=True)
    
#     trainer = pl.Trainer(accelerator="gpu", devices=1, 
#                          precision="16-mixed",
#                          max_epochs=CONFIG['epochs'], 
#                          logger=logger, callbacks=callbacks, default_root_dir=os.getcwd())
    
#     trainer.fit(net, train_dataloaders=train_loader, val_dataloaders=valid_loader)
#     break
    
    
# if CONFIG['log_wandb']:
#     wandb.finish()

In [None]:
# timm.list_pretrained("resnet*")

# Pytorch training

In [None]:
class MetricMonitor:
    def __init__(self, float_precision=4):
        self.float_precision = float_precision
        self.reset()

    def reset(self):
        self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})

    def update(self, metric_name, val):
        metric = self.metrics[metric_name]

        metric["val"] += val
        metric["count"] += 1
        metric["avg"] = metric["val"] / metric["count"]

    def __str__(self):
        return " | ".join(
            [
                "{metric_name}: {avg:.{float_precision}f}".format(
                    metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision
                )
                for (metric_name, metric) in self.metrics.items()
            ]
        )
    
    
# def get_auc(y_true, y_preds, weights=None):
#     return metrics.roc_auc_score(y_true, y_preds, max_fpr=0.2)

# Training functions

In [None]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

def shared_step(model, batch, criterion):
    image, target = batch['image'], batch['target']
    image = image.to(CONFIG["device"], non_blocking=True)
    target = target.to(CONFIG["device"], non_blocking=True)
    outputs = model(image)
    loss = criterion(outputs.squeeze(), target.to(torch.float64))
    logits = outputs.sigmoid()
    

    return {
        "loss": loss,
        "labels": target,
        "logits": logits,
    }

# output = shared_step(net, b, criterion)
def train(train_loader, model, criterion, optimizer, epoch, scaler, scheduler):
    metric_monitor = MetricMonitor()
    model.train()
    stream = tqdm(train_loader)
    train_loss = 0
    for i, batch in enumerate(stream, start=1):
        optimizer.zero_grad(set_to_none=True)
        
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            outputs = shared_step(model, batch, criterion)
            loss =  outputs['loss']
        
        metric_monitor.update("Loss", loss)
        train_loss += loss.detach().float()
        CONFIG['example_ct'] += len(batch["image"])
        # backward pass, with gradient scaling
        scaler.scale(loss).backward()
        
        # clip the gradient
        if CONFIG['with_clip']:
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
        
        lr = optimizer.param_groups[0]['lr']
        scaler.step(optimizer)
        scaler.update()
        
        _train_metrics = {
            "train/step_loss": loss,
            "train/epoch": (i + 1 + CONFIG['n_steps_per_epoch'] * CONFIG['epochs']),
            "train/example_ct": CONFIG['example_ct'],
            "lr": lr,
        }
        
        if CONFIG['log_wandb'] and (i+1 < CONFIG['n_steps_per_epoch']):
            wandb.log(_train_metrics)
        
        CONFIG['step_ct'] += 1
        if scheduler is not None:
            scheduler.step()
        
        stream.set_description(
            "Epoch: {epoch}. Train.      {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
        )
        
    total_train_loss = train_loss / len(train_loader)
    _train_metrics['train/epoch_loss'] = total_train_loss
    
    flush()
    return _train_metrics


def validate(val_loader, model, criterion, epoch):
    metric_monitor = MetricMonitor()
    model.eval()
    stream = tqdm(val_loader)
    valid_loss = 0
    
    with torch.no_grad():
        for i, batch in enumerate(stream, start=1):
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                outputs = shared_step(model, batch, criterion)
                loss =  outputs['loss']

            metric_monitor.update("Loss", loss)
            valid_loss += loss.detach().float()
            _valid_metrics = {
                    "valid/step_loss": loss,
                }
            
            if CONFIG['log_wandb']:
                wandb.log(_valid_metrics)
            
            stream.set_description(
                "Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
            )
            
    total_valid_loss = valid_loss / len(val_loader)
    _valid_metrics['valid/epoch_loss'] = total_valid_loss
    flush()
    return _valid_metrics

In [None]:
def train_and_validate(model, train_dataset, val_dataset, fold=0):
    if CONFIG['log_wandb']:
        run = wandb.init(
            project=CONFIG["project_name"],
            resume="allow",
        )
        artifact = wandb.Artifact(f"{CONFIG['artifact_name']}_{fold}", type="model")
    
    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            DEVICE_IDS = list(range(torch.cuda.device_count()))
            print(f"\nUsing {len(DEVICE_IDS)} GPUs to train ...\n")
            model = nn.DataParallel(model, device_ids=DEVICE_IDS)
            
    model = model.to(CONFIG["device"])
#     model.apply(init_weights)
    train_loader = get_dataloaders(train_dataset, CONFIG, split="train")
    valid_loader = get_dataloaders(val_dataset, CONFIG, split="valid")
    
    n_steps_per_epoch = math.ceil(len(train_loader.dataset) / CONFIG['batch_size'])
    CONFIG['n_steps_per_epoch'] = n_steps_per_epoch
    CONFIG['example_ct'] = 0
    CONFIG['step_ct'] = 0
    
    total_steps = len(train_loader)
#     criterion = nn.BCEWithLogitsLoss()
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(WEIGHTS[1]))
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"])
    scaler = torch.cuda.amp.GradScaler()

    scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=CONFIG["warmup"] * CONFIG['n_steps_per_epoch'],
            num_training_steps=CONFIG["epochs"]* CONFIG['n_steps_per_epoch'],
            num_cycles = CONFIG["num_cycles"],
        )
    
    best_metric = np.inf
    loss_min = np.inf
    es = 0
    ES_RATIO = 0.3 if CONFIG["epochs"] < 30 else 0.20
    weights_file = "ISIC_2024_fold_{fold}_epoch_{epoch}.pth"
    for epoch in range(1, CONFIG["epochs"] + 1):
        _train_metrics = train(train_loader, model, criterion, optimizer, epoch, scaler, scheduler=scheduler)
        _valid_metrics = validate(valid_loader, model, criterion, epoch)
        
        val_loss = _valid_metrics['valid/epoch_loss']
        if CONFIG['log_wandb']:
            wandb.log({**_train_metrics, **_valid_metrics})
        
        if val_loss < best_metric:
            print(f"Best metric: ({best_metric:.6f} --> {val_loss:.6f}). Saving model ...")
            if torch.cuda.device_count() > 2:
                torch.save(model.module.state_dict(), weights_file.format(fold=fold, epoch=epoch))
            else:
                torch.save(model.state_dict(), weights_file.format(fold=fold, epoch=epoch))
            best_metric = val_loss
            if CONFIG['log_wandb']:
                if epoch == 1:
                    artifact.add_file(weights_file.format(fold=fold, epoch=epoch))
                    run.log_artifact(artifact)
                else:
                    draft_artifact = wandb.Artifact(f"{CONFIG['artifact_name']}_{fold}", type="model")
                    draft_artifact.add_file(weights_file.format(fold=fold, epoch=epoch))
                    run.log_artifact(draft_artifact)
                
            es = 0
            
        else:
            es += 1
            
        if es > math.ceil(ES_RATIO*CONFIG["epochs"]):
            print(f"Early stopping on epoch {epoch} ...")
            break
    
    if CONFIG['log_wandb']:
        wandb.config = CONFIG
        wandb.finish()
        
    del model, train_loader, valid_loader
    flush()

# Model training

In [None]:
for fold in range(5):
    model = ISICModel(backbone=CONFIG["backbone"], pretrained=True)
#     model.freeze_encoder(False)
    train_ds = df[df['fold'] != fold].reset_index(drop=True)
    valid_ds = df[df['fold'] == fold].reset_index(drop=True)
#     valid_ds = pd.concat([valid_ds, extern_data]).sample(frac=1.0, random_state=CONFIG['seed']).reset_index(drop=True)
    train_and_validate(model, train_ds, valid_ds, fold=fold)
    
    break
gc.collect()
flush()