In [1]:
!pip install torchio



In [2]:
!pip install monai



In [3]:
!pip install itk



In [1]:
import argparse
import os
import sys
import pandas as pd
from tqdm import tqdm
import numpy as np
import glob
import multiprocessing
from PIL import Image

import torch
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torch import autograd

from torchio.transforms import CropOrPad
from monai.data import ArrayDataset, DataLoader, PILReader
from monai.transforms import Compose, LoadImage, AddChannel, RandFlip, RandRotate, RandRotate90, RandScaleIntensity, CenterSpatialCrop, ToTensor, ScaleIntensity, LoadPNG, RandSpatialCrop
from monai.visualize import plot_2d_or_3d_image

import FlowArrayDataset

from utils import *
from VGGLoss import *
from Generator import *
from Discriminator import *
import pytorch_ssim

from math import log10
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Credits instructions: https://torchio.readthedocs.io/#credits



In [6]:
data_dir = "/data/*"

device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

In [7]:
# 6 input slices
inputZ01_path = sorted(glob.glob(os.path.join(data_dir, '*A04Z01*.tif'), recursive=True))
inputZ02_path = sorted(glob.glob(os.path.join(data_dir, '*A04Z02*.tif'), recursive=True))
inputZ03_path = sorted(glob.glob(os.path.join(data_dir, '*A04Z03*.tif'), recursive=True))
inputZ04_path = sorted(glob.glob(os.path.join(data_dir, '*A04Z04*.tif'), recursive=True))
inputZ05_path = sorted(glob.glob(os.path.join(data_dir, '*A04Z05*.tif'), recursive=True))
inputZ06_path = sorted(glob.glob(os.path.join(data_dir, '*A04Z06*.tif'), recursive=True))
inputZ07_path = sorted(glob.glob(os.path.join(data_dir, '*A04Z07*.tif'), recursive=True))

# 3 output channels
targetC01_path = sorted(glob.glob(os.path.join(data_dir, '*C01.tif'), recursive=True))
targetC02_path = sorted(glob.glob(os.path.join(data_dir, '*C02.tif'), recursive=True))
targetC03_path = sorted(glob.glob(os.path.join(data_dir, '*C03.tif'), recursive=True))

# split training/validation
inputZ01, inputZ01_val = split_train_val(inputZ01_path)
inputZ02, inputZ02_val = split_train_val(inputZ02_path)
inputZ03, inputZ03_val = split_train_val(inputZ03_path)
inputZ04, inputZ04_val = split_train_val(inputZ04_path)
inputZ05, inputZ05_val = split_train_val(inputZ05_path)
inputZ06, inputZ06_val = split_train_val(inputZ06_path)
inputZ07, inputZ07_val = split_train_val(inputZ07_path)

targetC01, targetC01_val = split_train_val(targetC01_path)
targetC02, targetC02_val = split_train_val(targetC02_path)
targetC03, targetC03_val = split_train_val(targetC03_path)

In [8]:
trans_val = Compose(
    [
        LoadImage(PILReader(), image_only=True),
        AddChannel(),
        CenterSpatialCrop(roi_size=256),
        ToTensor()
    ]
)

In [9]:
val_dataset = FlowArrayDataset.FlowArrayDataset(
    inputZ01=inputZ01_val, inputZ01_transform=trans_val,
    inputZ02=inputZ02_val, inputZ02_transform=trans_val,
    inputZ03=inputZ03_val, inputZ03_transform=trans_val,
    inputZ04=inputZ04_val, inputZ04_transform=trans_val,
    inputZ05=inputZ05_val, inputZ05_transform=trans_val,
    inputZ06=inputZ06_val, inputZ06_transform=trans_val,
    inputZ07=inputZ07_val, inputZ07_transform=trans_val,
    targetC01=targetC01_val, targetC01_transform=trans_val,
    targetC02=targetC02_val, targetC02_transform=trans_val,
    targetC03=targetC03_val, targetC03_transform=trans_val
)

In [10]:
validation_loader = DataLoader(
    val_dataset,
    batch_size=1
    #num_workers=4 #multiprocessing.cpu_count(),
)

In [11]:
netG = GeneratorUnet().to(device)
netG.eval()

GeneratorUnet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.1)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.1)
    )
  )
  (down1): Down(
    (pool_conv): Sequential(
      (0): AvgPool2d(kernel_size=3, stride=2, padding=0)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(224, 448, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.1)
          (3): Conv2d(448, 448, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(448, eps=1e

In [12]:
load_weight_dir = "checkpoints/tempcheckpoint/G_epoch_700.pth"
save_dir = "testresults/temptestresults"


if not os.path.exists(save_dir):
    os.makedirs(save_dir)


In [13]:
print(f'Loading checkpoint: {load_weight_dir}')
path_to_saved_weight = os.path.join(load_weight_dir)
checkpoint = torch.load(path_to_saved_weight)  # when you are loading weights saved on gpu device
netG.load_state_dict(checkpoint['model_state_dict'])

Loading checkpoint: checkpoints/tempcheckpoint/G_epoch_700.pth


<All keys matched successfully>

In [14]:
PSNRs = []
SSIMs = []

mseloss = nn.MSELoss()
    
with torch.no_grad():
    for batch_index, batch in enumerate(tqdm(validation_loader)):
        inputZ01, inputZ02, inputZ03, inputZ04, inputZ05, inputZ06, inputZ07 = \
            batch[0].to(device), batch[1].to(device), batch[2].to(device),\
            batch[3].to(device), batch[4].to(device), batch[5].to(device), batch[6].to(device)
        
        targetC01, targetC02, targetC03 = batch[7].to(device), batch[8].to(device), batch[9].to(device)
        
        # now compute output
        outputC01, outputC02, outputC03 = netG(inputZ01, inputZ02, inputZ03, inputZ04, inputZ05, inputZ06, inputZ07)
        
        # now compute loss/metric
        lossC01 = mseloss(outputC01, targetC01)
        lossC02 = mseloss(outputC02, targetC02)
        lossC03 = mseloss(outputC03, targetC03)
        content_loss = lossC01 + lossC02 + lossC03
        
        outputCs = torch.cat((outputC01, outputC02, outputC03), dim=1)
        targetCs = torch.cat((targetC01, targetC02, targetC03), dim=1)
        
        psnr = 10 * log10(1 / nn.MSELoss()(outputCs, targetCs))
        ssim = pytorch_ssim.ssim(outputCs, targetCs)
        
        PSNRs.append(psnr)
        SSIMs.append(ssim.item())
        
        # now save each channels as tif
        outputC01_image = ToPILImage()(outputC01[0].data.cpu())
        outputC01_image.save(os.path.join(save_dir,targetC01_val[batch_index][17:]))
        
        outputC02_image = ToPILImage()(outputC02[0].data.cpu())
        outputC02_image.save(os.path.join(save_dir,targetC02_val[batch_index][17:]))
        
        outputC03_image = ToPILImage()(outputC03[0].data.cpu())
        outputC03_image.save(os.path.join(save_dir,targetC03_val[batch_index][17:]))
        
        # now save RGB image as tiff 
        
        targetCs_name = f'targetCs_{batch_index}.png'
        outputCs_name = f'outputCs_{batch_index}.png'
        
        outputCs_image = ToPILImage()(outputCs[0].data.cpu())
        outputCs_image.save(os.path.join(save_dir,outputCs_name))
        
        targetCs_image = ToPILImage()(targetCs[0].data.cpu())
        targetCs_image.save(os.path.join(save_dir,targetCs_name))
        
        
        # now save RGB image as tiff 
        #outputCs_save = outputCs.squeeze(0).detach().cpu().numpy() # killing batch dimension
        #targetCs_save = targetCs.squeeze(0).detach().cpu().numpy()
        
        #outputCs_save = np.moveaxis(outputCs_save, 0, -1) # channel frist to channel last
        #targetCs_save = np.moveaxis(targetCs_save, 0, -1)
        
    ssim_mean = np.array(SSIMs).mean()
    psnr_mean = np.array(PSNRs).mean()
    
    print("SSIM:", ssim_mean, "PSNR:", psnr_mean)



100%|██████████| 63/63 [00:19<00:00,  3.75it/s]

SSIM: 0.02202738087535614 PSNR: -69.68706780364853



