# Fine-tuning of a model for segmentation of retinal optical coherence tomography images (AROI)

For more info, check the README.md file.

## Fine-tune the model on the dataset

This notebook fine-tunes some model, by training on the dataset that was created in 02_create_huggingface_dataset.ipynb.

## Citations

Information about the dataset can be found in the following publications:

M. Melinščak, M. Radmilović, Z. Vatavuk, and S. Lončarić, "Annotated retinal optical coherence tomography images (AROI) database for joint retinal layer and fluid segmentation," Automatika, vol. 62, no. 3, pp. 375-385, Jul. 2021. doi: 10.1080/00051144.2021.1973298

M. Melinščak, M. Radmilović, Z. Vatavuk, and S. Lončarić, "AROI: Annotated Retinal OCT Images database," in 2021 44th International Convention on Information, Communication and Electronic Technology (MIPRO), Sep. 2021, pp. 400-405.

M. Melinščak, "Attention-based U-net: Joint segmentation of layers and fluids from retinal OCT images," in 2023 46th International Convention on Information, Communication and Electronic Technology (MIPRO), Sep. 2021, pp. 391-396.

In [1]:
from pathlib import Path
import os
import random
import torch
import numpy as np
import io
import requests
from collections import defaultdict
from PIL import Image
from typing import List, Dict, Tuple, cast
import datasets
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.figure import Figure
from matplotlib.patches import Patch

from transformers import AutoImageProcessor, DetrForSegmentation, SegformerImageProcessor
from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer
from transformers.image_transforms import rgb_to_id
from transformers import EvalPrediction
import evaluate

First, load the dataset from disk:

In [2]:
split_dataset_path: str = "hf_aroi_dataset_split"
split_dataset: datasets.DatasetDict = datasets.load_from_disk(split_dataset_path)

test_dataset: datasets.Dataset = split_dataset['test']
train_dataset: datasets.Dataset = split_dataset['train']

Check if everything looks fine:

In [3]:
print(f"Test dataset length: {len(test_dataset)}")
print(f"Train dataset length: {len(train_dataset)}")

Test dataset length: 22
Train dataset length: 1115


Create an evaluation method. This is almost 100% based on some examples that I found, which is clearly reused everywhere.

In [5]:
metric: evaluate.EvaluationModule = evaluate.load("mean_iou")

def compute_metrics(eval_prediction: EvalPrediction) -> Dict:
    with torch.no_grad():
        logits: np.ndarray
        labels: np.ndarray
        logits, labels = cast(Tuple[np.ndarray, np.ndarray], eval_prediction)
        logits_tensor: torch.Tensor = torch.from_numpy(logits)
        interpolation_result: torch.Tensor = torch.nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )
        logits_tensor = interpolation_result.argmax(dim=1)
        pred_labels: np.ndarray = logits_tensor.detach().cpu().numpy()
        metrics: Dict = cast(Dict, metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            ignore_index=255,
            reduce_labels=False,
        ))
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics

The trainer requires the mappings between IDs and their labels:

In [6]:
annotations_short: List[str] = [
    'above ILM',
    'ILM-IPL/INL',
    'IPL/INL-RPE',
    'RPE-BM',
    'under BM',
    'PED',
    'SRF',
    'IRF',
]
id2label: Dict[int,str] = {v: k for v, k in enumerate(annotations_short)}
label2id: Dict[str,int] = {v: k for k, v in id2label.items()}

reduce_labels has to be False as there's no 'background' or unlabeled part in the images. Most examples have this set to True but it doesn't work for this dataset.

In [7]:
def train_single_model(model_name: str, output_dir: str, save_dir: str, num_epochs: int,
                       batch_size: int, learning_rate: float, save_steps: int, eval_steps: int,
                       logging_steps: int = 10, warmup_steps: int = 0):
    """ Finetune a model on the dataset.
    Note that this method has a side effect: it sets transformers on the test and train datasets """
    print(f"Model name: {model_name}")
    feature_extractor: SegformerImageProcessor = AutoImageProcessor.from_pretrained(model_name, reduce_labels=False)

    model = AutoModelForSemanticSegmentation.from_pretrained(
        model_name, id2label=id2label, label2id=label2id
    )
    def transform_apply_feature_extractor(some_batch: Dict):
        images: List = some_batch["image"]
        labels: List = some_batch["label"]
        inputs = feature_extractor(images, labels)
        return inputs

    test_dataset.set_transform(transform_apply_feature_extractor)
    train_dataset.set_transform(transform_apply_feature_extractor)

    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=learning_rate,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        save_total_limit=3,
        evaluation_strategy="steps",
        save_strategy="steps",
        save_steps=save_steps,
        eval_steps=eval_steps,
        logging_steps=logging_steps,
        eval_accumulation_steps=5,
        remove_unused_columns=False,
        push_to_hub=False,
        warmup_steps=warmup_steps,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        compute_metrics=compute_metrics,
    )
    
    trainer.train()
    model.save_pretrained(save_dir, from_pt=True)

For information about the model, see https://huggingface.co/nvidia/mit-b0

Note that this is the most basic model out of five: b0 up to b5. I prefer to test first with the simplest one instead of immediately testing with the most complicated one.

The model is stored locally in the directory 'nvidia-mit-b0-finetuned'.

In [10]:
train_single_model("nvidia/mit-b0", output_dir="nvidia-mit-b0-training",
                   save_dir="nvidia-mit-b0-finetuned",
                   num_epochs=150, batch_size=60, learning_rate=1e-3,
                   save_steps=20, eval_steps=40, logging_steps=40,
                   warmup_steps=50)



Model name: nvidia/mit-b0


Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, w

Step,Training Loss,Validation Loss,Mean Iou,Mean Accuracy,Overall Accuracy,Per Category Iou,Per Category Accuracy
40,0.9838,0.187288,0.471394,0.519274,0.948794,"[0.9932763332658823, 0.8266048427335162, 0.7161208048998536, 0.12662054048058125, 0.9641020128888943, 0.1440327122183337, 0.00039413256563164027, 0.0]","[0.998351667279345, 0.9016869123887179, 0.9509127726269901, 0.1492644655116051, 0.99579295174979, 0.15779111568107002, 0.00039413256563164027, 0.0]"
80,0.1069,0.072426,0.693461,0.74905,0.977171,"[0.9966487228700531, 0.9089897406856224, 0.8833012834846384, 0.3821064971925711, 0.9861019756830381, 0.6231685580383871, 0.7673708087089358, 0.0]","[0.998791946245383, 0.957102643037413, 0.9605953691545193, 0.5061131088591043, 0.9958171193369019, 0.7217771826203401, 0.8522002878881348, 0.0]"
120,0.0507,0.048844,0.756684,0.814381,0.98318,"[0.9971643168761568, 0.9288407474780883, 0.9083248041074514, 0.47318955672525365, 0.9916985679879228, 0.7430508118333246, 0.8502893671161037, 0.16091219096334186]","[0.9981388421842933, 0.9735800260698338, 0.9573877327249656, 0.589228506047728, 0.9949625685613678, 0.9017272850458723, 0.9390979505106587, 0.16092362344582592]"
160,0.0367,0.039208,0.818605,0.871808,0.985694,"[0.9973876493605115, 0.9362427820606157, 0.9255235051362367, 0.4940775968685851, 0.9928121165889119, 0.7795170139538903, 0.8539395382051892, 0.5693430656934306]","[0.998493707974712, 0.9763499459189616, 0.9628015863541708, 0.6333932657731285, 0.9976746249775789, 0.8629570291016045, 0.9221159777914867, 0.6206749555950266]"
200,0.031,0.035962,0.835415,0.894026,0.986498,"[0.9970542956814854, 0.9358839182607822, 0.9308055899386235, 0.5038223938223938, 0.9932834059112945, 0.797467434815143, 0.8748691521056445, 0.6501307034220533]","[0.9994063548346452, 0.9692986105333222, 0.9625164139906888, 0.6398659692710036, 0.997766763903443, 0.8749531712060706, 0.9309239838234286, 0.7774777975133215]"
240,0.0282,0.033754,0.833021,0.883552,0.986933,"[0.9967991970163573, 0.9367868613622066, 0.9316969164048627, 0.5133386176696078, 0.9940634706563954, 0.8155475863654984, 0.8723217336131582, 0.6036100612537469]","[0.9972611156282045, 0.9808046981168705, 0.9650343312155417, 0.6457175547564563, 0.9981122848754331, 0.9045409981109065, 0.9188258276783878, 0.6581172291296625]"
280,0.0266,0.032336,0.835293,0.884029,0.987362,"[0.9974259531258473, 0.9426469156688787, 0.9331645323692391, 0.5168734779079207, 0.9944649495111094, 0.8158982396034552, 0.8652208730327346, 0.6166503428011754]","[0.998213873647959, 0.9735384252710986, 0.976852404511471, 0.6556554429552142, 0.9976006117420487, 0.8921702256550053, 0.907224621290013, 0.6709769094138543]"
320,0.0248,0.032135,0.842372,0.894151,0.987522,"[0.9976147422130289, 0.9419181596727184, 0.9322581607813546, 0.5118614429172078, 0.9944356090926052, 0.8229847861749702, 0.8728251930922166, 0.665076575445644]","[0.9986753879339024, 0.9825623318634384, 0.9668315802970214, 0.6506864988558353, 0.9983173317473354, 0.8757263444845644, 0.9275824251148125, 0.7528241563055063]"
360,0.0233,0.029808,0.848709,0.898904,0.988234,"[0.997558186957352, 0.941927664343051, 0.9372699608136481, 0.5299519712883306, 0.9951814833980749, 0.835041557211145, 0.8808006261880801, 0.6719437692964593]","[0.9984309458069917, 0.9645145186787586, 0.9750087320219826, 0.6564890487087284, 0.9978117005107291, 0.927353595255745, 0.9448728494070875, 0.7267495559502665]"
400,0.0227,0.029385,0.851711,0.906861,0.988424,"[0.9976898007129046, 0.9459028251948701, 0.9387889448632428, 0.5354398288747458, 0.9949302294577695, 0.8358938396026576, 0.8799413193677635, 0.6851039839639188]","[0.998796665205362, 0.9701098261086613, 0.9709875806330385, 0.6628146453089245, 0.9976247793291606, 0.9215906645304766, 0.9559085612447734, 0.7770515097690941]"




After +/- 2 and a half hour of training, the accuracy values reached the following values:

* Label 0 'above ILM': 0.9987891148693957
* Label 1 'ILM-IPL/INL': 0.9735800260698338
* Label 2 'IPL/INL-RPE': 0.9741664787623961
* Label 3 'RPE-BM': 0.6814972213141549
* Label 4 'under BM': 0.9982323675738952
* Label 5 'PED': 0.9112923152952804
* Label 6 'SRF': 0.943416272534101
* Label 7 'IRF': 0.7868561278863233

Mainly label 3 and label 7 have a much lower accuracy. As I'm no ophthalmologist and can't even claim basic understanding of the subject, I have no idea if this is to be expected or not. Later I'll test with more advanced segmentation models to check if it might fix this issue. Possibly the test set isn't the best selection, or maybe the manual labeling of the dataset has been done by multiple specialists that do not fully agree and have been applying labels slightly differently.

The model has not yet been uploaded to HuggingFace.