In [1]:
from src.Models.D_UNet import UNet2D, ResidualUNet2D
from src.utils.losses import BCEDiceLoss, DiceLoss, GeneralizedDiceLoss, WeightedCrossEntropyLoss, WeightedSmoothL1Loss

In [None]:
class CustomDataset3D(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, datadict=datadict,  output_size=(256, 256), output_depth=5):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.series = os.listdir(mask_dir)
        self.datadict = datadict
        reversed_dict = {v: k for k, v in datadict.items()}
        self.reversed_dict = reversed_dict

        self.output_size = (IMAGE_HEIGHT, IMAGE_WIDTH)  # (H, W)
        self.output_depth = output_depth  # New Depth

    def __len__(self):
        return len(self.series)


    def __getitem__(self, index):
        Maskvolume = []
        ImageVolume = []
        # print(self.series[index])
        flag = 0
        for key in range(len(self.reversed_dict.keys())):
            catag = self.reversed_dict[key]
            Maskcatgvolume = []
            Masks = os.path.join(self.mask_dir, os.listdir(self.mask_dir)[index], catag)
            MasksList = os.listdir(Masks)
            MasksList = sorted(MasksList)
            
            for msk in MasksList:
                pngMask = Image.open(os.path.join(Masks, msk))
                pngMask = np.array(pngMask)
                Maskcatgvolume.append(pngMask)
    
                if msk in self.images and flag == 0:
                    pngimage = Image.open(os.path.join(self.image_dir ,msk))
                    pngimage = np.array(pngimage)
                    ImageVolume.append(pngimage)
            flag = 1
                    
            Maskcatgvolume = np.stack(Maskcatgvolume, axis = 0)
            Maskvolume.append(Maskcatgvolume)
            
        Maskvolume = np.stack(Maskvolume, axis = 0)
        ImageVolume = np.stack(ImageVolume, axis = 0)
        ImageVolume = np.expand_dims(ImageVolume, axis=0)
        newMaskVolume = []
        for i in range(Maskvolume.shape[1]):
            newMaskVolume.append(np.argmax(Maskvolume[:,i,:,:] , axis=0))
        newMaskVolume = np.stack(newMaskVolume, axis=0)
        newMaskVolume = np.expand_dims(newMaskVolume, axis=0)



        

        resized_images = np.array([cv2.resize(img, self.output_size, interpolation=cv2.INTER_LINEAR) for img in ImageVolume[0]])
        resized_masks = np.array([cv2.resize(mask, self.output_size, interpolation=cv2.INTER_NEAREST) for mask in newMaskVolume[0]])

        # print('resized_images:-',resized_images.shape)
        # print('resized_masks:-',resized_masks.shape)
        # print(np.unique(resized_images))


        new_images = []
        new_masks = []
        if self.transform is not None:
            for slic in range(resized_images.shape[0]):
                image = resized_images[slic,:,:]
                mask = resized_masks[slic,:,:]

                augmentations = self.transform(image=image, mask=mask)
                image = augmentations["image"].squeeze(0)
                mask = augmentations["mask"].squeeze(0)
                new_images.append(image)
                new_masks.append(mask)
        
        return torch.stack(new_images).unsqueeze(0), torch.stack(new_masks).unsqueeze(0)



In [None]:
f = [16, 32, 64, 128, 256, 512, 1024]
for i in range(len(f)):
    model = UNet2D(in_channels=1, out_channels=9, f_maps=f[i])
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total Parameters {f[i]}: {total_params}")

In [2]:
import os
from src.configuration.config import datadict
Dir = r"C:\Users\Rishabh\Documents\pytorch-3dunet\TrainingData"
image_dir = os.path.join(Dir, 'Images')
mask_dir = os.path.join(Dir, 'Masks')
images = os.listdir(image_dir)
series = os.listdir(mask_dir)
datadict = datadict
reversed_dict = {v: k for k, v in datadict.items()}

In [3]:
from PIL import Image
import numpy as np
index = 3
Maskvolume = []
ImageVolume = []
flag = 0
for key in range(len(reversed_dict.keys())):
    catag = reversed_dict[key]
    Maskcatgvolume = []
    Masks = os.path.join(mask_dir, os.listdir(mask_dir)[index], catag)
    MasksList = os.listdir(Masks)
    MasksList = sorted(MasksList)
    
    for msk in MasksList:
        pngMask = Image.open(os.path.join(Masks, msk))
        pngMask = np.array(pngMask)
        Maskcatgvolume.append(pngMask)

        if msk in images and flag == 0:
            pngimage = Image.open(os.path.join(image_dir ,msk))
            pngimage = np.array(pngimage)
            ImageVolume.append(pngimage)
    flag = 1
            
    Maskcatgvolume = np.stack(Maskcatgvolume, axis = 0)
    Maskvolume.append(Maskcatgvolume)
    
Maskvolume = np.stack(Maskvolume, axis = 0)
ImageVolume = np.stack(ImageVolume, axis = 0)

newMaskVolume = []
for i in range(Maskvolume.shape[1]):
    newMaskVolume.append(np.argmax(Maskvolume[:,i,:,:] , axis=0))
newMaskVolume = np.stack(newMaskVolume, axis=0)

newMaskVolume[newMaskVolume>0] = -1
newMaskVolume[newMaskVolume == 0] = 1
newMaskVolume[newMaskVolume == -1] = 0

for i in range(Maskvolume.shape[1]):
    Maskvolume[0,i,:,:] = Maskvolume[0,i,:,:] + newMaskVolume[i,:,:]

In [None]:
Maskvolume.shape

In [None]:
Maskvolume[1:,:,:,:].shape

In [None]:
ImageVolume[:3,:,:].shape

In [None]:
np.unique(ImageVolume)

In [4]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = ImageVolume[:3,:,:]
inputs = inputs/255
inputs = torch.tensor(inputs, dtype=torch.float32)
inputs = inputs.unsqueeze(0)
inputs = inputs.to(device)
targets = Maskvolume[:,1,:,:]
targets = torch.tensor(targets, dtype=torch.float32)
targets = targets.unsqueeze(0)
targets = targets.to(device)
print(inputs.shape)

torch.Size([1, 3, 512, 512])


In [None]:
model = UNet2D(in_channels=3, out_channels=9, f_maps=128).to(device)
outputs = model(inputs)

In [None]:
outputs.shape

In [None]:
targets.shape

In [None]:
lossfn = DiceLoss()

In [None]:
loss = lossfn(outputs, targets)
print(loss)

In [5]:
import torch.optim as optim
import torch.nn as nn
model = UNet2D(in_channels=3, out_channels=9, f_maps=128).to(device)

LEARNING_RATE = 0.001
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

# lossfn = DiceLoss()
lossfn = nn.BCEWithLogitsLoss()
# lossfn = BCEDiceLoss()

  scaler = torch.cuda.amp.GradScaler()


In [6]:
model.train()
for i in range(10000):
    inputs = inputs.to(device)
    targets = targets.to(device)

    optimizer.zero_grad()
    
    # Forward pass with mixed precisio/n
    with torch.cuda.amp.autocast():
        outputs = model(inputs)
        loss = lossfn(outputs, targets)
        print(loss.item())

    # loss.backward()
    # optimizer.step()
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    # if i%100 == 0:
    #     print(loss.item())

  with torch.cuda.amp.autocast():


nan
nan


KeyboardInterrupt: 

In [None]:
np.unique(np.array(targets.cpu()))

In [None]:
np.unique(np.array(outputs.cpu()))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def dice_loss(preds, targets, smooth=1e-6):
    """
    Compute the Dice Loss.
    
    Args:
        preds (torch.Tensor): Model predictions (logits or probabilities) with shape (N, C, H, W) or (N, C).
        targets (torch.Tensor): Ground truth labels with the same shape as preds.
        smooth (float): Smoothing factor to avoid division by zero.
    
    Returns:
        torch.Tensor: Dice loss value.
    """
    preds = torch.sigmoid(preds)  # Ensure predictions are in [0, 1] if logits are provided
    
    intersection = torch.sum(preds * targets, dim=(1, 2, 3))
    union = torch.sum(preds, dim=(1, 2, 3)) + torch.sum(targets, dim=(1, 2, 3))
    
    dice_score = (2.0 * intersection + smooth) / (union + smooth)
    
    return 1 - dice_score.mean()

# Example usage:
preds = torch.randn(4, 1, 256, 256)  # Example tensor with batch size 4
targets = torch.randint(0, 2, (4, 1, 256, 256)).float()

loss = dice_loss(preds, targets)
print("Dice Loss:", loss.item())


In [None]:
preds = torch.randn(2, 9, 256, 256)  # Example tensor with batch size 4
targets = torch.randint(0, 11, (2, 9, 256, 256)).float()
lossfn = GeneralizedDiceLoss()
# lossfn = WeightedCrossEntropyLoss()
# loss = dice_loss(targets, targets)
loss = lossfn(preds, targets)
loss

In [None]:
loss