In [1]:
import torch
from torch.nn import functional as F
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
from unet import UNet
from torch.utils.data import DataLoader
from skimage.transform import resize
from dataset import *
from transformations import *
from model import *
from torchvision import transforms
import matplotlib.pyplot as plt

In [2]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, image_file_path, mask_file_path):
        self.images = np.load(image_file_path)
        self.masks = np.load(mask_file_path)

    def __getitem__(self, index):
        # Get an individual image and mask pair from the dataset
        image = self.images[index]
        mask = self.masks[index]

        # Convert the image and mask to PyTorch tensors
        image = torch.from_numpy(image)
        mask = torch.from_numpy(mask)

        # Return the processed image and mask
        return image, mask

    def __len__(self):
        return len(self.images)

# ################################################ train data
# # Paths to your image and mask .npy files
# image_file_path_tr = r'/content/drive/MyDrive/data/Train_img_10frame_256.npy'
# mask_file_path_tr = r'/content/drive/MyDrive/data/Train_msk_10frame_256.npy'


# batch_size = 5
# val_split = 0.2

# # Create an instance of your custom dataset
# dataset = CustomDataset(image_file_path_tr, mask_file_path_tr)

# dataset_size = len(dataset)
# val_size = int(val_split * dataset_size)
# train_size = dataset_size - val_size
# print(dataset_size, val_size, train_size)
# train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])

# # Create a DataLoader
# train_dataloader = DataLoader(train_set, batch_size=batch_size , shuffle=True, num_workers=0)
# val_dataloader = DataLoader(val_set, batch_size=batch_size , shuffle=True, num_workers=0)
# ################################################ validation data
# Paths to your image and mask .npy files
image_file_path_val = 'Test_img_10frame_256.npy'
mask_file_path_val = 'Test_msk_10frame_256.npy'


# Create an instance of your custom dataset
dataset = CustomDataset(image_file_path_val, mask_file_path_val)

# Create a DataLoader
test_dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)


print("test_loader:", len(dataset))


# Iterate over the dataloader to access batches of image and mask pairs
# for batch in test_dataloader:
#     # Unpack the batch into images and masks
#     images, masks = batch
#     print(images.shape)
#     print(masks.shape)


test_loader: 100


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # Use GPU if available
# device = torch.device("cpu")
print(device)

cuda:0


In [4]:

"""
3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation
Paper URL: https://arxiv.org/abs/1606.06650
Author: Amir Aghdam
"""


from torch import nn
# from torchsummary import summary
import torch
import time

class Conv3DBlock(nn.Module):
    """
    The basic block for double 3x3x3 convolutions in the analysis path
    -- __init__()
    :param in_channels -> number of input channels
    :param out_channels -> desired number of output channels
    :param bottleneck -> specifies the bottlneck block
    -- forward()
    :param input -> input Tensor to be convolved
    :return -> Tensor
    """

    def __init__(self, in_channels, out_channels, bottleneck = False) -> None:
        super(Conv3DBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels= in_channels, out_channels=out_channels//2, kernel_size=(3,3,3), padding=1)
        self.bn1 = nn.BatchNorm3d(num_features=out_channels//2)
        self.conv2 = nn.Conv3d(in_channels= out_channels//2, out_channels=out_channels, kernel_size=(3,3,3), padding=1)
        self.bn2 = nn.BatchNorm3d(num_features=out_channels)
        self.relu = nn.ReLU()
        self.bottleneck = bottleneck
        if not bottleneck:
            self.pooling = nn.MaxPool3d(kernel_size=(2,2,2), stride=2)


    def forward(self, input):
        res = self.relu(self.bn1(self.conv1(input)))
        res = self.relu(self.bn2(self.conv2(res)))
        out = None
        if not self.bottleneck:
            out = self.pooling(res)
        else:
            out = res
        return out, res




class UpConv3DBlock(nn.Module):
    """
    The basic block for upsampling followed by double 3x3x3 convolutions in the synthesis path
    -- __init__()
    :param in_channels -> number of input channels
    :param out_channels -> number of residual connections' channels to be concatenated
    :param last_layer -> specifies the last output layer
    :param num_classes -> specifies the number of output channels for dispirate classes
    -- forward()
    :param input -> input Tensor
    :param residual -> residual connection to be concatenated with input
    :return -> Tensor
    """

    def __init__(self, in_channels, res_channels=0, last_layer=False, num_classes=None) -> None:
        super(UpConv3DBlock, self).__init__()
        assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments'
        self.upconv1 = nn.ConvTranspose3d(in_channels=in_channels, out_channels=in_channels, kernel_size=(2, 2, 2), stride=2)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm3d(num_features=in_channels//2)
        self.conv1 = nn.Conv3d(in_channels=in_channels+res_channels, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
        self.conv2 = nn.Conv3d(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
        self.last_layer = last_layer
        if last_layer:
            self.conv3 = nn.Conv3d(in_channels=in_channels//2, out_channels=num_classes, kernel_size=(1,1,1))


    def forward(self, input, residual=None):
        out = self.upconv1(input)
        if residual!=None:
          # print("shape input after upconv1:", out.shape)
          # print("shape residual:", residual.shape)
          if(out.shape[2]==4):
              # print("flag")
              out = F.interpolate(out, size=(5, 128, 128), mode='trilinear', align_corners=False)
          out = torch.cat((out, residual), 1)
          # print("--------- output of upblock:")
          # print("concate shape:", out.shape)
          # print("Flag none")
        out = self.relu(self.bn(self.conv1(out)))
        # print("out shape:", out.shape)
        out = self.relu(self.bn(self.conv2(out)))
        # print("out shape:", out.shape)
        if self.last_layer: out = self.conv3(out)
        # print("out shape:", out.shape)
        return out



class UNet3D(nn.Module):
    """
    The 3D UNet model
            the source code of 3D U-Net
            https://github.com/AghdamAmir/3D-UNet/blob/main/train.py
    """

    def __init__(self, in_channels, num_classes, level_channels=[64, 128, 256], bottleneck_channel=512) -> None:
        super(UNet3D, self).__init__()
        level_1_chnls, level_2_chnls, level_3_chnls = level_channels[0], level_channels[1], level_channels[2]
        self.a_block1 = Conv3DBlock(in_channels=in_channels, out_channels=level_1_chnls)
        self.a_block2 = Conv3DBlock(in_channels=level_1_chnls, out_channels=level_2_chnls)
        self.a_block3 = Conv3DBlock(in_channels=level_2_chnls, out_channels=level_3_chnls)
        self.bottleNeck = Conv3DBlock(in_channels=level_3_chnls, out_channels=bottleneck_channel, bottleneck= True)
        self.s_block3 = UpConv3DBlock(in_channels=bottleneck_channel, res_channels=level_3_chnls)
        self.s_block2 = UpConv3DBlock(in_channels=level_3_chnls, res_channels=level_2_chnls)
        self.s_block1 = UpConv3DBlock(in_channels=level_2_chnls, res_channels=level_1_chnls, num_classes=num_classes, last_layer=True)

    def forward(self, input):
        #Analysis path forward feed
        # print("shape of input: ", input.shape)
        out, residual_level1 = self.a_block1(input)
        # print("shape of out:", out.shape)
        # print("shape of residual_level1: ", residual_level1.shape)

        out, residual_level2 = self.a_block2(out)
        # print("shape of out:", out.shape)
        # print("shape of residual_level2: ", residual_level2.shape)

        out, residual_level3 = self.a_block3(out)
        # print("shape of out:", out.shape)
        # print("shape of residual_level3: ", residual_level3.shape)

        out, _ = self.bottleNeck(out)
        # print("shape of out bottleNeck: ", out.shape)

        #Synthesis path forward feed
        out = self.s_block3(out, residual_level3)
        # print("shape s_block3: ", out.shape)
        out = self.s_block2(out, residual_level2)
        # print("shape s_block2: ", out.shape)
        out = self.s_block1(out, residual_level1)
        # print("shape s_block3: ", out.shape)
        return out



In [5]:
!CUDA_LAUNCH_BLOCKING=1

In [6]:
!nvidia-smi

Tue Jul  4 15:04:10 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.108.03   Driver Version: 510.108.03   CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A10          On   | 00000000:17:00.0 Off |                    0 |
|  0%   55C    P8    23W / 150W |      2MiB / 23028MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A10          On   | 00000000:CA:00.0 Off |                    0 |
|  0%   63C    P0    87W / 150W |  19629MiB / 23028MiB |     28%      Default |
|       

In [46]:
model = UNet3D(num_classes=4, in_channels=1).to(device)
model.load_state_dict(torch.load('final_model-4.pt'))


<All keys matched successfully>

In [47]:
import torch
import numpy as np
import monai

def one_hot_encoders(input_tensor, num_classes=4):
    tensor_list = []
    for i in range(num_classes):
        temp_prob = input_tensor == i * torch.ones_like(input_tensor)
        tensor_list.append(temp_prob)
    output_tensor = torch.cat(tensor_list, dim=1)
    return output_tensor.float()

def compute_metrics(test_dataloader, model, device, num_classes=4):
    img, msks = next(iter(test_dataloader))
    
    outputs = np.empty((len(test_dataloader), num_classes, img.size(dim=1), img.size(dim=2), img.size(dim=3)))
    labels = np.empty((len(test_dataloader), num_classes, img.size(dim=1), img.size(dim=2), img.size(dim=3)))
    i = 0
    
    model.eval()
    with torch.no_grad():
        for input, label in test_dataloader:
            input = input.to(device)
            label = label.to(device)
            input = input.unsqueeze(1)

            output = model(input)
            outputs_soft = torch.softmax(output, dim=1)
            label = one_hot_encoders(label.unsqueeze(1), num_classes=4)

            outputs[i,:,:,:,:] = outputs_soft[0].cpu()
            labels[i,:,:,:,:] = label[0].cpu()
            i += 1
            
    outputs = torch.from_numpy(outputs)
    labels = torch.from_numpy(labels)
    
    dice = monai.metrics.compute_dice(outputs, labels, include_background=False)
    hausdorff = monai.metrics.compute_hausdorff_distance(outputs, labels)
    print('çlasswise dice')
    # print(dice)
    mean_dice = torch.mean(dice, dim=0)
    cls_dice = torch.mean(dice, dim=1)
    print('Dice score')
    print(f'\n Mean Dice score: {torch.mean(mean_dice):.4f}')
    print(dice)
    
    mean_hausdorff = torch.mean(hausdorff, dim=0)
    print('Hausdorff distance')
    print(f'\n Mean Hausdorff disctance: {torch.mean(mean_hausdorff):.4f}')
    print(hausdorff)

In [36]:
for inp, label in test_dataloader:
    print(inp.shape)
    print(label.shape)
    label = one_hot_encoders(label.unsqueeze(1), num_classes=4)
    print(label.shape)
    inp = inp.to(device)
    label = label.to(device)
    inp = inp.unsqueeze(1)
    output = model(inp)
    print(output.shape)
    break

torch.Size([1, 10, 256, 256])
torch.Size([1, 10, 256, 256])
torch.Size([1, 4, 10, 256, 256])
torch.Size([1, 4, 10, 256, 256])


In [48]:
compute_metrics(test_dataloader, model, device)

çlasswise dice
Dice score

 Mean Dice score: 0.3875
tensor([[0.4078, 0.4634, 0.5253],
        [0.5491, 0.3463, 0.6434],
        [0.2045, 0.0789, 0.1948],
        [0.4183, 0.4967, 0.6193],
        [0.4986, 0.4650, 0.7187],
        [0.0254, 0.2439, 0.1381],
        [0.2787, 0.4209, 0.2110],
        [0.3685, 0.3964, 0.2685],
        [0.1523, 0.3506, 0.3327],
        [0.4610, 0.4101, 0.5030],
        [0.0237, 0.0266, 0.0059],
        [0.2859, 0.3974, 0.4686],
        [0.2476, 0.2820, 0.3501],
        [0.1501, 0.2328, 0.4158],
        [0.4695, 0.3079, 0.6291],
        [0.6399, 0.2532, 0.6664],
        [0.4009, 0.3860, 0.2370],
        [0.6183, 0.4174, 0.7399],
        [0.6841, 0.3425, 0.5220],
        [0.3700, 0.4498, 0.3325],
        [0.3878, 0.4818, 0.7123],
        [0.5443, 0.4136, 0.6955],
        [0.4966, 0.3688, 0.3226],
        [0.0873, 0.2505, 0.3445],
        [0.1641, 0.2448, 0.1609],
        [0.5534, 0.2777, 0.6434],
        [0.1331, 0.3288, 0.1871],
        [0.1952, 0.1563, 0.378