### Dataset Creation

This is the notebook to generate augmented dataset

This notebook contains the following
1. Image Pre-processing with Ben Graham Preprocessing
2. Image Segmentation using CLAHE and openCV
3. Image Segmentation using UNET  

This notebook assumes the following project structure:
```bash
Root
├── notebooks
│   └── notebook1.ipynb
└── input
    └── Data
        ├── DDR
        │   ├── Train
        │   └── Test
        ── BEN
        │   ├── Train
        │   └── Test
        ├── CLAHE
        │   ├── Train
        │   └── Test
        ├── UNET_binary
        │   ├── Train
        │   └── Test
        └── UNET_multiclass
            ├── Train
            └── Test
```

If you do not have the dataset, please download it from our Google Drive

In [13]:
### BASE Class for dataset creation
from abc import ABC

class Dataset_Creator(ABC):
    def __init__():
        pass

    def create_dataset(self, source_path:str, dest_path:str, limit:int=None, show_output:bool=False):
        pass

### Splitting Dataset into Train, Val, Test

The existing dataset does not have any train_test-_val_split

We decided to randomly split the train, validation, test with 0.7, 0.15, 0.15 (?)

In [24]:
import pandas as pd
import os
import shutil

# Replace 'your_file.csv' with the path to your CSV file
csv_file =r'H:\Diabetic_Retinopathy_Detection\input\grading_images\DR_grading.csv'
# Replace 'your_image_directory' with the path to the directory containing your images
image_directory = r'H:\Diabetic_Retinopathy_Detection\input\grading_images\DR_grading\DR_grading'
# Read the CSV file
df = pd.read_csv(csv_file, header=0)

# Loop through the DataFrame
for index, row in df.iterrows():
    image_name = row['id_code']  # Column name for image names
    label = row['diagnosis']  # Column name for labels
    
    # Create the directory for the label if it doesn't exist
    label_directory = os.path.join(image_directory, str(label))
    if not os.path.exists(label_directory):
        os.makedirs(label_directory)
    
    # Construct the source and destination paths
    src = os.path.join(image_directory, image_name)
    dest = os.path.join(label_directory, image_name)
    
    # Move the image
    shutil.move(src, dest)


In [28]:
import os
import shutil
from glob import glob
from sklearn.model_selection import train_test_split

# Base directory where the BEN folder is located
base_dir = r"H:\Diabetic_Retinopathy_Detection\input\grading_images\DDR"
dest_dir = r"H:\Diabetic_Retinopathy_Detection\input\grading_images\DDR"  # Destination directory for the split datasets

# Collect all image paths
image_paths = []
for i in range(5):  # Assuming subfolders are named 0, 1, 2, 3, 4
    image_paths.extend(glob(os.path.join(base_dir, str(i), '*')))

# Shuffle and split the data
train_val, test = train_test_split(image_paths, test_size=0.15, random_state=42)
train, val = train_test_split(train_val, test_size=0.15 / 0.85, random_state=42)

# Function to copy files to the new directory structure
def copy_files(files, dataset_type):
    for file_path in files:
        # Determine the class directory based on the file path
        class_dir = os.path.basename(os.path.dirname(file_path))
        dest_path = os.path.join(dest_dir, dataset_type, class_dir)
        
        if not os.path.exists(dest_path):
            os.makedirs(dest_path)
        
        # Copy the file
        shutil.copy(file_path, dest_path)

# Copy the files to their new locations
copy_files(train, 'train')
copy_files(val, 'val')
copy_files(test, 'test')

print("Data split and copied successfully.")


Data split and copied successfully.


### Image Pre-processing with Ben Graham Preprocessing algorithm

Inspired by Ben Graham who won the EyePacs Diabetic retinopathy challenge

Due to computational limitations of applying the Ben Graham preprocessing as a torchvision Transforms, these were applied onto the original image and saved such that they can be used directly

In [45]:
import cv2
import numpy as np
import os

class Ben_process(Dataset_Creator):
    def __init__(self):
        pass

    def create_dataset(self, source_path:str, dest_path:str, limit:int=None, show_output:bool=False):
        """
        Main function that will run image masking and processing with OpenCV and Ben Graham's Processing

        Ben Graham's Processing takes overlaying a gaussian blur of an image on top of the original image.
        This highlights the edges and increase the contrast, making it easier to determine the blood vessels and lesions.

        A mask is created to remove the background of the image such that the focus is the eye.

        All images in the folder will be preprocessed and saved to the correct location in it's original image size with the same image name

        Args:
            source_path (str): Source path of the folder
            dest_path (str): Dest path of the folder
            limit (int, optional): Maximum number of images to process. Defaults to None.
            show_output (bool, optional): Show the processed image. Defaults to False
        
        """
        
        #create destination directory
        if not os.path.isdir(dest_path):
          os.makedirs(dest_path)

        #variable for scaling factor for image size  
        scale = 500
        folder_files = os.listdir(source_path)

        for i, file in enumerate(folder_files, start=1):
            base_image = cv2.imread(f'{source_path}/{file}') #open image

            #resize image in to improve the ben graham effectiveness (specific for gaussian blur scale)
            resize_image = cv2.resize(base_image, (224, 224))

            #create background mask
            mask = self._mask_image(resize_image)

            #Ben Graham Processing
            processed_img = cv2.addWeighted(resize_image, 4, cv2.GaussianBlur(resize_image, (0, 0), scale/30), -4, 128) #overlaps gaussian blur
            np_zero_mask = np.zeros(processed_img.shape) #creates black background
            cv2.circle(np_zero_mask, (processed_img.shape[1] // 2, processed_img.shape[0] // 2), int(scale * 0.9), (1, 1, 1), -1, 8, 0) #estimates a circle where the eye will be
            enhanced_image = processed_img * np_zero_mask + 128 * (1 - np_zero_mask) #overlay a dark blue outline around the eye (used in scenarios where the mask fails)

            #remove background
            enhanced_image[mask == 255] = 0 
            
            if show_output:
                result = np.hstack((resize_image, enhanced_image))
                cv2.imshow("enhanced image", result)
                cv2.waitKey(10000)
                cv2.destroyAllWindows()

            cv2.imwrite(f"{dest_path}/{file}", enhanced_image) 
            print(f"Image {i}/{len(folder_files)} processed. Output path:{dest_path}/{file}")

            if limit == None:
                pass

            elif i >= limit:
                break



    def _mask_image(self,image:np.ndarray):
        """
        Function to created a mask of the eyeball in black against the white background
        
        Background of image will be white to act as the mask

        Args:
            image (np.ndarray): Image to create mask of

        Return:
            mask (np.ndarray): Maks of image
        """

        #Convert image to grayscale to improve background removal
        tmp = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        
        #Create border threshold of 5-255
        _, mask = cv2.threshold(tmp, 5, 255, cv2.THRESH_BINARY)

        #Blur the mask to smoothen threshold edges
        mask = cv2.GaussianBlur(mask, (11,11), 0)

        #Create base white background
        white_bg = np.full_like(image, 255)

        #Create mask and invert it such that we get the eyeball is black and background is white
        mask = cv2.bitwise_not(mask)

        #Overlay the mask on white background. Background will be white, eyeball will be black
        masked_image = cv2.bitwise_or(white_bg, white_bg, mask=mask)

        return masked_image


In [39]:
source_path = "../input/grading_images/DDR"
dest_path = "../input/grading_images/BEN"
folder_list = ["train", "val", "test"]
class_lists = ['0', '1', '2', '3', '4']

Ben_Processor = Ben_process()

for folder_name in folder_list:
    for class_idx in class_lists:
        image_src_path = f"{source_path}/{folder_name}/{class_idx}"
        mask_dest_path = f"{dest_path}/{folder_name}/{class_idx}"
        Ben_Processor.create_dataset(image_src_path, mask_dest_path)

Image 1/4334 processed. Output path:../input/grading_images/BEN/train/0/007-0007-000.jpg
Image 2/4334 processed. Output path:../input/grading_images/BEN/train/0/007-0008-000.jpg
Image 3/4334 processed. Output path:../input/grading_images/BEN/train/0/007-0009-000.jpg
Image 4/4334 processed. Output path:../input/grading_images/BEN/train/0/007-0022-000.jpg
Image 5/4334 processed. Output path:../input/grading_images/BEN/train/0/007-0023-000.jpg
Image 6/4334 processed. Output path:../input/grading_images/BEN/train/0/007-0025-000.jpg
Image 7/4334 processed. Output path:../input/grading_images/BEN/train/0/007-0026-000.jpg
Image 8/4334 processed. Output path:../input/grading_images/BEN/train/0/007-0028-000.jpg
Image 9/4334 processed. Output path:../input/grading_images/BEN/train/0/007-0029-000.jpg
Image 10/4334 processed. Output path:../input/grading_images/BEN/train/0/007-0030-000.jpg
Image 11/4334 processed. Output path:../input/grading_images/BEN/train/0/007-0032-000.jpg
Image 12/4334 proce

### Image Segmentation with OpenCV and Clahe

Inspired by Detecting Diabetic Retinopathy in Fundus Images using Combined Enhanced Green and Value Planes (CEGVP) with k-NN:
<br>
https://thesai.org/Publications/ViewPaper?Volume=13&Issue=1&Code=IJACSA&SerialNo=32


The image will have it's contrast enhanced before extracting the green channel of the image. The green channel will then undergo edge detection with different thresholds to segment the blood vessels and the lesions.

The background will be black, with the extracted blood vessels in white and the lesions will be in gray.


Due to computational limitations of applying segmentation as a torchvision Transforms, these were applied onto the original image and saved such that they can be used directly

In [29]:
import os
import cv2
import numpy as np

class CLAHE_segmentation(Dataset_Creator):
    def __init__(self):
        pass
   
    def create_dataset(self, source_path:str, dest_path:str, limit:int=None, show_output:bool=False):
        """
        Main function that will run image segmentation with OpenCV and CLAHE
        Image segmentation will be run using edge detection, background will be black while the vessels are white and lesions are gray

        All images in the folder will be preprocessed and saved to the correct location in it's original image size with the same image name

        Args:
            source_path (str): Source path of the folder
            dest_path (str): Dest path of the folder
            limit (int, optional): Maximum number of images to process. Defaults to None.
            show_output (bool, optional): Show the processed image. Defaults to False
        
        """

        #create destination directory
        if not os.path.isdir(dest_path):
          os.makedirs(dest_path)

        folder_files = os.listdir(source_path)

        for i, file in enumerate(folder_files, start=1):
            base_image = cv2.imread(f'{source_path}/{file}') #open image

            #enhance contrast of image
            enhanced_image = self._enhance_contrast(base_image, False)

            #enhance green channel of the image
            max_pixel_diff, enhanced_green, g = self._enhance_green_channel(image=enhanced_image,
                                                show_channels=False,
                                                show_enhanced=False)

            #extract blood vessel
            blood_vessel_mask = self._extract_blood_vessels(image=enhanced_green,
                                                    show_extracted=False)

            #extract lesion 
            lesion_mask = self._extract_lesion(image=g, max_pixel_diff=max_pixel_diff, show_extracted=False)

            #calculate the overlap between blood_vessel and lesion mask, prioritise blood vessel
            overlap_mask = cv2.bitwise_and(blood_vessel_mask, lesion_mask)
            lesion_mask = np.subtract(lesion_mask, overlap_mask) #remove overlapping parts from lesion mask
            lesion_mask[lesion_mask == 255] = 100 #convert lesion mask to gray

            merged_mask = cv2.bitwise_or(blood_vessel_mask, lesion_mask)
            
            rgb_mask = cv2.cvtColor(merged_mask,cv2.COLOR_GRAY2RGB) 

            if show_output:
                resized = cv2.resize(rgb_mask, (224,224))
                result = np.hstack((cv2.resize(base_image, (224,224)), resized))
                cv2.imshow("output", result)
                cv2.waitKey(10000)
                cv2.destroyAllWindows()

            cv2.imwrite(f"{dest_path}/{file}", rgb_mask) 
            print(f"Image {i}/{len(folder_files)} processed. Output path:{dest_path}/{file}")

            if limit == None:
                pass

            elif i >= limit:
                break


    def _enhance_contrast(self, image:np.ndarray, show_enchanced:bool=False):
        """
        Function to enhance the contrast of image using CLAHE

        Args:
            image (np.ndarray): Image to be processed
            show_enchanced (bool): Boolean to show before and after image. Defaults to False.

        Returns:
            enhanced image (np.ndarray)
        """
        lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) #convert rgb image to lab channels
        L, a, b = cv2.split(lab) #split lab channels

        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) #create Contrast Limited Adaptive Histogram Equalization (CLAHE)
        cL = clahe.apply(L) #apply to lab channel to increase contrast

        new_img = cv2.merge((cL,a,b)) #create new lab enhanced image
        enhanced_img = cv2.cvtColor(new_img, cv2.COLOR_LAB2BGR) #convert from lab to rgb

        if show_enchanced:
          result = np.hstack((cv2.resize(image, (224,224)), (cv2.resize(enhanced_img, (224,224)))))
          cv2.imshow("enhanced image", result)
          cv2.waitKey(10000)
          cv2.destroyAllWindows()

        return enhanced_img


    def _enhance_green_channel(self, image:np.ndarray, kernel_size:tuple=(75,75), show_channels:bool=False, show_enhanced:bool=False):
        """
        Function to enhance the constrast of the GREEN channel of the image by doing noise extraction

        Args:
            image (np.ndarray): Image to be processed
            kernel_size (tuple, optional): Kernel size for . Defaults to (75,75).
            show_channels (bool, optional): Boolean to show the r,g and b channel of image. Defaults to False.
            show_enhanced (bool, optional): Boolean to show enhanced image. Defaults to False.

        Returns:
            max pixel diff (int): maximum pixel difference in enhanced image
            enhanced green: enhanced green channel image
            g: green channel image without enhancement
        """
        b,g,r = cv2.split(image)
        if show_channels:
          cv2.imshow("blue channel", b)
          cv2.imshow("green channel", g)
          cv2.imshow("red channel", r)

        kernel = np.ones(kernel_size, np.uint8) #creating kernel
        opening = cv2.morphologyEx(g, cv2.MORPH_OPEN, kernel) #morphological transformations to remove noise, smoothens the image
        morph_green = cv2.subtract(g, opening) #subtract to extract out the outliers/noise that were removed, highlights the key areas and increases contrast
    
        max_pixel_diff = np.max(morph_green)

        min_intensity = np.min(morph_green)
        max_intensity = np.max(morph_green)

        enhanced_green = ((morph_green - min_intensity) / (max_intensity - min_intensity)) * 255 #normalise image
        enhanced_green = enhanced_green.astype(np.uint8)

        if show_enhanced:
          result = np.hstack((cv2.resize(g, (224,224)), cv2.resize(enhanced_green, (224,224))))
          cv2.imshow("enhanced green channel", result)
          cv2.waitKey(10000)
          cv2.destroyAllWindows()

        return max_pixel_diff, enhanced_green, g


    def _extract_blood_vessels(self, image:np.ndarray, block_size:int=31, calculated_mean:int=21, area_limit:int=3, show_extracted:bool=False):
        """
        Function to extract out blood vessels through the use of edge detection

        Args:
            image (np.ndarray): Image to be processed
            block_size (int, optional): Size of local region around each pixel used for adaptive thresholding, the higher the more global the image. MUST be an odd number . Defaults to 31.
            calculated_mean (int, optional): Constant value added to each pixel. Defaults to 21.
            area_limit (int, optional): Minimum limit of the area of mask to be considered a mask. Defaults to 3.
            show_extracted (bool, optional): Boolean to show extracted vessel mask. Defaults to False.

        Returns:
            mask (np.ndaray): Extracted vessel mask
        """
        if block_size % 2 ==0:
            raise Exception("block_size must be an odd integer")

        while True:
            #adaptive threshold for edge detection
            adaptive_threshold = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, block_size, calculated_mean)
            contours, _ = cv2.findContours(adaptive_threshold, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) #edge detection

            #sort contours by max area
            contours = sorted(contours, key=cv2.contourArea, reverse=True)
            contours = [cnt for cnt in contours if cv2.contourArea(cnt) > area_limit] #sieve and extract contours larger than limit
            if len(contours) > 0:
              max_contour_area = max([cv2.contourArea(cnt) for cnt in contours]) #get maximum contours
              if max_contour_area >= 0.5 * image.shape[0] * image.shape[1]: #if the max coutour is larger than 50% of the image (edges not defined enough, repeat and increase block size and mean)
                  block_size += 2
                  calculated_mean += 3
              
              else:
                 break
            
            else:
                break

        #create base mask
        mask = np.zeros_like(image)
        cv2.drawContours(mask, contours, -1, (255), thickness=cv2.FILLED) #update mask with contours

        if show_extracted:
          result = np.hstack((cv2.resize(image, (224,224)), cv2.resize(mask, (224,224))))
          cv2.imshow("vein extraction", result)
          cv2.waitKey(10000)
          cv2.destroyAllWindows()

        return mask


    def _extract_lesion(self, image: np.ndarray, max_pixel_diff: int = 125, show_extracted: bool = False):
        """
        Function to extract out lesions through the use of edge detection

        Args:
            image (np.ndarray): Green channel of image that isn't processed
            max_pixel_diff (int, optional): Max pixel diff from processing. Defaults to 125.
            show_extracted (bool, optional): Boolean to show extracted lesion mask. Defaults to False.

        Returns:
            mask (np.ndarray): Extracted lesion mask
        """
        # Initialize variables to None
        binary = binary2 = binary3 = None

        # Conditional logic for threshold boundaries depending on max_pixel_diff due to different lighting conditions
        # Binary: all somewhat light parts of the image beyond a certain threshold
        # Binary2: stricter threshold, aim is for Binary2 - Binary to just return the lesions without any light spots on image
        # Binary3: extremely strict threshold, only return lightest part of image, aim to retrieve optic disk
        if max_pixel_diff < 165:
            _, binary = cv2.threshold(image, max_pixel_diff, 255, cv2.THRESH_BINARY)
            _, binary2 = cv2.threshold(image, max_pixel_diff, 255, cv2.THRESH_BINARY)
            _, binary3 = cv2.threshold(image, max_pixel_diff * 1.1, 255, cv2.THRESH_BINARY)

        elif 165<= max_pixel_diff <200:
          _, binary = cv2.threshold(image, max_pixel_diff*0.8, 255, cv2.THRESH_BINARY)
          _, binary2 = cv2.threshold(image, max_pixel_diff*0.85, 255, cv2.THRESH_BINARY)
          _, binary3 = cv2.threshold(image, max_pixel_diff*1.1, 255, cv2.THRESH_BINARY)

        elif 200<= max_pixel_diff <215:
          _, binary = cv2.threshold(image, max_pixel_diff*0.7, 255, cv2.THRESH_BINARY)
          _, binary2 = cv2.threshold(image, max_pixel_diff*0.75, 255, cv2.THRESH_BINARY)
          _, binary3 = cv2.threshold(image, max_pixel_diff*0.8, 255, cv2.THRESH_BINARY)

        elif 215 <= max_pixel_diff < 230:
          _, binary = cv2.threshold(image, max_pixel_diff*0.75, 255, cv2.THRESH_BINARY)
          _, binary2 = cv2.threshold(image, max_pixel_diff*0.85, 255, cv2.THRESH_BINARY)
          _, binary3 = cv2.threshold(image, max_pixel_diff*0.9, 255, cv2.THRESH_BINARY)

        elif 230<= max_pixel_diff <245:
          _, binary = cv2.threshold(image, max_pixel_diff*0.45, 255, cv2.THRESH_BINARY)
          _, binary2 = cv2.threshold(image, max_pixel_diff*0.5, 255, cv2.THRESH_BINARY)
          _, binary3 = cv2.threshold(image, max_pixel_diff, 255, cv2.THRESH_BINARY)

        elif max_pixel_diff >245:
          _, binary = cv2.threshold(image, max_pixel_diff*0.7, 255, cv2.THRESH_BINARY)
          _, binary2 = cv2.threshold(image, max_pixel_diff*0.75, 255, cv2.THRESH_BINARY)
          _, binary3 = cv2.threshold(image, max_pixel_diff, 255, cv2.THRESH_BINARY)

        #retrieve contours
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 
        contours2, _ = cv2.findContours(binary2, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours3, _ = cv2.findContours(binary3, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        #generate empty mask
        mask = np.zeros_like(image)
        mask2 = np.zeros_like(image)
        mask3 = np.zeros_like(image)

        #adding contours to the different mask accordingly
        cv2.drawContours(mask, contours, -1, (255), thickness=cv2.FILLED)
        cv2.drawContours(mask2, contours2, -1, (255), thickness=cv2.FILLED)
        cv2.drawContours(mask3, contours3, -1, (255), thickness=cv2.FILLED) #optic disk

        #subtracting binary from binary2 to extract just the small lesions
        merge_mask = np.subtract(mask, mask2)
        #adding the optic disk
        merge_mask = np.add(merge_mask, mask3)

        if show_extracted:
          result = np.hstack((cv2.resize(image, (224,224)), cv2.resize(merge_mask, (224,224))))
          cv2.imshow("outlier extraction", result)
          cv2.waitKey(10000)
          cv2.destroyAllWindows()

        return merge_mask

In [30]:
source_path = "../input/grading_images/DDR"
dest_path = "../input/grading_images/CLAHE"
folder_list = ["train", "val", "test"]
class_lists = ['0', '1', '2', '3', '4']

CLAHE_segmenter = CLAHE_segmentation()

for folder_name in folder_list:
    for class_idx in class_lists:
        image_src_path = f"{source_path}/{folder_name}/{class_idx}"
        mask_dest_path = f"{dest_path}/{folder_name}/{class_idx}"
        CLAHE_segmenter.create_dataset(image_src_path, mask_dest_path)

Image 1/4334 processed. Output path:../input/grading_images/CLAHE/train/0/007-0007-000.jpg
Image 2/4334 processed. Output path:../input/grading_images/CLAHE/train/0/007-0008-000.jpg
Image 3/4334 processed. Output path:../input/grading_images/CLAHE/train/0/007-0009-000.jpg
Image 4/4334 processed. Output path:../input/grading_images/CLAHE/train/0/007-0022-000.jpg
Image 5/4334 processed. Output path:../input/grading_images/CLAHE/train/0/007-0023-000.jpg
Image 6/4334 processed. Output path:../input/grading_images/CLAHE/train/0/007-0025-000.jpg
Image 7/4334 processed. Output path:../input/grading_images/CLAHE/train/0/007-0026-000.jpg
Image 8/4334 processed. Output path:../input/grading_images/CLAHE/train/0/007-0028-000.jpg
Image 9/4334 processed. Output path:../input/grading_images/CLAHE/train/0/007-0029-000.jpg
Image 10/4334 processed. Output path:../input/grading_images/CLAHE/train/0/007-0030-000.jpg
Image 11/4334 processed. Output path:../input/grading_images/CLAHE/train/0/007-0032-000.j

### Image Segmentation with UNET

Requires: 2 pretrained UNET model, one for vessel segmentation and one for lesion segmentation

There are 2 ways to create the dataset
1. Binary class: vessel and lesion segmentation will return a binary mask which will be merged
2. Multi Class: vessel segmentation will be deemed class 1 while the lesion segmentation will take class 2-6. Background will be black while the different classes will get varying shades of gray depending on their class index.

Due to computational limitations of applying segmentation as a torchvision Transforms, these were applied onto the original image and saved such that they can be used directly

In [31]:
import torch
import torchvision
from PIL import Image

In [32]:
class conv_block(torch.nn.Module):
  """convolutional block for that UNET"""
  def __init__(self, in_channels:int, out_channels:int):
    super(conv_block, self).__init__()
    self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, padding=1)
    self.bn1 = torch.nn.BatchNorm2d(out_channels)
    self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3, padding=1)
    self.bn2 = torch.nn.BatchNorm2d(out_channels)
    self.relu = torch.nn.ReLU()

  def forward(self, inputs):
    x = self.relu(self.bn1(self.conv1(inputs)))
    x = self.relu(self.bn2(self.conv2(x)))
    return x
  
class encoder_block(torch.nn.Module):
  """ 
  encoder block that includes convolutional block and maxpooling
  returns both values before maxpool and after maxpool (for skip connections)
  """ 
  def __init__(self, in_channels:int, out_channels:int):
    super(encoder_block, self).__init__()
    self.conv = conv_block(in_channels, out_channels)
    self.maxpool = torch.nn.MaxPool2d((2,2))

  def forward(self, inputs):
    x = self.conv(inputs)
    p = self.maxpool(x)
    return x, p

class decoder_block(torch.nn.Module):
  """
  decoder block that upsamples images and takes in skip connections
  """
  def __init__(self, in_channels:int, out_channels:int):
    super(decoder_block, self).__init__()
    self.upsample = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    self.conv = conv_block(out_channels+out_channels, out_channels)

  def forward(self, inputs, skip_connections):
    x = self.upsample(inputs)
    x = torch.cat((x, skip_connections), 1)
    return self.conv(x)

class uNetModel(torch.nn.Module):
  """UNET architecture"""
  def __init__(self, n_classes):
    super(uNetModel, self).__init__()
    #--------------------------
    # Encoder
    #--------------------------
    self.encoder1 = encoder_block(3, 64)
    self.encoder2 = encoder_block(64, 128)
    self.encoder3 = encoder_block(128, 256)
    self.encoder4 = encoder_block(256, 512)

    #--------------------------
    # Bottleneck
    #--------------------------
    self.bottleneck = conv_block(512, 1024)

    #--------------------------
    # Encoder
    #--------------------------
    self.decoder1 = decoder_block(1024, 512)
    self.decoder2 = decoder_block(512, 256)
    self.decoder3 = decoder_block(256, 128)
    self.decoder4 = decoder_block(128, 64)

    #--------------------------
    # Classifier
    #--------------------------
    self.classifier = torch.nn.Conv2d(64, n_classes, 1)

  def forward(self, inputs):
    x1, p1 = self.encoder1(inputs)
    x2, p2 = self.encoder2(p1)
    x3, p3 = self.encoder3(p2)
    x4, p4 = self.encoder4(p3)
    b = self.bottleneck(p4)

    d1 = self.decoder1(b, x4)
    d2 = self.decoder2(d1, x3)
    d3 = self.decoder3(d2, x2)
    d4 = self.decoder4(d3, x1)

    output = self.classifier(d4)
    return output

In [33]:
class UNET_segmentation(Dataset_Creator):
    def __init__(self, multiclass:bool, vessel_unet_path:str, lesion_unet_path:str, device:str, mean=[0.2816, 0.2817, 0.2816], std=[0.1992, 0.1991, 0.1991] ):
        """
        Intialising unet segmentation

        Args:
            multiclass (bool): Boolean if the dataset will be binary or multiclass
            vessel_unet_path (str): Path to UNET vessel model
            lesion_unet_path (str): Path to UNET lesion model
            device (str): cuda or cpu
            mean (list, optional): mean to normalise image UNET model was trained on. Defaults to [0.2816, 0.2817, 0.2816].
            std (list, optional): std to normalise image UNET model was trained on. Defaults to [0.1992, 0.1991, 0.1991].
        """
        self.multiclass = multiclass
        self.vessel_unet = torch.load(vessel_unet_path).to(device)
        self.lesion_unet = torch.load(lesion_unet_path).to(device)
        self.mean =  torch.tensor(mean).reshape(-1, 1, 1)
        self.std = torch.tensor(std).reshape(-1, 1, 1)
        self.device = device

        self.colour_map = {
            0: 0,  # Black for background (label 0)
            1: 255,  # White for label 1
            2: 128,  # Gray for label 2
            3: 100,  # Gray for label 3
            4: 150,  # Gray for label 4
            5: 200,  # Gray for label 5
            6: 50 
        }
        

    def create_dataset(self, source_path:str, dest_path:str, limit:int=None, show_output:bool=False):
        """
        Main function that will run image segmentation with UNET

        All images in the folder will be preprocessed and saved to the correct location in 512x512 with the same image name

        Args:
            source_path (str): Source path of the folder
            dest_path (str): Dest path of the folder
            limit (int, optional): Maximum number of images to process. Defaults to None.
            show_output (bool, optional): Show the processed image. Defaults to False
        """
        #create destination directory
        if not os.path.isdir(dest_path):
          os.makedirs(dest_path)

        folder_files = os.listdir(source_path)

        for i, file in enumerate(folder_files, start=1):
            base_image = Image.open(f'{source_path}/{file}') #open image
            resized_image = base_image.resize((512,512))
            enhanced_image = self._enhance_contrast(resized_image, False)
            tensor_image = torchvision.transforms.ToTensor()(enhanced_image)
            #normalising image
            input = ( tensor_image - self.mean )/self.std #normalise image

            #run vessel UNET and get the predictions with sigmoid
            vessel_mask = self.vessel_unet(input.unsqueeze(0).to(self.device))
            vessel_outputs = torch.nn.functional.sigmoid(vessel_mask)
            vessel_predictions = (vessel_outputs > 0.5)
            vessel_predictions = vessel_predictions.cpu().numpy()
            
            #run lesion UNET
            lesion_mask = self.lesion_unet(input.unsqueeze(0).to(self.device))
            
            result = np.zeros((512, 512), dtype=np.uint8)

            if self.multiclass:
                #get multiclass prediction
                lesion_outputs = torch.nn.functional.softmax(lesion_mask)
                lesion_predictions = torch.argmax(lesion_outputs, dim=1).cpu().numpy()
                lesion_predictions[lesion_predictions >= 1] += 1 #increase the class index such that it is now 0,2,3,4,5, to prevent overlap with vessel
                #merge mask, prioritise the vessel segmentation over lesion
                merged_mask = np.where(vessel_predictions != 0, vessel_predictions, lesion_predictions)
                
                #update the colour mapping
                for label, intensity in self.colour_map.items():
                    result[merged_mask.squeeze() == label] = intensity
            
                if show_output:
                    merged_result = np.hstack((np.array(resized_image), result))
                    cv2.imshow("mask", merged_result)
                    cv2.waitKey(10000)
                    cv2.destroyAllWindows()
                    
                #save image
                cv2.imwrite(f"{dest_path}/{file}", cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
                print(f"Image {i}/{len(folder_files)} processed. Output path:{dest_path}/{file}")

            elif not self.multiclass:
                #get binary prediction
                lesion_outputs = torch.nn.functional.sigmoid(lesion_mask)
                lesion_predictions = (lesion_outputs > 0.5).cpu().numpy()

                #merge mask, prioritise the vessel segmentation over lesion
                merged_mask = np.where(vessel_predictions != 0, vessel_predictions, lesion_predictions)
                result[merged_mask.squeeze() == 1] = 255 #convert binary
                
                if show_output:
                    merged_result = np.hstack((np.array(resized_image), result))
                    cv2.imshow("mask", merged_result)
                    cv2.waitKey(10000)
                    cv2.destroyAllWindows()

                #save image
                cv2.imwrite(f"{dest_path}/{file}", cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
                print(f"Image {i}/{len(folder_files)} processed. Output path:{dest_path}/{file}")
            
            if limit == None:
                pass

            elif i >= limit:
                break


    def _enhance_contrast(self, image:Image, show_image:bool):    
        """
        Function to enhance the contrast of the GREEN channel of a RGB image with CLAHE

        Args:
            image (Image): Image
            show_image (bool): Boolean to show the enhanced image 

        Returns:
            enhanced_image (Image): enhanced image 
        """
        image = np.array(image)
        r,g,b = cv2.split(image) #extract out green channel

        green_channel = cv2.cvtColor(g, cv2.COLOR_GRAY2RGB)  # convert green channel to RGB
        lab = cv2.cvtColor(green_channel, cv2.COLOR_RGB2LAB)  # convert rgb image to lab channels

        L, a, b = cv2.split(lab)  # split lab channels

        clahe = cv2.createCLAHE(clipLimit=3, tileGridSize=(8, 8))  # create CLAHE
        cL = clahe.apply(L)  # apply CLAHE to enhance contrast

        new_img = cv2.merge((cL, a, b))  # create new lab enhanced image
        enhanced_img = cv2.cvtColor(new_img, cv2.COLOR_LAB2RGB) # convert from lab to rgb


        if show_image:
            resized_enhanced_img = cv2.resize(enhanced_img, (224,224))
            cv2.imshow("", resized_enhanced_img)
            cv2.waitKey(10000)
            cv2.destroyAllWindows()
        
        return Image.fromarray(enhanced_img)

In [34]:
# Clearing GPU memory
import gc
torch.cuda.empty_cache()
gc.collect()

55

In [36]:
source_path = "../input/grading_images/DDR"
dest_path = "../input/grading_images/UNET_Binary"
folder_list = ["train", "val", "test"]
class_lists = ['0', '1', '2', '3', '4']

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

uNet_segmenter = UNET_segmentation(multiclass=False, 
                                   vessel_unet_path = "../models/vessel/FocalTverskyLossBase.pt",
                                   lesion_unet_path="../models/lesion/binary/DiceLoss.pt",
                                   device=device)

for folder_name in folder_list:
    for class_idx in class_lists:
        image_src_path = f"{source_path}/{folder_name}/{class_idx}"
        mask_dest_path = f"{dest_path}/{folder_name}/{class_idx}"
        uNet_segmenter.create_dataset(image_src_path, mask_dest_path)

Image 1/4334 processed. Output path:../input/grading_images/UNET_Binary/train/0/007-0007-000.jpg
Image 2/4334 processed. Output path:../input/grading_images/UNET_Binary/train/0/007-0008-000.jpg
Image 3/4334 processed. Output path:../input/grading_images/UNET_Binary/train/0/007-0009-000.jpg
Image 4/4334 processed. Output path:../input/grading_images/UNET_Binary/train/0/007-0022-000.jpg
Image 5/4334 processed. Output path:../input/grading_images/UNET_Binary/train/0/007-0023-000.jpg
Image 6/4334 processed. Output path:../input/grading_images/UNET_Binary/train/0/007-0025-000.jpg
Image 7/4334 processed. Output path:../input/grading_images/UNET_Binary/train/0/007-0026-000.jpg
Image 8/4334 processed. Output path:../input/grading_images/UNET_Binary/train/0/007-0028-000.jpg
Image 9/4334 processed. Output path:../input/grading_images/UNET_Binary/train/0/007-0029-000.jpg
Image 10/4334 processed. Output path:../input/grading_images/UNET_Binary/train/0/007-0030-000.jpg
Image 11/4334 processed. Outp

In [37]:
source_path = "../input/grading_images/DDR"
dest_path = "../input/grading_images/UNET_Multiclass"
folder_list = ["train", "val", "test"]
class_lists = ['0', '1', '2', '3', '4']

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Clearing GPU memory
import gc
torch.cuda.empty_cache()
gc.collect()


uNet_segmenter = UNET_segmentation(multiclass=True, 
                                   vessel_unet_path = "../models/vessel/FocalTverskyLossBase.pt",
                                   lesion_unet_path="../models/lesion/multiclass/DiceLoss_2.pt",
                                   device=device)

for folder_name in folder_list:
    for class_idx in class_lists:
        image_src_path = f"{source_path}/{folder_name}/{class_idx}"
        mask_dest_path = f"{dest_path}/{folder_name}/{class_idx}"
        uNet_segmenter.create_dataset(image_src_path, mask_dest_path)

  lesion_outputs = torch.nn.functional.softmax(lesion_mask)


Image 1/4334 processed. Output path:../input/grading_images/UNET_Multiclass/train/0/007-0007-000.jpg
Image 2/4334 processed. Output path:../input/grading_images/UNET_Multiclass/train/0/007-0008-000.jpg
Image 3/4334 processed. Output path:../input/grading_images/UNET_Multiclass/train/0/007-0009-000.jpg
Image 4/4334 processed. Output path:../input/grading_images/UNET_Multiclass/train/0/007-0022-000.jpg
Image 5/4334 processed. Output path:../input/grading_images/UNET_Multiclass/train/0/007-0023-000.jpg
Image 6/4334 processed. Output path:../input/grading_images/UNET_Multiclass/train/0/007-0025-000.jpg
Image 7/4334 processed. Output path:../input/grading_images/UNET_Multiclass/train/0/007-0026-000.jpg
Image 8/4334 processed. Output path:../input/grading_images/UNET_Multiclass/train/0/007-0028-000.jpg
Image 9/4334 processed. Output path:../input/grading_images/UNET_Multiclass/train/0/007-0029-000.jpg
Image 10/4334 processed. Output path:../input/grading_images/UNET_Multiclass/train/0/007-00