In [1]:
""" Python Notebook to test the interpolation of the small dataset. """
### Imports
import os
import argparse
import torch
import math
import random
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR

from torch.utils.data import DataLoader
from torch.utils.data import Dataset, DataLoader
import sklearn.metrics as metrics
from skimage.exposure import rescale_intensity
from skimage import morphology
from scipy import ndimage

from tqdm.notebook import tqdm
from importlib import reload

import pydicom as dicom
import pylibjpeg

import matplotlib.pyplot as plt
import cv2

#imports from vfit
# import config
import myutils
from loss import Loss
import shutil
import os
import time

from patchify import patchify


In [2]:
def build_test_dataset():
    num_adjacent_slices = 3

    class DICOM_Dataset(Dataset):
        def __init__(self, data_dir, transform=None):
            self.data_dir = data_dir
            self.transform = transform
            self.data = []
            # walk through the directory and get all the .dcm files
            for root, dirs, files in os.walk(data_dir):
                for file in files:
                    if file.endswith('.npy'):
                        self.data.append(os.path.join(root, file))
            self.data.sort()

        def __len__(self):
            return len(self.data)


        def __preproc__(self, filepath):
            """ Read in the .npy file from the filepath and allow pickling """

            pixel_array = np.load(filepath, allow_pickle=True)
            
            """ Get a random set of 3 adjacent slices """
            random_start_slice = random.randint(0, pixel_array.shape[0] - num_adjacent_slices)
            slices = pixel_array[random_start_slice:random_start_slice + num_adjacent_slices] 

            return slices


        def __getitem__(self, idx):
            filepath = self.data[idx]
            slices = self.__preproc__(filepath)
            # convert numpy array to dtype float32
            slices = slices.astype(np.float32)
            if self.transform:
                slices = self.transform(slices)

            # add a channel dimension to the slices
            slices = np.expand_dims(slices, axis=1)
            # duplicate the slices to make a 3 channel image
            slices = np.repeat(slices, 3, axis=1)
            # parse out the middle slice and store it as gt
            gt = slices[num_adjacent_slices // 2]
            # remove the middle slice from the list of slices
            slices = np.delete(slices, num_adjacent_slices // 2, 0)
            # convert the slices to a tensor
            slices = torch.from_numpy(slices)
            # convert the gt to a tensor
            gt = torch.from_numpy(gt)

            return (slices, gt)

    test_dataset = DICOM_Dataset(data_dir='data_np/testing', transform=None)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=True)

    return test_loader

In [3]:
import scipy.interpolate as interpolate
def bilinear_interpolation(image1, image2, w):
    # initialize as float32
    interpolated_image = np.zeros(image1.shape, dtype=np.float32)
    x_dim = np.arange(image1.shape[1])
    y_dim = np.arange(image1.shape[0])

    img1_interp = interpolate.interp2d(x_dim, y_dim, image1, kind='linear')
    img2_interp = interpolate.interp2d(x_dim, y_dim, image2, kind='linear')
    w = .5
    interpolated_image = w * img1_interp(x_dim, y_dim) + (1 - w) * img2_interp(x_dim, y_dim)

    # add 3 channels to the image
    interpolated_image = np.expand_dims(interpolated_image, axis=0)
    interpolated_image = np.repeat(interpolated_image, 3, axis=0)
    # cast to float32
    interpolated_image = interpolated_image.astype(np.float32)
    return interpolated_image

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from model.VFIT_S import UNet_3D_3D

##### Define Loss & Optimizer #####
criterion = nn.L1Loss()
        
model = UNet_3D_3D(n_inputs=2, joinType="concat")
model.to(device)
# model.load_state_dict(torch.load('checkpoints/VFIT_epoch_patches_4_batch_50.pth'))
model.load_state_dict(torch.load('checkpoints/VFIT_epoch_patches_1_batch_50.pth'))
model.eval()

test_loader = build_test_dataset()

In [5]:
for i, (images, gt_image) in enumerate(test_loader):

    with torch.no_grad():

        # make images and gt_image into numpy ndarray to use patchify
        images = images.numpy()
        gt_image = gt_image.numpy()

        # make blank numpy arrays to store the patches
        images_patches = np.zeros((images.shape[0], images.shape[1], 8, 8, images.shape[2], 256, 256), dtype=np.float32)
        gt_image_patches = np.zeros((gt_image.shape[0], 8, 8, gt_image.shape[1], 256, 256), dtype=np.float32)

        for batch in range(images.shape[0]):
            for slice in range(images.shape[1]):
                images_patches[batch, slice] = patchify(images[batch, slice], (3, 256, 256), step=256)

            gt_image_patches[batch] = patchify(gt_image[batch], (3, 256, 256), step=256)

        images = images_patches.reshape(images_patches.shape[0], images_patches.shape[1], images_patches.shape[2]*images_patches.shape[3], images_patches.shape[4], images_patches.shape[5], images_patches.shape[6])
        gt_image = gt_image_patches.reshape(gt_image_patches.shape[0], gt_image_patches.shape[1]*gt_image_patches.shape[2], gt_image_patches.shape[3], gt_image_patches.shape[4], gt_image_patches.shape[5])
            
        # make images and gt_image a torch tensor
        images = torch.from_numpy(images)
        gt_image = torch.from_numpy(gt_image)

        for patch in range(images.shape[2]):
            slices_of_image_patch = images[:, :, patch, :, :, :]
            gt_image_patch = gt_image[:, patch, :, :, :]

            slices_of_image_patch = slices_of_image_patch.to(device)
            gt_image_patch = gt_image_patch.to(device)

            slices_of_image_patch = slices_of_image_patch.float()
            gt_image_patch = gt_image_patch.float()

            """ Verify that gt patch isn't black """
            if torch.sum(gt_image_patch) == 0:
                # skip this patch
                continue 

            slices_of_image_patch = [slices_of_image_patch[:,i,:,:,:] for i in range(slices_of_image_patch.shape[1])]

            # print out number of slices in list
            print("Number of slices in this patch: ", len(slices_of_image_patch))
            # print out shape of each slice
            print("Shape of each patch: ", slices_of_image_patch[0].shape)
            # Forward
            out = model(slices_of_image_patch)

            # Calculate loss
            loss = criterion(out, gt_image_patch)

            # print out loss
            print("Loss: ", loss.item())

            # print out shape of out
            print("Shape of out: ", out.shape)
            
            # display the input slices, ground truth, and predicted output
            # use opencv
            import cv2
            import matplotlib.pyplot as plt

            # convert the list slices to the cpu and convert to numpy array
            slices_of_image_patch = [slices_of_image_patch[i].cpu().numpy() for i in range(len(slices_of_image_patch))]
            slices_of_image_patch = np.array(slices_of_image_patch)
            # get rid of batch dim 
            slices_of_image_patch = np.squeeze(slices_of_image_patch, axis=1)
            # convert the gt_image to a numpy array
            gt_image_patch = gt_image_patch[0].cpu().numpy()
            # convert the out to a numpy array
            out = out[0].cpu().numpy()

            # print out shapes of slices, gt_image, and out
            print("Shape of slices: ", slices_of_image_patch.shape)
            print("Shape of gt_image: ", gt_image_patch.shape)
            print("Shape of out: ", out.shape)

            """ Display the slices, gt_image, and out """
            # create the 3x3 grid of images with matplotlib
            fig, axs = plt.subplots(3, 3)

            # save min and max values of slices
            min_slice = np.min(slices_of_image_patch)
            max_slice = np.max(slices_of_image_patch)

            """ Insert bilinear interpolation here """
            interpolated_image = bilinear_interpolation(slices_of_image_patch[0, 0, :, :], slices_of_image_patch[1, 0, :, :], w=0.5)
            
            # calculate L1 loss by subtacting the interpolated image from the gt_image and summing the absolute values
            # divide by number of pixels
            bi_loss = np.sum(np.abs(cv2.subtract(interpolated_image[0, :, :], gt_image_patch[0, :, :])))
            bi_loss = bi_loss / (gt_image_patch.shape[1] * gt_image_patch.shape[2])
            # calculate L1 loss of the predicted output the same way
            out_loss = np.sum(np.abs(cv2.subtract(out[0, :, :], gt_image_patch[0, :, :])))
            out_loss = out_loss / (gt_image_patch.shape[1] * gt_image_patch.shape[2])

           
            # convert the interpolated_image to uint8
            interpolated_image = (interpolated_image * 255).astype(np.uint8)
            # display the interpolated image row 2 column 3
            axs[1, 2].imshow(interpolated_image[0,:,:])
            # label this as the interpolated image
            axs[1, 2].set_title("Bilinear Interpolation")


           

            # display the slices# convert the slices to uint8
            slices_of_image_patch = (slices_of_image_patch * 255).astype(np.uint8)
            # display the slices
            axs[0, 0].imshow(slices_of_image_patch[0, 0, :, :])
            # label this as slice 1
            axs[0, 0].set_title("Slice N")
            axs[1, 0].imshow(slices_of_image_patch[1, 0, :, :])
            # label this as slice 3
            axs[1, 0].set_title("Slice N+2")

            # display the difference between the slices
            axs[2, 0].imshow(cv2.subtract(slices_of_image_patch[0, 0, :, :], slices_of_image_patch[1, 0, :, :]))
            # label this as the difference between slice 3 and slice 1
            axs[2, 0].set_title("Diff (A)")
            # calculate this difference
            diff = cv2.subtract(slices_of_image_patch[0, 0, :, :], slices_of_image_patch[1, 0, :, :])
            # calculate the L1 loss by summing the absolute values and dividing by number of pixels
            diff = np.sum(np.abs(diff))
            diff = diff / (slices_of_image_patch.shape[2] * slices_of_image_patch.shape[3])

            # convert the gt_image to uint8
            gt_image_patch = (gt_image_patch * 255).astype(np.uint8)
            # display the gt_image
            axs[0, 1].imshow(gt_image_patch[0, :, :])
            # label this as the ground truth
            axs[0, 1].set_title("N+1 (Ground Truth)")
            # display gt again for bilinear comparison
            axs[0, 2].imshow(gt_image_patch[0, :, :])
            # label this as the ground truth
            axs[0, 2].set_title("N+1 (Ground Truth)")

            # map values below 0 to 0
            out[out < 0] = 0
            # map values above 1 to 1
            out[out > 1] = 1
            # map minimum value of out with min_slice and max value of out with max_slice
            out = (out - np.min(out)) / (np.max(out) - np.min(out))
            out = out * (max_slice - min_slice) + min_slice

            # convert the out to uint8
            out = (out * 255).astype(np.uint8)
            # display the out
            axs[1, 1].imshow(out[0, :, :])
            # label this as the predicted output
            axs[1, 1].set_title("Predicted Output")

            # display the difference between the gt_image and out
            axs[2, 1].imshow(cv2.subtract(gt_image_patch[0, :, :], out[0, :, :]))
            # label this as the difference between the ground truth and predicted output
            axs[2, 1].set_title("Diff (B)")

            # display the difference between the interpolated image and gt
            print("Shape of interpolated_image: ", interpolated_image.shape)
            print("Shape of gt_image_patch: ", gt_image_patch[0,:,:].shape)
            axs[2, 2].imshow(cv2.subtract(gt_image_patch[0, :, :], interpolated_image[0,:,:]))
            # label this as the difference between the interpolated image and predicted output
            axs[2, 2].set_title("Diff (C)")
            
            # print out min and max values of slices, gt_image, and out
            print("Min value of slices: ", np.min(slices_of_image_patch))
            print("Max value of slices: ", np.max(slices_of_image_patch))
            print("Min value of gt_image: ", np.min(gt_image_patch))
            print("Max value of gt_image: ", np.max(gt_image_patch))
            print("Min value of out: ", np.min(out))
            print("Max value of out: ", np.max(out))
            print("Min values of abs slice diff: ", np.min(np.abs(slices_of_image_patch[0, 0, :, :] - slices_of_image_patch[1, 0, :, :])))
            print("Max values of abs slice diff: ", np.max(np.abs(slices_of_image_patch[0, 0, :, :] - slices_of_image_patch[1, 0, :, :])))
            print("Min value of abs gt_image and out diff: ", np.min(np.abs(gt_image_patch[0, :, :] - out[0, :, :])))
            print("Max value of abs gt_image and out diff: ", np.max(np.abs(gt_image_patch[0, :, :] - out[0, :, :])))
            
            # set figure title below to the patch number and loss
            fig.suptitle("DBT Sample: "+ str(i) +"\nPatch: " + str(patch) + "\nN and N+2 weighted abs diff: " + str(diff) + "\nPred_Loss: " + str(out_loss) + "\nBi_Loss: " + str(bi_loss))

            # gives the subplots some vertical padding
            fig.tight_layout()
            # give the layout horizontal padding for the text
            # fig.subplots_adjust(top=0.85)
            

            # save the plot as a png file in test_plt_images folder
            # make a folder for the DBT_Sample number if one does not exist
            if not os.path.exists("test_plt_images/DBT_Sample_" + str(i)):
                os.makedirs("test_plt_images/DBT_Sample_" + str(i))
            # save the plot as a png file
            plt.savefig("test_plt_images/DBT_Sample_" + str(i) + "/Patch_" + str(patch) + "_ImgDiff_" + str(diff) + "_Pred_Loss_" + str(out_loss) + "_Bi_Loss_" + str(bi_loss) + ".png")



            time.sleep(10)

Number of slices in this patch:  2
Shape of each patch:  torch.Size([1, 3, 256, 256])


  "See the documentation of nn.Upsample for details.".format(mode)


Loss:  0.0006021884619258344
Shape of out:  torch.Size([1, 3, 256, 256])
Shape of slices:  (2, 3, 256, 256)
Shape of gt_image:  (3, 256, 256)
Shape of out:  (3, 256, 256)
Shape of interpolated_image:  (3, 256, 256)
Shape of gt_image_patch:  (256, 256)
Min value of slices:  0
Max value of slices:  255
Min value of gt_image:  0
Max value of gt_image:  255
Min value of out:  0
Max value of out:  255
Min values of abs slice diff:  0
Max values of abs slice diff:  255
Min value of abs gt_image and out diff:  0
Max value of abs gt_image and out diff:  255
Number of slices in this patch:  2
Shape of each patch:  torch.Size([1, 3, 256, 256])
Loss:  0.0032544354908168316
Shape of out:  torch.Size([1, 3, 256, 256])
Shape of slices:  (2, 3, 256, 256)
Shape of gt_image:  (3, 256, 256)
Shape of out:  (3, 256, 256)
Shape of interpolated_image:  (3, 256, 256)
Shape of gt_image_patch:  (256, 256)
Min value of slices:  0
Max value of slices:  255
Min value of gt_image:  0
Max value of gt_image:  255
Mi



Number of slices in this patch:  2
Shape of each patch:  torch.Size([1, 3, 256, 256])
Loss:  0.004312166012823582
Shape of out:  torch.Size([1, 3, 256, 256])
Shape of slices:  (2, 3, 256, 256)
Shape of gt_image:  (3, 256, 256)
Shape of out:  (3, 256, 256)
Shape of interpolated_image:  (3, 256, 256)
Shape of gt_image_patch:  (256, 256)
Min value of slices:  0
Max value of slices:  255
Min value of gt_image:  0
Max value of gt_image:  255
Min value of out:  0
Max value of out:  255
Min values of abs slice diff:  0
Max values of abs slice diff:  255
Min value of abs gt_image and out diff:  0
Max value of abs gt_image and out diff:  255
Number of slices in this patch:  2
Shape of each patch:  torch.Size([1, 3, 256, 256])
Loss:  0.00790172629058361
Shape of out:  torch.Size([1, 3, 256, 256])
Shape of slices:  (2, 3, 256, 256)
Shape of gt_image:  (3, 256, 256)
Shape of out:  (3, 256, 256)
Shape of interpolated_image:  (3, 256, 256)
Shape of gt_image_patch:  (256, 256)
Min value of slices:  0