In [1]:
import random 
import torch 
from torch.utils import data 
import torchvision.transforms.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from PIL import Image
import nibabel as nib

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
IMAGE_SIZE = 128
BATCH_SIZE = 1
NUM_CLASS = 15

In [4]:
import cv2
from glob import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class Down3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv3D(in_channels, out_channels)
        )

    def forward(self, x):
        return self.mpconv(x)

class Up3D(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_channels, in_channels//2, kernel_size=2, stride=2)

        self.conv = DoubleConv3D(in_channels, out_channels)

    def forward(self, x1, x2):
        # print(x1.shape, x2.shape)
        x1    = self.up(x1)
        # print(x1.shape)
        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1    = F.pad(x1, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2,
                        diffZ // 2, diffZ - diffZ // 2))
        x     = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv3D, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, 1)

    def forward(self, x):
        return self.conv(x)

class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=False):
        super().__init__()
        self.in_channels  = in_channels
        self.out_channels = out_channels
        self.bilinear     = bilinear

        self.conv1    = DoubleConv3D(in_channels, 64)
        self.down1    = Down3D(64, 128)
        self.down2    = Down3D(128, 256)
        self.down3    = Down3D(256, 512)
        self.down4    = Down3D(512, 1024)
        self.up1      = Up3D(1024, 512, bilinear)
        self.up2      = Up3D(512, 256, bilinear)
        self.up3      = Up3D(256, 128, bilinear)
        self.up4      = Up3D(128, 64, bilinear)
        self.outconv  = OutConv3D(64, out_channels)

    def forward(self, x):
        # print(x.shape)
        # x = x.unsqueeze(1)
        x1 = self.conv1(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        # print(x5.shape, x4.shape)
        x6 = self.up1(x5, x4)
        x7 = self.up2(x6, x3)
        x8 = self.up3(x7, x2) 
        x9 = self.up4(x8, x1)
        output= self.outconv(x9)
        # print(x6.shape)
        # up network

        return output

In [6]:
checkpoint = torch.load('checkpoint.t7')
model = UNet3D(3,15).to(device)
model.load_state_dict(checkpoint['state_dict'])
model.eval()

UNet3D(
  (conv1): DoubleConv3D(
    (conv): Sequential(
      (0): Conv3d(3, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down3D(
    (mpconv): Sequential(
      (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv3D(
        (conv): Sequential(
          (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (4): BatchNor

In [17]:
path = '/home/arshad/Downloads/amos22/amos22'


input_paths   = sorted(glob(os.path.join(path, "imagesVa","*.nii.gz")))
target_paths  = sorted(glob(os.path.join(path, "labelsVa","*.nii.gz")))

In [18]:
# test_dl      = AmosDataLoader(input_paths, target_paths)
# test_loader  = DataLoader(test_dl, batch_size = BATCH_SIZE, drop_last= False, collate_fn=test_dl.collate_fn)

In [19]:
def dice_loss(input, target):
    smooth = 1.0

    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()

    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

In [10]:
def preprocess_img_input(input_im):
    # z_factor_input      = input_im.shape[2]*int(input_im.shape[0]/IMAGE_SIZE)**2
    input_im            = np.stack((input_im,)*3, axis=-1)
    input_im            = torch.tensor(input_im).float()/255
    # print(input_im.shape)
    input_im            = input_im.permute(3,2,0,1)
    
    input_im            = input_im.unsqueeze(0)
    # print('input_shape before inter',input_im.shape)
    output_size_input   = (82, IMAGE_SIZE, IMAGE_SIZE)
    input_im            = F.interpolate(input_im, size=output_size_input, mode='trilinear', align_corners=False)
    input_im            = input_im#.squeeze(0)
    return input_im

In [11]:
def preprocess_output(output_im):
    mask_cat              = np.zeros((NUM_CLASS, *output_im.shape), dtype=np.float32)
    for i in range(NUM_CLASS):
        mask_cat[i][output_im == i] = 1
    output_im             = torch.tensor(mask_cat).float()/255
    output_im             = output_im.permute(0,2,3,1)
    output_im             = output_im.unsqueeze(0)#.unsqueeze(0)
    output_size_input     = (82, IMAGE_SIZE, IMAGE_SIZE)
    output_im             = F.interpolate(output_im, size=output_size_input, mode='trilinear', align_corners=False)
    output_im             = output_im#.squeeze(0)
    return output_im

In [12]:
input_paths[:5]

['/home/arshad/Downloads/amos22/amos22/imagesTs/amos_0002.nii.gz',
 '/home/arshad/Downloads/amos22/amos22/imagesTs/amos_0003.nii.gz',
 '/home/arshad/Downloads/amos22/amos22/imagesTs/amos_0012.nii.gz',
 '/home/arshad/Downloads/amos22/amos22/imagesTs/amos_0020.nii.gz',
 '/home/arshad/Downloads/amos22/amos22/imagesTs/amos_0026.nii.gz']

In [13]:
path = 'results/'

In [14]:
# ! pip install tqdm

In [21]:
from tqdm import tqdm
# dices = []
for img_in in tqdm(input_paths ):
    name = img_in.split('/')[-1]
    # print(name)
    input_im = nib.load(img_in).get_fdata()
    # print(input_im.shape)
    input_shape = input_im.shape
    input_im = preprocess_img_input(input_im)
    
#     target_im = nib.load(img_out).get_fdata()
#     target_im = preprocess_output(target_im)
    pred = model(input_im.to(device))
    final_out = F.interpolate(pred, size=input_shape, mode='trilinear')
    final_out = final_out.squeeze(0)
    final_out = torch.sum(final_out, dim=0, keepdim=False)
    # print(final_out.shape)
    nii_img = nib.Nifti1Image(final_out.to('cpu').detach().numpy(), affine=None)
    nib.save(nii_img, path + name)
    # dice = dice_loss(pred.to('cpu'), target_im)
    # print(dice)
    # print(final_out.shape)
    # break÷
    
    

100%|█████████████████████████████████████████| 120/120 [15:03<00:00,  7.53s/it]


In [16]:
for ix, ims in enumerate(test_loader):
    
    # val_epoch_loss.append(loss)
    # val_dice_loss.append(dice_val)
    print(dice_val)
    break

NameError: name 'test_loader' is not defined

In [None]:
# checkpoint = torch.load('checkpoint.t7')
# model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# epoch = checkpoint['epoch']