In [16]:
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from src.Models.D_UNet import UNet2D, ResidualUNet2D
from src.configuration.config import datadict

In [24]:
class CustomDataset2D(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, datadict=datadict):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.series = os.listdir(mask_dir)
        self.datadict = datadict
        reversed_dict = {v: k for k, v in datadict.items()}
        self.reversed_dict = reversed_dict

    def __len__(self):
        return len(self.series)
        
    def __getitem__(self, index):
        count = 0
        for i in range(len(self.series)):
            first_folder = os.listdir(os.path.join(self.mask_dir, self.series[i]))[0]
            folder_path = os.path.join(self.mask_dir, self.series[i], first_folder)
            series_length = len(os.listdir(folder_path))

            if count+series_length > index:
                self.series_index = i
                index = (count+series_length)/index
                index = index - 1
                break
            else:
                count = count + series_length
            
        Maskvolume = []
        ImageVolume = []
        flag = 0
        for key in range(len(self.reversed_dict.keys())):
            catag = self.reversed_dict[key]
            Maskcatgvolume = []
            Masks = os.path.join(self.mask_dir, os.listdir(self.mask_dir)[self.series_index], catag)
            MasksList = os.listdir(Masks)
            MasksList = sorted(MasksList)
            
            for msk in MasksList:
                pngMask = Image.open(os.path.join(Masks, msk))
                pngMask = np.array(pngMask)
                Maskcatgvolume.append(pngMask)
        
                if msk in self.images and flag == 0:
                    pngimage = Image.open(os.path.join(self.image_dir ,msk))
                    pngimage = np.array(pngimage)
                    ImageVolume.append(pngimage)
            flag = 1
                    
            Maskcatgvolume = np.stack(Maskcatgvolume, axis = 0)
            Maskvolume.append(Maskcatgvolume)
            
        Maskvolume = np.stack(Maskvolume, axis = 0)
        ImageVolume = np.stack(ImageVolume, axis = 0)
        
        newMaskVolume = []
        for i in range(Maskvolume.shape[1]):
            newMaskVolume.append(np.argmax(Maskvolume[:,i,:,:] , axis=0))
        newMaskVolume = np.stack(newMaskVolume, axis=0)
        
        newMaskVolume[newMaskVolume>0] = -1
        newMaskVolume[newMaskVolume == 0] = 1
        newMaskVolume[newMaskVolume == -1] = 0
        
        for i in range(Maskvolume.shape[1]):
            Maskvolume[0,i,:,:] = Maskvolume[0,i,:,:] + newMaskVolume[i,:,:]

        return ImageVolume ,Maskvolume 

In [25]:
import os
Dir = r"C:\Users\Rishabh\Documents\pytorch-3dunet\TrainingData"
image_dir = os.path.join(Dir, 'Images')
mask_dir = os.path.join(Dir, 'Masks')
data = CustomDataset2D(image_dir, mask_dir)

In [29]:
ImageVolume ,Maskvolume  = data[5]

In [30]:
ImageVolume.shape

(155, 512, 512)

In [31]:
Maskvolume.shape

(9, 155, 512, 512)