# Fine-tuning of a model for segmentation of retinal optical coherence tomography images (AROI)

For more info, check the README.md file.

This notebook creates a huggingface dataset, based on the patient_images list which was created in 01_load_patient_images.ipynb.

## Citations

Information about the dataset can be found in the following publications:

M. Melinščak, M. Radmilović, Z. Vatavuk, and S. Lončarić, "Annotated retinal optical coherence tomography images (AROI) database for joint retinal layer and fluid segmentation," Automatika, vol. 62, no. 3, pp. 375-385, Jul. 2021. doi: 10.1080/00051144.2021.1973298

M. Melinščak, M. Radmilović, Z. Vatavuk, and S. Lončarić, "AROI: Annotated Retinal OCT Images database," in 2021 44th International Convention on Information, Communication and Electronic Technology (MIPRO), Sep. 2021, pp. 400-405.

M. Melinščak, "Attention-based U-net: Joint segmentation of layers and fluids from retinal OCT images," in 2023 46th International Convention on Information, Communication and Electronic Technology (MIPRO), Sep. 2021, pp. 391-396.

In [None]:
%run 01_load_patient_images.ipynb

Now create a HuggingFace dataset. We're not using the colour_mask images as we don't need them for fine-tuning a segmentation model.

The DatasetInfo lacks some data, but that doesn't matter as I'm not planning to upload the dataset to HuggingFace Hub: that's not up to me.

In [2]:
def create_huggingface_dataset(patient_images: List[PatientImage]):
    images: List[str] = []
    segmentation_maps: List[str] = []
    widths: List[int] = []
    heights: List[int] = []
    patient_numbers: List[int] = []
    image_numbers: List[int] = []

    pi: PatientImage
    for pi in patient_images:
        images.append(pi.get_raw_image_as_rgb().as_posix())
        segmentation_maps.append(pi.number_mask_path.as_posix())
        widths.append(PatientImage.width)
        heights.append(PatientImage.height)
        patient_numbers.append(pi.patient_number)
        image_numbers.append(pi.image_number)


    features: datasets.Features = datasets.Features({
            'image': datasets.Value(dtype='string'),
            'label': datasets.Value(dtype='string'),
            'width': datasets.Value(dtype='int16'),
            'height': datasets.Value(dtype='int16'),
            'patient_number': datasets.Value(dtype='int16'),
            'image_number': datasets.Value(dtype='int16'),
        })
    info: datasets.DatasetInfo = datasets.DatasetInfo(description="AROI", citation="", homepage="", license="", dataset_name="AROI", version="0.0.1", features=features)

    print("Creating the dataset, with paths for the images...")
    ds = datasets.Dataset.from_dict(
        mapping={
            'image': images,
            'label': segmentation_maps,
            'width': widths,
            'height': heights,
            'patient_number': patient_numbers,
            'image_number': image_numbers},
        features=features,
        info=info
    )

    print("Converting the raw image paths into images...")
    ds = ds.cast_column("image", datasets.Image())
    print("Converting the labeled image paths into images...")
    ds = ds.cast_column("label", datasets.Image())
    return ds

complete_dataset = create_huggingface_dataset(patient_images)
complete_dataset_path: str = "hf_aroi_dataset_complete"
complete_dataset.save_to_disk(complete_dataset_path)

Creating the dataset, with paths for the images...
Converting the raw image paths into images...
Converting the labeled image paths into images...


Saving the dataset (0/2 shards):   0%|          | 0/1137 [00:00<?, ? examples/s]

Let's check the contents of the dataset:

In [8]:
print("Full dataset:")
print(complete_dataset)
print("\nFirst entry:")
print(complete_dataset[0])
print("\nFirst entry key and value pairs:")
for k,v in complete_dataset[0].items():
    print(f"Key: {k}, value: {v}")
print("\nHuggingFace features:")
print(complete_dataset.features)
print("\nHuggingFace features as key and value pairs:")
for k,v in complete_dataset.features.items():
    print(f"Key: {k}, value: {v}")

Full dataset:
Dataset({
    features: ['image', 'label', 'width', 'height', 'patient_number', 'image_number'],
    num_rows: 1137
})

First entry:
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x1024 at 0x7FFC51CE1FD0>, 'label': <PIL.PngImagePlugin.PngImageFile image mode=L size=512x1024 at 0x7FFC51CE2E90>, 'width': 512, 'height': 1024, 'patient_number': 9, 'image_number': 106}

First entry key and value pairs:
Key: image, value: <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x1024 at 0x7FFC51CE3650>
Key: label, value: <PIL.PngImagePlugin.PngImageFile image mode=L size=512x1024 at 0x7FFC51CE1210>
Key: width, value: 512
Key: height, value: 1024
Key: patient_number, value: 9
Key: image_number, value: 106

HuggingFace features:
{'image': Image(decode=True, id=None), 'label': Image(decode=True, id=None), 'width': Value(dtype='int16', id=None), 'height': Value(dtype='int16', id=None), 'patient_number': Value(dtype='int16', id=None), 'image_number': Value(dtype='i

The following code is untested. It should in theory allow anyone to upload the dataset to HuggingFace Hub.

In [17]:
# Can be used on the 'complete_dataset' or the 'split_dataset' which is defined later in this notebook
def example_upload_to_huggingface_hub(ds: datasets.Dataset):
    import huggingface_hub
    hf_token: typing.Optional[str] = huggingface_hub.HfFolder.get_token()
    if hf_token is None:
        print("You first need to login to use the HugginFace hub")
    else:
        dataset.push_to_hub('<username>/retinal_optical_coherence_tomography_images_complete', private=False, token=hf_token)
    # You'll need to write a README.md with some metadata at the top. An example:
    # https://github.com/DriesVerachtert/basic_shapes_object_detection_dataset/blob/main/README.md
    readme_file: str = "SOME_README.md"
    readme_contents: str
    with open(readme_file, 'r') as file:
        readme_contents = file.read()
    card = huggingface_hub.repocard.RepoCard(readme_contents)
    card.push_to_hub('<username>/retinal_optical_coherence_tomography_images_complete', token=hf_token, repo_type="dataset")

Dictionaries that map the number values of the labels to their descriptions and vice versa:

In [9]:
id2label: Dict[int,str] = {v: k for v, k in enumerate(annotations_short)}
label2id: Dict[str,int] = {v: k for k, v in id2label.items()}

When creating a train set and a test set, we have to make sure they contain images that contain all the labels: we shouldn't make a test set with images which do not contain label '5' for example. For each label, let's check how many images contain that label:

In [12]:
def check_distribution_of_labels_across_dataset(ds: datasets.Dataset):
    num_images_per_class: Dict[int,int] = defaultdict(int)

    for ds_entry in ds:
        img: Dict = cast(Dict, ds_entry)
        unique_label_ids: np.ndarray = np.unique(img['label'])
        
        for np_id in unique_label_ids:
            id: int = int(np_id)
            num_images_per_class[id] = num_images_per_class[id] + 1

    for k,v in num_images_per_class.items():
        print(f"{v} images use class {k}")
    if len(num_images_per_class) == len(id2label):
        print("All labels are represented")


check_distribution_of_labels_across_dataset(complete_dataset)

1137 images use class 0
1137 images use class 1
1137 images use class 2
1137 images use class 3
1137 images use class 4
1014 images use class 5
649 images use class 6
228 images use class 7
All labels are represented


Only +/- 1 out of 5 images contains label 7 => we can't simply use some random images as a test set.

The following method selects a set of images, that contains at least 10 images using each of the labels.

In [16]:
def select_indexes_for_test_dataset(ds: datasets.Dataset, min_images_per_label: int = 10) -> List[int]:
    """ Select indexes of entries within this dataset that could be used as a test dataset
    Make sure that at least each label is represented by min_images_per_label images
    Note: uses """
    def enough_images_selected(num_images_per_label: Dict[int,int], num_classes: int) -> bool:
        for id in range(0,num_classes):
            if num_images_per_label[id] < min_images_per_label:
                return False
        return True

    def check_if_contains_needed_labels(num_images_per_label: Dict[int,int], labels_of_image: np.ndarray) -> bool:
        """ some labels might not have enough images yet within num_images_per_label
        Check if the labels within the randomly selected image, contains labels that are not yet enough represented """
        for np_id in labels_of_image:
            id: int = int(np_id)
            if num_images_per_label[id] < min_images_per_label:
                return True
        return False


    num_images_per_label: Dict[int,int] = defaultdict(int)
    selected_indexes: List[int] = []

    while not enough_images_selected(num_images_per_label, len(id2label)):
        # select a random entry in the dataset
        random_index: int = random.randint(0, len(ds) - 1)
        if random_index in selected_indexes:
            # this index was already selected
            continue
        if 'label' not in ds[random_index]:
            print(f"no 'label' key for entry at index {random_index}")
            raise ValueError("This method can only be used before adding the transform method which uses the feature extractor")
        labels_of_image: np.ndarray = np.unique(ds[random_index]['label'])
        if check_if_contains_needed_labels(num_images_per_label, labels_of_image):
            # Let's add this image
            selected_indexes.append(random_index)
            for np_id in labels_of_image:
                id: int = int(np_id)
                num_images_per_label[id] = num_images_per_label[id] + 1
    return selected_indexes


indexes_for_test_dataset = select_indexes_for_test_dataset(complete_dataset, min_images_per_label=10)
print(f"Selected indexes for a test dataset: {indexes_for_test_dataset}")

Selected indexes for a test dataset: [228, 51, 563, 501, 457, 285, 209, 1116, 178, 864, 65, 61, 191, 859, 865, 318, 704, 928, 727, 664, 292, 877]


Let's make a train and a test dataset from the complete dataset:
* The test dataset will contain the small number of images that were selected above.
* The train dataset will contain all remaining images.


In [22]:
test_dataset: datasets.Dataset = complete_dataset.select(indexes_for_test_dataset)
print(f"ds_test dataset has {len(test_dataset)} entries with the following distribution of label usage:")
check_distribution_of_labels_across_dataset(test_dataset)

train_dataset: datasets.Dataset = complete_dataset.select([x for x in range(0,len(complete_dataset)) if x not in indexes_for_test_dataset])
print(f"ds_train dataset has {len(train_dataset)} entries with the following distribution of label usage:")
check_distribution_of_labels_across_dataset(train_dataset)

ds_test dataset has 22 entries with the following distribution of label usage:
22 images use class 0
22 images use class 1
22 images use class 2
22 images use class 3
22 images use class 4
22 images use class 5
14 images use class 6
10 images use class 7
All labels are represented
ds_train dataset has 1115 entries with the following distribution of label usage:
1115 images use class 0
1115 images use class 1
1115 images use class 2
1115 images use class 3
1115 images use class 4
992 images use class 5
635 images use class 6
218 images use class 7
All labels are represented


Create a dataset containing the two splits 'train' and test'.

In [23]:
split_dataset: datasets.DatasetDict = datasets.DatasetDict({
    'train': train_dataset,
    'test': test_dataset
})

Store also this dataset on disk:

In [24]:
split_dataset_path: str = "hf_aroi_dataset_split"
split_dataset.save_to_disk(split_dataset_path)

Saving the dataset (0/2 shards):   0%|          | 0/1115 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/22 [00:00<?, ? examples/s]