**1. Import Libraries and Define Functions**

In [None]:
# basic
import numpy as np
import pandas as pd
import math

# PyTorch
import torch
import torchmetrics
import timm
import timm.optim
import timm.scheduler
from torch import nn
from torch.utils.data import Dataset, DataLoader

# others
import os, sys, datetime
import albumentations
from albumentations.pytorch.transforms import ToTensorV2
from collections import Counter
import pathlib
from PIL import Image
from typing import Tuple, Dict, List
from tqdm import tqdm 
from copy import deepcopy
from accelerate import Accelerator
from accelerate.utils import set_seed

# data augmentation
def get_transforms(image_size):
    transforms_train = albumentations.Compose([
        albumentations.Transpose(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.RandomBrightness(limit=0.2, p=0.75),
        albumentations.RandomContrast(limit=0.2, p=0.75),
        albumentations.OneOf([
            albumentations.MotionBlur(blur_limit=5),
            albumentations.MedianBlur(blur_limit=5),
            albumentations.GaussianBlur(blur_limit=5),
            albumentations.GaussNoise(var_limit=(5.0, 30.0)),
        ], p=0.7),
        albumentations.OneOf([
            albumentations.OpticalDistortion(distort_limit=1.0),
            albumentations.GridDistortion(num_steps=5, distort_limit=1.),
            albumentations.ElasticTransform(alpha=3),
        ], p=0.7),
        albumentations.CLAHE(clip_limit=4.0, p=0.7),
        albumentations.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
        albumentations.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
        albumentations.Resize(image_size, image_size),
        albumentations.Cutout(max_h_size=int(image_size*0.375), max_w_size=int(image_size*0.375), num_holes=1, p=0.7),
        albumentations.Normalize(),
        ToTensorV2()
    ])
    transforms_val = albumentations.Compose([
        albumentations.Resize(image_size, image_size),
        albumentations.Normalize(),
        ToTensorV2()
    ])
    return transforms_train, transforms_val

# 'torchvision.datasets.ImageFolder()' customized for applying data augmentation with albumentations library

# make function to find classes in target directory
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folder names in a target directory.

    Assumes target directory is in standard image classification format.

    Args:
        directory (str): target directory to load classnames from.

    Returns:
        Tuple[List[str], Dict[str, int]]: (list_of_class_names, dict(class_name: idx...))
    
    Example:
        find_classes("food_images/train")
        >>> (["class_1", "class_2"], {"class_1": 0, ...})
    """
    # 1. get the class names by scanning the target directory
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    
    # 2. raise an error if class names not found
    if not classes:
        raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
        
    # 3. create a dictionary of index labels (computers prefer numerical rather than string labels)
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

# write a customized dataset class (inherits from torch.utils.data.Dataset)
# 1. subclass torch.utils.data.Dataset
class CustomizedImageFolder(Dataset):
    # 2. initialize with a targ_dir and a transform parameter
    def __init__(self, img_dir: str, transform) -> None:
        # 3. create class attributes
        # get all image paths
        self.paths = list(pathlib.Path(img_dir).glob("*/*.jpg")) # .png, .jpeg
        # setup transforms
        self.transform = transform
        # create classes and class_to_idx attributes
        self.classes, self.class_to_idx = find_classes(img_dir)
    # 4. make function to load images
    def load_image(self, index: int) -> Image.Image:
        "Opens an image via a path and returns it."
        image_path = self.paths[index]
        return Image.open(image_path)
    # 5. overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)
    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.paths)
    # 6. overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, (img, label): (X, y)."
        # load image and label
        img = self.load_image(index)
        class_name  = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpg(.png, .jpeg)
        class_idx = self.class_to_idx[class_name]
        label = class_idx
        # transform for applying data augmentation with albumentations library
        img = np.array(img)
        return self.transform(image=img)['image'], label # return (img, label): (X, y)

# train
def colorful(obj, color="cyan", display_type="shine"):
    color_dict = {"black":"30", "red":"31", "green":"32", "yellow":"33", "blue":"34", "purple":"35", "cyan":"36", "white":"37"}
    display_type_dict = {"plain":"0", "highlight":"1", "underline":"4", "shine":"5", "inverse":"7", "invisible":"8"}
    s = str(obj)
    color_code = color_dict.get(color, "")
    display = display_type_dict.get(display_type, "")
    out = '\033[{};{}m'.format(display, color_code)+s+'\033[0m'
    return out

class StepRunner:
    def __init__(self, accelerator, net, loss_fn, metrics_dict=None, stage='train', optimizer=None, lr_scheduler=None):
        self.net, self.loss_fn, self.metrics_dict, self.stage = net, loss_fn, metrics_dict, stage
        self.optimizer, self.lr_scheduler = optimizer, lr_scheduler
        self.accelerator = accelerator

    def __call__(self, batch):
        features, labels = batch 
        
        # forward
        preds = self.net(features)
        loss = self.loss_fn(preds, labels.long().flatten())

        # backward
        if self.optimizer is not None and self.stage == "train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            self.optimizer.zero_grad()
        all_preds = self.accelerator.gather(preds)
        all_labels = self.accelerator.gather(labels)
        all_loss = self.accelerator.gather(loss).sum()
        
        # losses
        step_losses = {self.stage+"_loss":all_loss.item()}
        
        # metrics
        step_metrics = {self.stage+"_"+name:metric_fn(all_preds, all_labels.long().flatten()).item()
                        for name, metric_fn in self.metrics_dict.items()}
        if self.optimizer is not None and self.stage == "train":
            step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
        
        return step_losses, step_metrics

class EpochRunner:
    def __init__(self, steprunner):
        self.steprunner = steprunner
        self.stage = steprunner.stage
        self.steprunner.net.train() if self.stage == "train" else self.steprunner.net.eval()
        self.accelerator = self.steprunner.accelerator
        
    def __call__(self, dataloader):
        loop = tqdm(enumerate(dataloader, start=1), total=len(dataloader), file=sys.stdout, disable=not self.accelerator.is_local_main_process, ncols=100)

        epoch_losses = {}
        for step, batch in loop: 
            if self.stage == "train":
                step_losses, step_metrics = self.steprunner(batch)
            else:
                with torch.no_grad():
                    step_losses, step_metrics = self.steprunner(batch)
                    
            step_log = dict(step_losses, **step_metrics)
            for k, v in step_losses.items():
                epoch_losses[k] = epoch_losses.get(k, 0.0) + v
            
            if step != len(dataloader):
                loop.set_postfix(**step_log)
            else:
                epoch_metrics = step_metrics
                epoch_metrics.update({self.stage+"_"+name:metric_fn.compute().item() 
                                 for name, metric_fn in self.steprunner.metrics_dict.items()})
                epoch_losses = {k:v/step for k, v in epoch_losses.items()}
                epoch_log = dict(epoch_losses, **epoch_metrics)
                loop.set_postfix(**epoch_log)
                for name, metric_fn in self.steprunner.metrics_dict.items():
                    metric_fn.reset()
        
        return epoch_log

class Model(nn.Module):
    StepRunner, EpochRunner = StepRunner, EpochRunner
    
    def __init__(self, net, loss_fn, metrics_dict=None, optimizer=None, lr_scheduler=None):
        super().__init__()
        self.net, self.loss_fn, self.metrics_dict = net, loss_fn, nn.ModuleDict(metrics_dict) 
        self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(self.net.parameters(), lr=1e-3)
        self.lr_scheduler = lr_scheduler

    def forward(self, x):
        return self.net.forward(x)

    def train(self, train_data, val_data=None, epochs=30, patience=30, ckpt_path='checkpoint.pt', monitor="val_loss", mode="min", mixed_precision='no', callbacks=None):
        self.__dict__.update(locals())
        self.accelerator = Accelerator(mixed_precision=mixed_precision)
        device = str(self.accelerator.device)
        device_type = '🐌' if 'cpu' in device else '⚡️'
        self.accelerator.print(colorful("<<<<<< "+device_type+" "+device+" is used >>>>>>"))
    
        self.net, self.loss_fn, self.metrics_dict, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
            self.net, self.loss_fn, self.metrics_dict, self.optimizer, self.lr_scheduler)
        
        train_dataloader, val_dataloader = self.accelerator.prepare(train_data, val_data)
        
        self.history = {}
        self.callbacks = self.accelerator.prepare(callbacks) if callbacks is not None else []
        
        if self.accelerator.is_local_main_process:
            for callback_obj in self.callbacks:
                callback_obj.on_fit_start(model=self)
        
        for epoch in range(1, epochs+1):
            nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            self.accelerator.print('\n'+'=========='*4+'%s'%nowtime+'=========='*4)
            self.accelerator.print("Epoch {0} / {1}".format(epoch, epochs)+"\n")

            # 1. train the model
            train_step_runner = self.StepRunner(
                net=self.net, 
                loss_fn=self.loss_fn, 
                metrics_dict=deepcopy(self.metrics_dict), 
                stage="train", 
                optimizer=self.optimizer, 
                lr_scheduler=self.lr_scheduler, 
                accelerator=self.accelerator)
            train_epoch_runner = self.EpochRunner(train_step_runner)
            train_metrics = {'epoch':epoch}
            train_metrics.update(train_epoch_runner(train_dataloader))
            if self.lr_scheduler is not None:
                self.lr_scheduler.step(epoch)
            
            for name, metric in train_metrics.items():
                self.history[name] = self.history.get(name, []) + [metric]
                
            if self.accelerator.is_local_main_process:
                for callback_obj in self.callbacks:
                    callback_obj.on_train_epoch_end(model=self)

            # 2. validate the model
            if val_dataloader:
                val_step_runner = self.StepRunner(
                    net=self.net, 
                    loss_fn=self.loss_fn, 
                    metrics_dict=deepcopy(self.metrics_dict), 
                    stage="val", 
                    accelerator=self.accelerator)
                val_epoch_runner = self.EpochRunner(val_step_runner)
                with torch.no_grad():
                    val_metrics = val_epoch_runner(val_dataloader)

                for name, metric in val_metrics.items():
                    self.history[name] = self.history.get(name, []) + [metric]
                
                if self.accelerator.is_local_main_process:
                    for callback_obj in self.callbacks:
                        callback_obj.on_validation_epoch_end(model=self)

            # 3. save the best model
            self.accelerator.wait_for_everyone()
            arr_scores = self.history[monitor]
            best_score_idx = np.argmax(arr_scores) if mode == "max" else np.argmin(arr_scores)

            if best_score_idx == len(arr_scores) - 1:
                unwrapped_net = self.accelerator.unwrap_model(self.net)
                self.accelerator.save(unwrapped_net.state_dict(), ckpt_path)
                self.accelerator.print(colorful("<<<<<< reach best {0} : {1} >>>>>>".format(monitor, arr_scores[best_score_idx])))

            if len(arr_scores) - best_score_idx > patience:
                self.accelerator.print(colorful("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(monitor, patience)))
                break

        if self.accelerator.is_local_main_process:
            for callback_obj in self.callbacks:
                callback_obj.on_fit_end(model=self)
            
            self.net = self.accelerator.unwrap_model(self.net)
            self.net.load_state_dict(torch.load(ckpt_path))
            dfhistory = pd.DataFrame(self.history)
            self.accelerator.print(dfhistory)
            return dfhistory 
    
    @torch.no_grad()
    def evaluate(self, val_data):
        accelerator = Accelerator()
        self.net, self.loss_fn, self.metrics_dict = accelerator.prepare(self.net, self.loss_fn, self.metrics_dict)
        val_data = accelerator.prepare(val_data)
        val_step_runner = self.StepRunner(
            net=self.net, 
            loss_fn=self.loss_fn, 
            metrics_dict=deepcopy(self.metrics_dict), 
            stage="val", 
            accelerator=accelerator)
        val_epoch_runner = self.EpochRunner(val_step_runner)
        val_metrics = val_epoch_runner(val_data)
        return val_metrics

**2. Load and Pre-Process the Data**

In [None]:
img_size = 300
num_classes = 6
transform_augmentation, transform_normal = get_transforms(img_size)

fold = '_Fold1' # Stratified 5-Fold Cross Validation
img_train_dir = '/.../PAD-UFES-20_300x300_SoG_Split4-1-1'+str(fold)+'/train/'
img_val_dir = '/.../PAD-UFES-20_300x300_SoG_Split4-1-1'+str(fold)+'/val/'
img_test_dir = '/.../PAD-UFES-20_300x300_SoG_Split4-1-1'+str(fold)+'/test/'

ds_train = CustomizedImageFolder(img_train_dir, transform=transform_augmentation)
ds_val = CustomizedImageFolder(img_val_dir, transform=transform_normal)
ds_test = CustomizedImageFolder(img_test_dir, transform=transform_normal)

dl_train = DataLoader(ds_train, batch_size=64, shuffle=True, drop_last=True)
dl_val = DataLoader(ds_val, batch_size=64, shuffle=False)
dl_test = DataLoader(ds_test, batch_size=64, shuffle=False)

# print('Examine Numerical Labels: ', ds_train.class_to_idx)
# for features, labels in dl_train:
#     # shape of features: [batch_size; channels, height, width]
#     print('Examine Batched Data Shapes: ', features.shape, labels.shape)
#     break

In [None]:
# calculate normalized inverse class frequencies

# gather all labels from the DataLoader
all_labels = []
for _, labels in dl_train:
    all_labels.extend(labels.tolist())

# count the instances of each class
class_counts = Counter(all_labels)

# calculate the total number of instances
total_instances = len(all_labels)

# calculate inverse class frequencies
inverse_class_frequencies = {class_label: total_instances/count for class_label, count in class_counts.items()}

# convert to tensor or list, assuming classes are 0-indexed and continuous
num_classes = len(class_counts)
inverse_freq_tensor = torch.zeros(num_classes)
for class_label, freq in inverse_class_frequencies.items():
    inverse_freq_tensor[class_label] = freq
normalized_inverse_freq_tensor = inverse_freq_tensor / inverse_freq_tensor.sum()
normalized_inverse_freq = normalized_inverse_freq_tensor.tolist()
print('Normalized Inverse Class Frequencies: ', normalized_inverse_freq)

**3. Define the Model and Metrics**

In [None]:
# EfficientNetB3
net = timm.create_model('efficientnet_b3', features_only=False, pretrained=True, num_classes=num_classes)
# torchkeras.summary(net, input_shape=(3, img_size, img_size))

In [None]:
# multiclass accuracy
class MulticlassAccuracy(torchmetrics.Accuracy):
    def __init__(self, multiclass=True, num_classes=6, average='micro', dist_sync_on_step=False):
        super().__init__(multiclass=multiclass, num_classes=num_classes, average=average, dist_sync_on_step=dist_sync_on_step)
        
    def update(self, preds:torch.Tensor, targets:torch.Tensor):
        super().update(preds.argmax(dim=-1), targets.long().flatten())
        
    def compute(self):
        return super().compute()

# balanced multiclass accuracy
class BalancedMulticlassAccuracy(torchmetrics.Accuracy):
    def __init__(self, multiclass=True, num_classes=6, average='macro', dist_sync_on_step=False):
        super().__init__(multiclass=multiclass, num_classes=num_classes, average=average, dist_sync_on_step=dist_sync_on_step)
        
    def update(self, preds:torch.Tensor, targets:torch.Tensor):
        super().update(preds.argmax(dim=-1), targets.long().flatten())
        
    def compute(self):
        return super().compute()
    
# AUROC
class AUROC(torchmetrics.AUROC):
    def __init__(self, num_classes=6, average='macro', dist_sync_on_step=False):
        super().__init__(num_classes=num_classes, average=average, dist_sync_on_step=dist_sync_on_step)
        
    def update(self, preds:torch.Tensor, targets:torch.Tensor):
        super().update((nn.Softmax(dim=1)(preds)), targets.long().flatten())
        
    def compute(self):
        return super().compute()

# focal loss
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1-alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)
            input = input.transpose(1, 2)
            input = input.contiguous().view(-1, input.size(2))
        target = target.view(-1, 1)

        logpt = nn.LogSoftmax(dim=1)(input)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = logpt.exp()

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt*at

        loss = -1*(1-pt)**self.gamma*logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

**4. Train and Save the Model**

In [None]:
set_seed(4317)

epochs = 100
patience = 100
loss_fn = FocalLoss(alpha=normalized_inverse_freq)
metrics_dict = {'BACC':BalancedMulticlassAccuracy()}
optimizer = torch.optim.Adam(params=net.parameters(), lr=1e-4)
lr_scheduler = timm.scheduler.CosineLRScheduler(optimizer=optimizer, t_initial=epochs, lr_min=1e-5, warmup_t=math.ceil(epochs/10), warmup_lr_init=1e-5)

model = Model(net, loss_fn, metrics_dict, optimizer, lr_scheduler=lr_scheduler)

In [None]:
dfhistory = model.train(
    train_data=dl_train, 
    val_data=dl_val, 
    epochs=epochs, 
    patience=patience, 
    ckpt_path='ImageOnlyDNN'+str(fold)+'.pt', 
    monitor='val_BACC', 
    mode='max', 
    mixed_precision='no')

**5. Load and Evaluate the Model**

In [None]:
metrics_dict = {'ACC':MulticlassAccuracy(), 'BACC':BalancedMulticlassAccuracy(), 'AUROC':AUROC()}
model = Model(net, loss_fn, metrics_dict, optimizer, lr_scheduler=lr_scheduler)
model.net.load_state_dict(torch.load('/.../checkpoints/ImageOnlyDNN'+str(fold)+'.pt'))
model.net = model.net.cuda()
model.net.eval()

# print(model.evaluate(dl_train))
print(model.evaluate(dl_val))
print(model.evaluate(dl_test))

**6. Save Classification Results**

In [None]:
ds_train = CustomizedImageFolder(img_train_dir, transform=transform_normal)

cnn_train = []
id_train = []
for i in range(len(ds_train)):
    img, label = ds_train[i]
    tensor = img.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    y_prob = (model.net(tensor[None, ...])).reshape(6).cpu().detach().numpy()
    cnn_train.append(y_prob)
    id_train.append(os.path.basename(ds_train.paths[i]))
dnn_train = pd.DataFrame(data=cnn_train).to_csv('cnn_train.csv')
id_train = pd.DataFrame(data=id_train).to_csv('id_train.csv')

In [None]:
ds_val = CustomizedImageFolder(img_val_dir, transform=transform_normal)

cnn_val = []
id_val = []
for i in range(len(ds_val)):
    img, label = ds_val[i]
    tensor = img.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    y_prob = (model.net(tensor[None, ...])).reshape(6).cpu().detach().numpy()
    cnn_val.append(y_prob)
    id_val.append(os.path.basename(ds_val.paths[i]))
dnn_val = pd.DataFrame(data=cnn_val).to_csv('cnn_val.csv')
id_val = pd.DataFrame(data=id_val).to_csv('id_val.csv')

In [None]:
ds_test = CustomizedImageFolder(img_test_dir, transform=transform_normal)

cnn_test = []
id_test = []
for i in range(len(ds_test)):
    img, label = ds_test[i]
    tensor = img.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    y_prob = (model.net(tensor[None, ...])).reshape(6).cpu().detach().numpy()
    cnn_test.append(y_prob)
    id_test.append(os.path.basename(ds_test.paths[i]))
dnn_test = pd.DataFrame(data=cnn_test).to_csv('cnn_test.csv')
id_test = pd.DataFrame(data=id_test).to_csv('id_test.csv')