In [1]:
import os
import cv2
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
from src.configuration.config import datadict, TrainingDir
from src.Dataset.dataset import CustomDatasetHW, CustomDatasetHWD, CustomDataset, CustomDatasetHW_new, CustomDatasetHW_validation
from src.utils.losses import BCEDiceLoss, DiceLoss, GeneralizedDiceLoss, WeightedCrossEntropyLoss, WeightedSmoothL1Loss
from src.configuration.config import IMAGE_HEIGHT, IMAGE_WIDTH
from src.utils.utils import custom_collate, custom_collate_Variable_HW
from src.Models.D_UNet import UNet3D

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DunetModel = UNet3D(in_channels=1, out_channels=1).to(device)
checkpointpath = r"C:\Users\Rishabh\Documents\3D_Unet_Bleed\model_checkpoint.pth"
DunetModel.load_state_dict(torch.load(checkpointpath))

<All keys matched successfully>

In [3]:
os.listdir(r"C:\Users\Rishabh\Documents\3D_Unet_Bleed")

['.git',
 '.ipynb_checkpoints',
 'Inference.ipynb',
 'main.py',
 'model_checkpoint.pth',
 'Predictions',
 'requirements.txt',
 'src',
 'Untitled.ipynb']

In [4]:
train_transform = A.Compose(
    [
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

In [5]:
ImagesDir = os.path.join(TrainingDir, 'Images')
MasksDir = os.path.join(TrainingDir, 'Masks')
print(os.listdir(TrainingDir))
# data1 = CustomDatasetHW(ImagesDir, MasksDir, transform=train_transform)
data = CustomDatasetHW_validation(ImagesDir, MasksDir, transform=train_transform)

['Images', 'Masks']


In [6]:
batch_size = 1
num_workers = 0
pin_memory = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = DataLoader(
    data,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=pin_memory,
    shuffle=True,
    collate_fn=custom_collate,
)

In [7]:
def save_prediction(image, mask):
    gray_image = image
    binary_mask = mask
    
    gray_norm = gray_image / 255.0

    # Create an RGB image with grayscale as background
    overlay = np.stack([gray_norm, gray_norm, gray_norm], axis=-1)
    
    # Define lighter colors for each class (0-8)
    colors = {
        1: [1.0, 0.6, 0.6],   # Light Red
        2: [0.6, 1.0, 0.6],   # Light Green
        3: [0.6, 0.6, 1.0],   # Light Blue
        4: [1.0, 1.0, 0.6],   # Light Yellow
        5: [1.0, 0.6, 1.0],   # Light Magenta
        6: [0.6, 1.0, 1.0],   # Light Cyan
        7: [0.8, 0.7, 1.0],   # Light Purple
        8: [1.0, 0.8, 0.6]    # Light Orange
    }
    
    # Create an RGB mask initialized with zeros
    mask_rgb = np.zeros_like(overlay)

    
    # Assign colors based on binary_mask values
    for value, color in colors.items():
        mask_rgb[binary_mask == value] = color
    
    # Define transparency level
    alpha = 0.4  # Transparency level (0-1)
    
    # Blend grayscale image with the colored mask
    blended = overlay * (1 - alpha) + mask_rgb * alpha
    blended = (blended*255).astype(np.uint8)
    image = Image.fromarray(blended)
    index = len(os.listdir('Predictions'))+1
    image.save(f'Predictions/output_{index}.jpg', quality=100)

In [None]:
DunetModel.eval()
for batch_idx, (inputs, targets) in enumerate(train_loader):
    inputs, targets = inputs.to(device), targets.to(device)
    with torch.no_grad():
        output = DunetModel(inputs)
    print(output.shape)


    # inputs = np.array(inputs)
    # targets = np.array(targets)
    
    # for batch in range(inputs.shape[0]):
    #     for sli in range(inputs.shape[2]):
            
    #         gray_image = inputs[batch, 0, sli, :, :]*255
    #         binary_mask = targets[batch, 0, sli, :, :]
            # print('gray_image:-',np.unique(gray_image))
            # # print('binary_mask:-',np.unique(binary_mask))
            # if len(np.unique(binary_mask))>1:
            #     # print('binary_mask:-',np.unique(binary_mask))
            #     save_prediction(gray_image, binary_mask)

In [14]:
# Ensure model is in evaluation mode
DunetModel.eval()

# Iterate through training data
for batch_idx, (inputs, targets) in enumerate(train_loader):
    # Ensure inputs and targets are on the same device and data type
    inputs = inputs.to(device, dtype=torch.float32)
    # targets = targets.to(device, dtype=torch.float32)
    # print(inputs[:,:,:10,:,:].shape)
    inputs = inputs[:,:,:8,:128,:128]

    # Disable gradient calculation for inference
    with torch.no_grad():
        output = DunetModel(inputs)

    output = output.cpu()
    output = np.array(output)
    print(len(np.unique(output)))
    print(np.min(np.unique(output)))
    print(np.max(np.unique(output)))
    
    

    # print(np.unique(output))

    # Print output shape
    print(f"Batch {batch_idx}: Output Shape: {output.shape}")


2.25.216239243324311492486775007119878849305


  output = np.array(output)


107395
Batch 0: Output Shape: (1, 1, 8, 128, 128)
2.25.874750920676985942236560559012010376830
55903
Batch 1: Output Shape: (1, 1, 8, 128, 128)
2.25.327971212165492878990090645563463447694
52497
Batch 2: Output Shape: (1, 1, 8, 128, 128)
2.25.812609565055494479265790573472977615559
75469
Batch 3: Output Shape: (1, 1, 8, 128, 128)
2.25.963853606161210352739966258030989557592
79005
Batch 4: Output Shape: (1, 1, 8, 128, 128)
2.25.387503757565414440314154621408994040708
103584
Batch 5: Output Shape: (1, 1, 8, 128, 128)
2.25.255390386701589077552528917500107662799
95878
Batch 6: Output Shape: (1, 1, 8, 128, 128)
1.2.826.0.1.3680043.10.511.3.50319555245010760304192407653470925
78805
Batch 7: Output Shape: (1, 1, 8, 128, 128)
