# RIGA Optic Disk Segmentation (Magrabia)

In [12]:
from pathlib import Path
from fundus_odmac_toolkit.models.segmentation import segment
from fundus_odmac_toolkit.models.hf_hub import list_models
from fundus_data_toolkit.functional import open_image
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import cv2
from torchmetrics import JaccardIndex
import re
import torch


In [2]:


root = Path('/home/clement/Documents/data/RIGA-dataset/Magrabia/')

all_gts = list(root.glob('**/*-*.tif'))

In [3]:
stored_images = {}
stored_masks = {}

In [4]:


def extract_mask(img_gt, ref_img):
    mask_border = np.abs(img_gt - ref_img).max(-1) > 25
    
    mask = cv2.floodFill(mask_border.astype(np.uint8), None, (0, 0), 255)[1]
    return mask
    


for img_gt in tqdm(all_gts):
    img_name = img_gt.stem.split('-')[0]
    img_parent = img_gt.parent
    img_id = f'{img_parent.stem}/{img_name.lower()}'
    if img_id in stored_images:
        continue
    
    # Ugly way to find all the possible image names given the lack of standardization in the dataset
    img_path = [f for f in img_parent.iterdir() if re.search(fr'{img_name}prime.tif|{img_name.capitalize()}prime.tif|{img_name.capitalize()}.tif|{img_name}.tif', str(f))]
    img_path = list(img_parent.glob(f'{img_name}prime.tif')) + list(img_parent.glob(f'{img_name.capitalize()}prime.tif'))  + list(img_parent.glob(f'{img_name.capitalize()}.tif')) + list(img_parent.glob(f'{img_name}.tif'))
    if len(img_path) == 0:
        print(f'No image found for {img_name}')
        continue
    else:
        img_path = img_path[0]
    current_gt_masks = list(img_parent.glob(f'{img_name}-*.tif'))+list(img_parent.glob(f'{img_name.capitalize()}-*.tif'))
    try:
        img = open_image(img_path)
    except:
        print(f'Failed to open {img_name}')
        
        continue
    gts = (np.stack([extract_mask(img, open_image(f)) for f in current_gt_masks]))
    gts =  1 - (np.mean(gts, axis=0) > gts.shape[0] / 2)
    stored_images[img_id] = img
    stored_masks[img_id] = gts

  0%|          | 0/570 [00:00<?, ?it/s]

No image found for Image7
No image found for Image8


In [13]:
all_models = list_models()

Architecture | [94m Encoder | [92m Variants
[1munet [94mseresnet50 [92m (1 variants)
[1munet [94mmaxvit_tiny_tf_512 [92m (1 variants)
[1munet [94mmaxvit_base_tf_512 [92m (1 variants)
[1munet [94mmobilevitv2_100 [92m (1 variants)
[1munetplusplus [94mseresnet50 [92m (1 variants)
[1munet [94mmobilenetv3_small_050 [92m (1 variants)
[1munetplusplus [94mmobilenetv3_small_050 [92m (1 variants)
[1munet [94mmaxvit_small_tf_512 [92m (1 variants)


In [14]:

def infer(arch, encoder):
    jaccard_index = JaccardIndex(task='binary').cuda()
    for img_id in tqdm(stored_images.keys()):
        img = stored_images[img_id]
        gt = stored_masks[img_id]
        gt = torch.from_numpy(gt).cuda() > 0
        predicted_od = segment(img, arch=arch, encoder=encoder)
        predicted_od = (predicted_od.argmax(0) == 1).long()
        jaccard_index.update(predicted_od, gt)
        
    print(f"Model: {arch}-{encoder}, Jaccard Index: {jaccard_index.compute().item():.2%}")

    
for arch, encoder in all_models:
    infer(arch, encoder)

  0%|          | 0/95 [00:00<?, ?it/s]

Model: unet-seresnet50, Jaccard Index: 87.78%


  0%|          | 0/95 [00:00<?, ?it/s]

Model: unet-maxvit_tiny_tf_512, Jaccard Index: 87.17%


  0%|          | 0/95 [00:00<?, ?it/s]

Model: unet-maxvit_base_tf_512, Jaccard Index: 80.66%


  0%|          | 0/95 [00:00<?, ?it/s]

Model: unet-mobilevitv2_100, Jaccard Index: 87.30%


  0%|          | 0/95 [00:00<?, ?it/s]

Model: unetplusplus-seresnet50, Jaccard Index: 81.95%


  0%|          | 0/95 [00:00<?, ?it/s]

Model: unet-mobilenetv3_small_050, Jaccard Index: 86.41%


  0%|          | 0/95 [00:00<?, ?it/s]

Model: unetplusplus-mobilenetv3_small_050, Jaccard Index: 82.62%


  0%|          | 0/95 [00:00<?, ?it/s]

Model: unet-maxvit_small_tf_512, Jaccard Index: 86.54%
