In [6]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [7]:
kernel = np.ones((6, 6), np.uint8)

In [8]:
class Medical(Dataset):
    def __init__(self,image_dir , mask_dir , transform = None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images =[img for img in os.listdir(image_dir)
        if os.path.exists(os.path.join(image_dir, img)) and 
        os.path.exists(os.path.join(mask_dir, img.replace(".png", "_Annotation.png")))]

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,index):

        img_path = os.path.join(self.image_dir,self.images[index])
        mask_path = os.path.join(self.mask_dir , self.images[index].replace(".png","_Annotation.png"))
        if not os.path.exists(img_path):
            print(f"Image file {img_path} does not exist.")
            return None
        if not os.path.exists(mask_path):
            print(f"Mask file {mask_path} does not exist.")
            return None
        image = np.array(Image.open(img_path).convert("L"),dtype =np.float32)
        mask = np.array(Image.open(mask_path).convert("L") , dtype =np.float32)
        mask = cv2.dilate(mask, kernel, iterations=1) 
        mask[mask==255.0]=1.0

        if self.transform is not None:
            augmentation = self.transform(image=image,mask=mask)
            image = augmentation["image"]
            mask = augmentation["mask"]

        return image , mask