In [38]:
import os
import nibabel as nib
import numpy as np
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
from src.utils.losses import GeneralizedDiceLoss, DiceLoss, BCEDiceLoss
from torch.utils.data import Dataset, DataLoader
from src.utils.utils import custom_collate_BHSD
import torch.nn.functional as F
import torch
from src.configuration.config import (
    datadict, TrainingDir, batch_size, num_epochs, num_workers,
    pin_memory, LEARNING_RATE, IMAGE_HEIGHT, IMAGE_WIDTH
)

In [39]:
Dir = r"C:\Users\Rishabh\Downloads\label_192\label_192"
masks_path = os.path.join(Dir, 'ground truths')
images_path = os.path.join(Dir, 'images')
masks = os.listdir(masks_path)
images = os.listdir(images_path)
os.listdir(Dir)

['ground truths', 'images']

In [40]:
newDatadict = {
    'BackGround': 0,
    'Bleed-Subdural': 1,
    'Scalp-Hematoma': 2,
    'Bleed-Others': 3,
    'Bleed-Intraventricular': 4,
    'Bleed-Epidural': 5,}

In [41]:
class BHSD_3D(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, datadict=newDatadict):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.series = os.listdir(mask_dir)
        self.datadict = datadict
        reversed_dict = {v: k for k, v in datadict.items()}
        self.reversed_dict = reversed_dict

    def transform_volume(self, image_volume, mask_volume):
        transformed = self.transform(
                image=image_volume, 
                mask=mask_volume
            )
        images = transformed['image']
        masks = transformed['mask'].permute(2, 0, 1)
        masks = F.one_hot(masks.long(), num_classes=6)
        masks = masks.permute(3, 0, 1, 2)
        return images , masks.float()

    

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        nii_segementation = nib.load(os.path.join(self.mask_dir, self.images[index]))
        nii_image = nib.load(os.path.join(self.image_dir, self.images[index]))
        
        # Get the image data as a NumPy array
        image_data = nii_image.get_fdata()
        segementation_data = nii_segementation.get_fdata()

        if self.transform is not None:
            transformed_image_volume, transformed_mask_volume = self.transform_volume(image_data, segementation_data)

        transformed_image_volume = transformed_image_volume.unsqueeze(0)
        # transformed_mask_volume = transformed_mask_volume.unsqueeze(0)
        return transformed_image_volume, transformed_mask_volume

In [42]:
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)


In [43]:
data = BHSD_3D(images_path, masks_path, train_transform)

In [44]:
x, y = data[0]

In [45]:
x.shape, y.shape

(torch.Size([1, 28, 128, 128]), torch.Size([6, 28, 128, 128]))

In [46]:
for i in range(y.shape[0]):
    print(torch.unique(y[i,:,:,:]))

tensor([0., 1.])
tensor([0.])
tensor([0., 1.])
tensor([0., 1.])
tensor([0.])
tensor([0.])


In [63]:
def new_custom_collate_BHSD(Batch):
    max_depth = 0
    for x,y in batch:
        max_depth = max(max_depth, x.shape[1])


    newImageVolume = []
    newMaskVolume = []
    for i in range(len(batch)):
        remmaining_slice = max_depth - batch[i][0].shape[1]
        # print(remmaining_slice)
        if remmaining_slice > 0:
            empty_slice = torch.zeros((1,remmaining_slice,batch[i][0].shape[2], batch[i][0].shape[3]))
            empty_slice_mask = torch.zeros((1,remmaining_slice,batch[i][0].shape[2], batch[i][0].shape[3]))
            newImageVolume.append(torch.cat((batch[i][0], empty_slice), dim=1))
            newMaskVolume.append(torch.cat((batch[i][1], empty_slice_mask), dim=1))
        else:
            newImageVolume.append(batch[i][0])
            newMaskVolume.append(batch[i][1])
    

    newImageVolume = torch.stack(newImageVolume, dim=0)
    newMaskVolume = torch.stack(newMaskVolume, dim=0)

    return newImageVolume, newMaskVolume
    for x, y in Batch:
        print(x.shape, ' ', y.shape)

In [64]:
train_loader = DataLoader(
        data,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
        collate_fn=new_custom_collate_BHSD,
    )

In [65]:
for batch_idx, (inputs, targets) in enumerate(train_loader):
    print(batch_idx)

torch.Size([1, 32, 128, 128])   torch.Size([6, 32, 128, 128])
torch.Size([1, 32, 128, 128])   torch.Size([6, 32, 128, 128])
torch.Size([1, 36, 128, 128])   torch.Size([6, 36, 128, 128])
torch.Size([1, 34, 128, 128])   torch.Size([6, 34, 128, 128])


TypeError: cannot unpack non-iterable NoneType object