In [1]:
import torch
import os
from src.Models.D_UNet import UNet2D, ResidualUNet2D
from src.Dataset.dataset import CustomDataset2D
from torch.utils.data import Dataset, DataLoader
from src.configuration.config import datadict
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet2D(in_channels=3, out_channels=9, f_maps=128).to(device)
checkpoint_path = r'C:\Users\Rishabh\Documents\3D_Unet_Bleed\model_checkpoint_0.pth'
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [2]:
os.listdir(r'C:\Users\Rishabh\Documents\3D_Unet_Bleed')

['.git',
 '.ipynb_checkpoints',
 'Datasetfor2Dunet.ipynb',
 'Debugging',
 'Inference.ipynb',
 'lossfuctioncreating.ipynb',
 'main.py',
 'main3d.py',
 'main3d_9classes.py',
 'model_checkpoint_0.pth',
 'New3dClaas.ipynb',
 'new_main_9clss.py',
 'Predictions',
 'requirements.txt',
 'src',
 'testingscript.py',
 'UnderstandingDataset.ipynb',
 'Unet25train.py',
 'Unet2D.ipynb',
 'Unet2_5d_inference.ipynb',
 'Untitled.ipynb']

In [4]:
class CustomDataset2D_numpy(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, datadict=datadict):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.series = os.listdir(mask_dir)
        self.datadict = datadict
        reversed_dict = {v: k for k, v in datadict.items()}
        self.reversed_dict = reversed_dict

    def __len__(self):
        count = 0
        for i in range(len(self.series)):
            first_folder = os.listdir(os.path.join(self.mask_dir, self.series[i]))[0]
            folder_path = os.path.join(self.mask_dir, self.series[i], first_folder)
            series_length = len(os.listdir(folder_path))
            count = count + series_length
        return count


    def transform_volume(self, image_volume, mask_volume):
        # print(image_volume.transpose(1, 2, 0).shape)
        # print(mask_volume.transpose(1, 2, 0).shape)
        transformed = self.transform(
                image=image_volume.transpose(1, 2, 0), 
                mask=mask_volume.transpose(1, 2, 0)  # Change (9, 512, 512) -> (512, 512, 9)
            )
        images = transformed['image']
        masks = transformed['mask'].permute(2, 0, 1)

        # print(images.shape)
        # print(masks.shape)

        return images , masks

        
    def __getitem__(self, index):
        # print("log1")
        count = 0
        index = index + 1
        for i in range(len(self.series)):
            first_folder = os.listdir(os.path.join(self.mask_dir, self.series[i]))[0]
            folder_path = os.path.join(self.mask_dir, self.series[i], first_folder)
            series_length = len(os.listdir(folder_path))

            if count+series_length > index:
                self.series_index = i
                index = index-count-1
                break
            elif count+series_length == index:
                self.series_index = i
                index = series_length - 1
                break
            else:
                count = count + series_length

        # print("log2")
            
        Maskvolume = []
        ImageVolume = []
        flag = 0
        for key in range(len(self.reversed_dict.keys())):
            catag = self.reversed_dict[key]
            Maskcatgvolume = []
            Masks = os.path.join(self.mask_dir, os.listdir(self.mask_dir)[self.series_index], catag)
            MasksList = os.listdir(Masks)
            MasksList = sorted(MasksList)
            
            for msk in MasksList:
                pngMask = Image.open(os.path.join(Masks, msk))
                pngMask = np.array(pngMask)
                Maskcatgvolume.append(pngMask)
        
                if msk in self.images and flag == 0:
                    pngimage = Image.open(os.path.join(self.image_dir ,msk))
                    pngimage = np.array(pngimage)
                    ImageVolume.append(pngimage)
            flag = 1
                    
            Maskcatgvolume = np.stack(Maskcatgvolume, axis = 0)
            Maskvolume.append(Maskcatgvolume)
            
        Maskvolume = np.stack(Maskvolume, axis = 0)
        ImageVolume = np.stack(ImageVolume, axis = 0)

        # print("log3")
        
        newMaskVolume = []
        for i in range(Maskvolume.shape[1]):
            newMaskVolume.append(np.argmax(Maskvolume[:,i,:,:] , axis=0))
        newMaskVolume = np.stack(newMaskVolume, axis=0)
        
        newMaskVolume[newMaskVolume>0] = -1
        newMaskVolume[newMaskVolume == 0] = 1
        newMaskVolume[newMaskVolume == -1] = 0
        
        for i in range(Maskvolume.shape[1]):
            Maskvolume[0,i,:,:] = Maskvolume[0,i,:,:] + newMaskVolume[i,:,:]


        # print("log4")



        newImageVolume = []
        newMaskVolume = []
        empty_slice = np.zeros(ImageVolume[0,:,:].shape)
        # empty_slice_mask = np.zeros(Maskvolume[:,0,:,:].shape)
        middleslice = ImageVolume[index,:,:]
        middlesliceMask = Maskvolume[:,index,:,:]

        # print("log5")
        
        if index == 0:
            if ImageVolume.shape[0] == 1:
                newImageVolume.append(empty_slice)
                newImageVolume.append(middleslice)
                newImageVolume.append(empty_slice)
                newImageVolume = np.stack(newImageVolume, axis=0)
                
            else:
                lastslice = ImageVolume[index+1,:,:]
                newImageVolume.append(empty_slice)
                newImageVolume.append(middleslice)
                newImageVolume.append(lastslice)
                newImageVolume = np.stack(newImageVolume, axis=0)

                
        elif index == (ImageVolume.shape[0]-1):
            firstslice = ImageVolume[index-1,:,:]
            newImageVolume.append(firstslice)
            newImageVolume.append(middleslice)
            newImageVolume.append(empty_slice)
            newImageVolume = np.stack(newImageVolume, axis=0)

        else:
            firstslice = ImageVolume[index-1,:,:]
            lastslice = ImageVolume[index+1,:,:]
            newImageVolume.append(firstslice)
            newImageVolume.append(middleslice)
            newImageVolume.append(lastslice)
            newImageVolume = np.stack(newImageVolume, axis=0)

        if self.transform is not None:
            transformed_image_volume, transformed_mask_volume = self.transform_volume(newImageVolume, middlesliceMask)
            

        # return image, mask
        return transformed_image_volume, transformed_mask_volume


        return newImageVolume, middlesliceMask
        # return ImageVolume ,Maskvolume


In [5]:
train_transform = A.Compose(
    [
        A.Resize(height=256, width=256),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)


In [6]:
import os
Dir = r"C:\Users\Rishabh\Documents\pytorch-3dunet\TrainingData"
image_dir = os.path.join(Dir, 'Images')
mask_dir = os.path.join(Dir, 'Masks')
data = CustomDataset2D_numpy(image_dir, mask_dir,transform=train_transform)

In [98]:
index = 8
newImageVolume_np, middlesliceMask = data[index]
newImageVolume = torch.tensor(newImageVolume_np, dtype=torch.float32)
# newImageVolume = newImageVolume.unsqueeze(0)
newImageVolume = newImageVolume.to(device)

  newImageVolume = torch.tensor(newImageVolume_np, dtype=torch.float32)


In [100]:
newImageVolume.shape

torch.Size([3, 256, 256])

In [8]:
for index in range(40,1000):
    newImageVolume, middlesliceMask = data[index]
    # newImageVolume = torch.tensor(newImageVolume_np, dtype=torch.float32)
    newImageVolume = newImageVolume.unsqueeze(0)
    newImageVolume = newImageVolume.to(device)
    output = model(newImageVolume)
    output = torch.sigmoid(output)
    output = (output > 0.5).float()
    output = output.cpu()
    output = np.array(output)
    output = np.argmax(output, axis=0)
    middlesliceMask = np.argmax(np.array(middlesliceMask.cpu()), axis=0)
    print(index," ",np.unique(output), " ", np.unique(middlesliceMask))

  output = np.array(output)
  middlesliceMask = np.argmax(np.array(middlesliceMask.cpu()), axis=0)


40   [0]   [0]
41   [0]   [0]
42   [0]   [0]
43   [0]   [0]
44   [0]   [0]
45   [0]   [0 7]
46   [0]   [0 7]
47   [0]   [0]
48   [0]   [0 7]
49   [0]   [0]
50   [0]   [0]
51   [0]   [0]
52   [0]   [0]
53   [0]   [0]
54   [0]   [0]
55   [0]   [0]
56   [0]   [0]
57   [0]   [0]
58   [0]   [0]
59   [0]   [0]
60   [0]   [0]
61   [0]   [0 7]
62   [0]   [0]
63   [0]   [0]
64   [0]   [0]
65   [0]   [0]
66   [0]   [0]
67   [0]   [0]
68   [0]   [0]
69   [0]   [0]
70   [0]   [0]
71   [0]   [0]
72   [0]   [0]
73   [0]   [0]
74   [0]   [0 7]
75   [0]   [0]
76   [0]   [0]
77   [0]   [0]
78   [0]   [0]
79   [0]   [0]
80   [0]   [0]
81   [0]   [0]
82   [0]   [0]
83   [0]   [0]
84   [0]   [0]
85   [0]   [0]
86   [0]   [0]
87   [0]   [0]
88   [0]   [0]
89   [0]   [0]
90   [0]   [0]
91   [0]   [0]
92   [0]   [0]
93   [0]   [0]
94   [0]   [0]
95   [0]   [0]
96   [0]   [0]
97   [0]   [0]
98   [0]   [0]
99   [0]   [0]
100   [0]   [0]
101   [0]   [0]
102   [0]   [0]
103   [0]   [0]
104   [0]   [0]
105   [0] 

KeyboardInterrupt: 

In [76]:
middlesliceMask.shape

(9, 512, 512)

In [39]:
def save_prediction(image, mask):
    gray_image = image
    binary_mask = mask
    
    gray_norm = gray_image / 255.0

    # Create an RGB image with grayscale as background
    overlay = np.stack([gray_norm, gray_norm, gray_norm], axis=-1)
    
    # Define lighter colors for each class (0-8)
    colors = {
        1: [1.0, 0.6, 0.6],   # Light Red
        2: [0.6, 1.0, 0.6],   # Light Green
        3: [0.6, 0.6, 1.0],   # Light Blue
        4: [1.0, 1.0, 0.6],   # Light Yellow
        5: [1.0, 0.6, 1.0],   # Light Magenta
        6: [0.6, 1.0, 1.0],   # Light Cyan
        7: [0.8, 0.7, 1.0],   # Light Purple
        8: [1.0, 0.8, 0.6]    # Light Orange
    }
    
    # Create an RGB mask initialized with zeros
    mask_rgb = np.zeros_like(overlay)

    
    # Assign colors based on binary_mask values
    for value, color in colors.items():
        mask_rgb[binary_mask == value] = color
    
    # Define transparency level
    alpha = 0.4  # Transparency level (0-1)
    
    # Blend grayscale image with the colored mask
    blended = overlay * (1 - alpha) + mask_rgb * alpha
    blended = (blended*255).astype(np.uint8)
    image = Image.fromarray(blended)
    index = len(os.listdir('Predictions'))+1
    image.save(f'Predictions/output_{index}.jpg', quality=100)

In [77]:
gray_image = newImageVolume_np[1, :, :]*255
binary_mask = output
binary_mask = np.argmax(binary_mask, axis=0)
# print('gray_image:-',np.unique(gray_image))
print('binary_mask:-',np.unique(binary_mask))
if len(np.unique(binary_mask))>1:
    print('binary_mask:-',np.unique(binary_mask))
    save_prediction(gray_image, binary_mask)

binary_mask:- [0]
