In [35]:
from torch.utils.data import Dataset, DataLoader
from src.configuration.config import datadict, IMAGE_HEIGHT, IMAGE_WIDTH, batch_size, num_epochs, num_workers, pin_memory, LEARNING_RATE
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
import numpy as np
from PIL import Image
class CustomDataset2D(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, datadict=datadict):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.series = os.listdir(mask_dir)
        self.datadict = datadict
        reversed_dict = {v: k for k, v in datadict.items()}
        self.reversed_dict = reversed_dict
        sorted_list = [0]
        for i in range(len(self.series)):
            first_folder = os.listdir(os.path.join(self.mask_dir, self.series[i]))[0]
            folder_path = os.path.join(self.mask_dir, self.series[i], first_folder)
            series_length = len(os.listdir(folder_path))
            sorted_list.append(sorted_list[-1] + series_length)
        self.sorted_list = sorted_list
            
            

    def __len__(self):
        return self.sorted_list[-1]


    def transform_volume(self, image_volume, mask_volume):
        # print(image_volume.transpose(1, 2, 0).shape)
        # print(mask_volume.transpose(1, 2, 0).shape)
        transformed = self.transform(
                image=image_volume.transpose(1, 2, 0), 
                mask=mask_volume.transpose(1, 2, 0)  # Change (9, 512, 512) -> (512, 512, 9)
            )
        images = transformed['image']
        masks = transformed['mask'].permute(2, 0, 1)

        # print(images.shape)
        # print(masks.shape)

        return images , masks
        
    def find_series(self, i , j, index):
        if i==j:
            return i-1
        mid_ind = (i+j)//2
        if self.sorted_list[mid_ind] >= index:
            ans = self.find_series(i , mid_ind, index)
        else:
            ans = self.find_series(mid_ind , j, index)
        return ans
            
            
    def update_index_and_find_series(self, index):
        series_indx = self.find_series(0 , len(self.sorted_list)-1, index)
        return series_indx, self.sorted_list[series_indx-1]-index-1
    def __getitem__(self, index):
        series_indx, index = self.update_index_and_find_series(index+1)
        self.series_index = series_indx
            
        Maskvolume = []
        ImageVolume = []
        flag = 0
        for key in range(len(self.reversed_dict.keys())):
            catag = self.reversed_dict[key]
            Maskcatgvolume = []
            Masks = os.path.join(self.mask_dir, os.listdir(self.mask_dir)[self.series_index], catag)
            MasksList = os.listdir(Masks)
            MasksList = sorted(MasksList)
            
            for msk in MasksList:
                pngMask = Image.open(os.path.join(Masks, msk))
                pngMask = np.array(pngMask)
                Maskcatgvolume.append(pngMask)
        
                if msk in self.images and flag == 0:
                    pngimage = Image.open(os.path.join(self.image_dir ,msk))
                    pngimage = np.array(pngimage)
                    ImageVolume.append(pngimage)
            flag = 1
                    
            Maskcatgvolume = np.stack(Maskcatgvolume, axis = 0)
            Maskvolume.append(Maskcatgvolume)
            
        Maskvolume = np.stack(Maskvolume, axis = 0)
        ImageVolume = np.stack(ImageVolume, axis = 0)

        
        newMaskVolume = []
        for i in range(Maskvolume.shape[1]):
            newMaskVolume.append(np.argmax(Maskvolume[:,i,:,:] , axis=0))
        newMaskVolume = np.stack(newMaskVolume, axis=0)
        
        newMaskVolume[newMaskVolume>0] = -1
        newMaskVolume[newMaskVolume == 0] = 1
        newMaskVolume[newMaskVolume == -1] = 0
        
        for i in range(Maskvolume.shape[1]):
            Maskvolume[0,i,:,:] = Maskvolume[0,i,:,:] + newMaskVolume[i,:,:]

        newImageVolume = []
        newMaskVolume = []
        empty_slice = np.zeros(ImageVolume[0,:,:].shape)

        middleslice = ImageVolume[index,:,:]
        middlesliceMask = Maskvolume[:,index,:,:]

        if index == 0:
            if ImageVolume.shape[0] == 1:
                newImageVolume.append(empty_slice)
                newImageVolume.append(middleslice)
                newImageVolume.append(empty_slice)
                newImageVolume = np.stack(newImageVolume, axis=0)
            else:
                lastslice = ImageVolume[index+1,:,:]
                newImageVolume.append(empty_slice)
                newImageVolume.append(middleslice)
                newImageVolume.append(lastslice)
                newImageVolume = np.stack(newImageVolume, axis=0)

        
        elif index == (ImageVolume.shape[0]-1):
            firstslice = ImageVolume[index-1,:,:]
            newImageVolume.append(firstslice)
            newImageVolume.append(middleslice)
            newImageVolume.append(empty_slice)
            newImageVolume = np.stack(newImageVolume, axis=0)

        else:
            firstslice = ImageVolume[index-1,:,:]
            lastslice = ImageVolume[index+1,:,:]
            newImageVolume.append(firstslice)
            newImageVolume.append(middleslice)
            newImageVolume.append(lastslice)
            newImageVolume = np.stack(newImageVolume, axis=0)

        if self.transform is not None:
            transformed_image_volume, transformed_mask_volume = self.transform_volume(newImageVolume, middlesliceMask)
            

        # return image, mask
        return transformed_image_volume, transformed_mask_volume
            
        return newImageVolume, middlesliceMask

In [36]:
TrainingDir = r"C:\Users\Rishabh\Documents\pytorch-3dunet\TrainingData"
image_dir = os.path.join(TrainingDir, 'Images')
mask_dir = os.path.join(TrainingDir, 'Masks')
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)
data = CustomDataset2D(image_dir, mask_dir,transform = train_transform)

In [37]:
data[0]

RecursionError: maximum recursion depth exceeded