In [1]:
import os
import cv2
import numpy as np
from skimage.measure import label, regionprops
from operator import attrgetter
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from glob import glob


In [None]:
raw_root = './data/brazilian-ophthalmological/1.0.1/fundus_photos'
raw_files = glob(raw_root + '/*.jpg', recursive=True)

print(len(raw_files))

In [3]:
def fill_crop(img, min_idx, max_idx):
    crop = np.zeros(np.array(max_idx, dtype='int16') - np.array(min_idx, dtype='int16'), dtype=img.dtype)
    img_shape, start, crop_shape = np.array(img.shape), np.array(min_idx, dtype='int16'), np.array(crop.shape),
    end = start + crop_shape
    # Calculate crop slice positions
    crop_low = np.clip(0 - start, a_min=0, a_max=crop_shape)
    crop_high = crop_shape - np.clip(end - img_shape, a_min=0, a_max=crop_shape)
    crop_slices = (slice(low, high) for low, high in zip(crop_low, crop_high))
    # Calculate img slice positions
    pos = np.clip(start, a_min=0, a_max=img_shape)
    end = np.clip(end, a_min=0, a_max=img_shape)
    img_slices = (slice(low, high) for low, high in zip(pos, end))
    crop[tuple(crop_slices)] = img[tuple(img_slices)]
    return crop


def fundus_crop(image, shape=[512, 512], margin=5):
    mask = (image.sum(axis=-1) > 30)
    mask = label(mask)
    regions = regionprops(mask)
    region = max(regions, key=attrgetter('area'))

    len = (np.array(region.bbox[2:4]) - np.array(region.bbox[0:2])).max()
    bbox = np.concatenate([np.array(region.centroid) - len / 2, np.array(region.centroid) + len / 2]).astype('int16')

    image_b = fill_crop(image, [bbox[0] - margin, bbox[1] - margin, 0], [bbox[2] + margin, bbox[3] + margin, 3])
    image_b = cv2.resize(image_b, shape, interpolation=cv2.INTER_LINEAR)
    return image_b

In [None]:
def process_single_image(cur_path):
    try:
        image = cv2.imread(cur_path)
        image_crop = fundus_crop(image, shape=[512, 512], margin=5)
        save_path = cur_path.replace('fundus_photos', 'fundus_photos_preprocessed')
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        cv2.imwrite(save_path, image_crop)
    except Exception as e:
        print(f"error {cur_path}: {str(e)}")

def process_images_parallel(raw_files, num_threads=8):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        list(tqdm(executor.map(process_single_image, raw_files), 
                 total=len(raw_files), 
                 desc="processing"))

process_images_parallel(raw_files)