#Initialization

In [None]:
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import sampler
import torchvision.datasets as dset
from torch.autograd import Variable

import os
import time
from google.colab.patches import cv2_imshow
import pickle
import cv2
import numpy as np
import glob
from PIL import Image
import shutil
import random
 
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab
from skimage.io import imsave
from skimage.metrics import structural_similarity as ssim
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torchsummary import summary
 
from itertools import product
from math import log10, sqrt
 
dtype = torch.cuda.FloatTensor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available()==False:
  dtype=torch.FloatTensor
print(device,dtype) 

#Set Config

In [None]:
class set_config:
  
  def __init__(self):
    self.cuda=torch.cuda.is_available()
    self.weight_decay=0 
    self.lr=1e-3 # learning rate
    self.test_img_name=None 
    self.batch_size=8 
    self.mode='train'  # 3 modes, 'train', 'test', 'detect'
    self.resume=False # if resume is set to true, resumes training from pretrained weights.
    self.dir='drive/MyDrive/Projects/K-Parameter/' # path to root directory
    self.param_weight_path=glob.glob(self.dir+'weights/latest_KParameter*')  # path to param-net model weights
    assert len(self.param_weight_path)<=1, "Multiple Param weight files detected."
    if len(self.param_weight_path)==1:
      self.resume=True
      self.param_weight_path=self.param_weight_path[0]
    else:
      self.param_weight_path=None
    self.end_weight_path=glob.glob(self.dir+'weights/latest_END*') # path to end(autoencoder) model weights
    assert len(self.end_weight_path)<=1, "Multiple END weight files detected."
    if len(self.end_weight_path)==1:
      self.end_weight_path=self.end_weight_path[0]
    else:
      self.end_weight_path=None
    self.train_file_name="Train_1500" # folder name in 'dataset/' folder to training images
    self.test_file_name="brainweb images" # folder name in 'dataset/' folder to training images

config=set_config()
if config.param_weight_path is not None:
  print("Param Weight file detected.",config.param_weight_path)
else:
  print("No Param weight file detected.")
if config.end_weight_path is not None:
  print("END Weight file detected.",config.end_weight_path)
else:
  print("No END weight file detected.")
  #assert False, "No END weight file detected."

#Models

In [None]:
class END(nn.Module):

  def __init__(self):
    
    super().__init__()
    self.conv1=nn.Conv2d(1,64,3,stride=1,padding=1,bias=False)
    self.relu1=nn.LeakyReLU(0.2)

    self.conv2=nn.Conv2d(64,64,3,stride=2,padding=1,bias=False)
    self.bn2=nn.BatchNorm2d(64,momentum=0.5)
    self.relu2=nn.LeakyReLU(0.2)

    self.conv3=nn.Conv2d(64,128,3,stride=2,padding=1,bias=False)
    self.bn3=nn.BatchNorm2d(128,momentum=0.5)
    self.relu3=nn.LeakyReLU(0.2)

    self.conv4=nn.Conv2d(128,256,3,stride=2,padding=1,bias=False)
    self.bn4=nn.BatchNorm2d(256,momentum=0.5)
    self.relu4=nn.LeakyReLU(0.2)

    self.conv5=nn.Conv2d(256,512,3,stride=2,padding=1,bias=False)
    self.bn5=nn.BatchNorm2d(512,momentum=0.5)
    self.relu5=nn.LeakyReLU(0.2)

    self.conv6=nn.Conv2d(512,512,3,stride=2,padding=1,bias=False)
    self.bn6=nn.BatchNorm2d(512,momentum=0.5)
    self.relu6=nn.LeakyReLU(0.2)

    self.conv7=nn.Conv2d(512,512,3,stride=2,padding=1,bias=False)
    self.bn7=nn.BatchNorm2d(512,momentum=0.5)
    self.relu7=nn.LeakyReLU(0.2)

    self.conv8=nn.Conv2d(512,512,3,stride=2,padding=1,bias=False)
    self.bn8=nn.BatchNorm2d(512,momentum=0.5)
    self.relu8=nn.LeakyReLU(0.2)

    self.conv9=nn.ConvTranspose2d(512,512,3,stride=2,padding=1,output_padding=1,bias=False)
    self.bn9=nn.BatchNorm2d(512,momentum=0.5)
    self.relu9=nn.ReLU()
      
    self.conv10=nn.ConvTranspose2d(512,512,3,stride=2,padding=1,output_padding=1,bias=False)
    self.bn10=nn.BatchNorm2d(512,momentum=0.5)
    self.relu10=nn.ReLU()

    self.conv11=nn.ConvTranspose2d(512,512,3,stride=2,padding=1,output_padding=1,bias=False)
    self.bn11=nn.BatchNorm2d(512,momentum=0.5)
    self.relu11=nn.ReLU()

    self.conv12=nn.ConvTranspose2d(512,256,3,stride=2,padding=1,output_padding=1,bias=False)
    self.bn12=nn.BatchNorm2d(256,momentum=0.5)
    self.relu12=nn.ReLU()

    self.conv13=nn.ConvTranspose2d(256,128,3,stride=2,padding=1,output_padding=1,bias=False)
    self.bn13=nn.BatchNorm2d(128,momentum=0.5)
    self.relu13=nn.ReLU()

    self.conv14=nn.ConvTranspose2d(128,64,3,stride=2,padding=1,output_padding=1,bias=False)
    self.bn14=nn.BatchNorm2d(64,momentum=0.5)
    self.relu14=nn.ReLU()

    self.conv15=nn.ConvTranspose2d(64,32,3,stride=2,padding=1,output_padding=1,bias=False)
    self.bn15=nn.BatchNorm2d(32,momentum=0.5)
    self.relu15=nn.ReLU()
      
    self.conv16=nn.Conv2d(32,16,3,stride=1,padding=1,bias=False)
    self.bn16=nn.BatchNorm2d(16,momentum=0.5)
    self.relu16=nn.ReLU()

    self.conv17=nn.Conv2d(16,1,1,stride=1,bias=False)

  def forward(self,img):

    if self.training:

      x1=self.relu1(self.conv1(img))

      x2=self.relu2(self.bn2(self.conv2(x1)))

      x3=self.relu3(self.bn3(self.conv3(x2)))

      x4=self.relu4(self.bn4(self.conv4(x3)))

      x5=self.relu5(self.bn5(self.conv5(x4)))

      x6=self.relu6(self.bn6(self.conv6(x5)))

      x7=self.relu7(self.bn7(self.conv7(x6)))

      x8=self.relu8(self.bn8(self.conv8(x7)))

      x9=self.relu9(self.bn9(self.conv9(x8)))

      x10=self.relu10(self.bn10(self.conv10(x9)))

      x11=self.relu11(self.bn11(self.conv11(x10)))

      x12=self.relu12(self.bn12(self.conv12(x11)))

      x13=self.relu13(self.bn13(self.conv13(x12)))

      x14=self.relu14(self.bn14(self.conv14(x13)))

      x15=self.relu15(self.bn15(self.conv15(x14)))

      x16=self.relu16(self.bn16(self.conv16(x15)))
        
      x17=self.conv17(x16)
      x17=torch.tanh(x17)
      
      return x17
    
    else:

      outs=[]

      x1=self.relu1(self.conv1(img))
      outs.append(x1)

      x2=self.relu2(self.bn2(self.conv2(x1)))
      outs.append(x2)

      x3=self.relu3(self.bn3(self.conv3(x2)))
      outs.append(x3)

      x4=self.relu4(self.bn4(self.conv4(x3)))
      outs.append(x4)

      x5=self.relu5(self.bn5(self.conv5(x4)))
      outs.append(x5)

      x6=self.relu6(self.bn6(self.conv6(x5)))
      outs.append(x6)

      x7=self.relu7(self.bn7(self.conv7(x6)))
      outs.append(x7)

      x8=self.relu8(self.bn8(self.conv8(x7)))
      outs.append(x8)

      return outs
  
  def init_weights(self):
    for name,module in self.named_modules():
      if isinstance(module,nn.Conv2d) or isinstance(module,nn.ConvTranspose2d):
        nn.init.xavier_uniform_(module.weight.data)
        if module.bias is not None:
          module.bias.data.zero_()

In [None]:
class BasicBlock(nn.Module):

  def __init__(self,in_channels,out_channels):
    super().__init__()

    self.in_channels=in_channels # no of channels of input
    self.out_channels=out_channels # required no of channels of output
    # as each block of param-net has different input and output size, therefore we had to generalize the no of input and output channels

    self.conv_t=nn.ConvTranspose2d(in_channels,in_channels,3,stride=2,padding=1,output_padding=1,bias=False)  ##2x*2x*in_channels
    self.conv1=nn.Conv2d(2*in_channels,in_channels,3,stride=1,padding=1,bias=False)                           ##2x*2x*in_channels
    self.in1=nn.InstanceNorm2d(in_channels)
    self.relu1=nn.LeakyReLU(0.2)
    self.conv2=nn.Conv2d(2*in_channels,in_channels,3,stride=1,padding=1,bias=False)                           ##2x*2x*in_channels
    self.in2=nn.InstanceNorm2d(in_channels)
    self.relu2=nn.LeakyReLU(0.2)
    self.conv3=nn.Conv2d(in_channels,out_channels,3,stride=1,padding=1,bias=False)                           ##2x*2x*out_channels
    self.in3=nn.InstanceNorm2d(out_channels)
    self.relu3=nn.LeakyReLU(0.2)

  def forward(self,input,brain_features,params):

    # brain features: image features from END model(autoencoder model)
    # params : parameters of required output
    
    x1=self.conv_t(input)
    x1=torch.cat([x1,brain_features],dim=1)

    x2=self.conv1(x1)
    x2=self.in1(x2)
    x2=self.relu1(x2)
    x2=torch.cat([x2,params[:,0:x2.shape[1],0:x2.shape[2],0:x2.shape[3]]],dim=1)

    x3=self.conv2(x2)
    x3=self.in2(x3)
    x3=self.relu2(x3)

    x4=self.conv3(x3)
    x4=self.in3(x4)
    x4=self.relu3(x4)

    return x4
  
  def init_weights(self):
    for name,module in self.named_modules():
      if isinstance(module,nn.Conv2d) or isinstance(module,nn.ConvTranspose2d):
        nn.init.xavier_uniform_(module.weight.data)
        if module.bias is not None:
          module.bias.data.zero_()

In [None]:
class ParamNet(nn.Module):

  def __init__(self):
    super().__init__()

    self.block1 = BasicBlock(in_channels=512,out_channels=512)          ## 2*2*512
    self.block2 = BasicBlock(in_channels=512,out_channels=512)          ## 4*4*512
    self.block3 = BasicBlock(in_channels=512,out_channels=512)          ## 8*8*512
    self.block4 = BasicBlock(in_channels=512,out_channels=256)          ## 16*16*256
    self.block5 = BasicBlock(in_channels=256,out_channels=128)          ## 32*32*128
    self.block6 = BasicBlock(in_channels=128,out_channels=64)           ## 64*64*64
    self.block7 = BasicBlock(in_channels=64,out_channels=64)            ## 128*128*64
    self.block8 = BasicBlock(in_channels=64,out_channels=32)            ## 256*256*32
    self.conv1 = nn.Conv2d(32,16,3,stride=1,padding=1,bias=False)       ## 256*256*16
    self.in1=nn.InstanceNorm2d(16)
    self.relu1=nn.LeakyReLU(0.2)
    self.conv2 = nn.Conv2d(16,1,3,stride=1,padding=1,bias=False)        ## 256*256*1
  
  def forward(self,brain_features,params):
    
    input=params[:,0:512,0:1,0:1]
    
    out1=self.block1(input,brain_features.pop(-1),params)
    
    out2=self.block2(out1,brain_features.pop(-1),params)

    out3=self.block3(out2,brain_features.pop(-1),params)

    out4=self.block4(out3,brain_features.pop(-1),params)

    out5=self.block5(out4,brain_features.pop(-1),params)

    out6=self.block6(out5,brain_features.pop(-1),params)

    out7=self.block7(out6,brain_features.pop(-1),params)

    out8=self.block8(out7,brain_features.pop(-1),params)

    out9=self.relu1(self.in1(self.conv1(out8)))

    out10=self.conv2(out9)

    ou10=torch.tanh(out10)

    return out10
    
  def init_weights(self):
    for name,module in self.named_modules():
      if isinstance(module,nn.Conv2d) or isinstance(module,nn.ConvTranspose2d):
        nn.init.xavier_uniform_(module.weight.data)
        if module.bias is not None:
          module.bias.data.zero_()
    

#DataLoader

In [None]:
############ Unzip data from drive ################################

shutil.unpack_archive(config.dir+"dataset/"+config.train_file_name+".zip","")
shutil.unpack_archive(config.dir+"dataset/"+config.test_file_name+".zip","")
shutil.unpack_archive(config.dir+"dataset/Default.zip","")
train_list=glob.glob(config.train_file_name+'/*')
test_list=glob.glob(config.test_file_name+'/*')
default_list=glob.glob('Default/*')
print(len(train_list),len(test_list),len(default_list))

###################################################################

1548 5 24


In [None]:
class MRIDataLoader(data.Dataset):

    def __init__(self,config,mode='train'):
      
      self.mode=mode
      self.config=config
      if self.mode=='train':
        self.data_path=config.train_file_name+'/'
      
      elif self.mode=='test':
        self.data_path=config.test_file_name+'/'
        self.train_data_path=config.train_file_name+'/'
      
      if self.mode=='train' or self.mode=='test':
        self.img_path_list=glob.glob(self.data_path+'T*')
      
    def __getitem__(self,index):
      if self.mode=='train' or self.mode=='test':
        
        input_image_index = random.randint(0,len(self.img_path_list)-1)
        input_name=self.img_path_list[input_image_index].split('/')[1][:-4]
        input_values=input_name.split(',')
        input_param_tr = float(input_values[0].split('=')[1])
        input_param_te = float(input_values[1].split('=')[1])
        input_slice_no = int(input_values[2].split('=')[1])
        input_slice_plane = input_values[3].split('=')[1]

        output_list = glob.glob(self.data_path+str(input_slice_no)+'/*')
        output_image_index = random.randint(0,len(output_list)-1)
        output_name = output_list[output_image_index].split('/')[2][:-4]
        output_values = output_name.split(',')
        output_param_tr = float(output_values[0].split('=')[1])
        output_param_te = float(output_values[1].split('=')[1])
        output_slice_no = int(output_values[2].split('=')[1])
        output_slice_plane = output_values[3].split('=')[1]

        input_image=cv2.imread(self.img_path_list[input_image_index],cv2.IMREAD_GRAYSCALE)
        output_image=cv2.imread(output_list[output_image_index],cv2.IMREAD_GRAYSCALE)

        input_image=cv2.resize(input_image, (256,256), interpolation=cv2.INTER_LINEAR)
        output_image=cv2.resize(output_image, (256,256), interpolation=cv2.INTER_LINEAR)

        input_image=input_image.astype(np.float64)
        input_image/=255.0
        input_image=torch.from_numpy(input_image).unsqueeze(0)
        mean=torch.Tensor([0.5])
        input_image=input_image-mean.expand_as(input_image)
        input_image=input_image*2

        output_image=output_image.astype(np.float64)
        output_image/=255.0
        output_image=torch.from_numpy(output_image).unsqueeze(0)
        mean=torch.Tensor([0.5])
        output_image=output_image-mean.expand_as(output_image)
        output_image=output_image*2

        return input_image,output_image,torch.tensor([input_param_te,input_param_tr,output_param_te,output_param_tr]) # input image, ground truth image, parameters of input image and ground truth image
      
      elif self.mode=='detect':

        input_image_path= 'brainweb images/TR=4.5,TE=0.05.jpg'
        output_image_path= 'brainweb images/TR=8.0,TE=0.12.jpg' # if output image is not known, give path of a random image and ignore PSNR and MAE in that case
        input_param_tr=float(4.5)
        input_param_te=float(0.05)
        output_param_tr=float(8.0)
        output_param_te=float(0.12)

        input_image=cv2.imread(input_image_path,cv2.IMREAD_GRAYSCALE)
        output_image=cv2.imread(output_image_path,cv2.IMREAD_GRAYSCALE)

        input_image=cv2.resize(input_image, (256,256), interpolation=cv2.INTER_LINEAR)
        output_image=cv2.resize(output_image, (256,256), interpolation=cv2.INTER_LINEAR)

        input_image=input_image.astype(np.float64)
        input_image/=255.0
        input_image=torch.from_numpy(input_image).unsqueeze(0)
        mean=torch.Tensor([0.5])
        input_image=input_image-mean.expand_as(input_image)
        input_image=input_image*2

        output_image=output_image.astype(np.float64)
        output_image/=255.0
        output_image=torch.from_numpy(output_image).unsqueeze(0)
        mean=torch.Tensor([0.5])
        output_image=output_image-mean.expand_as(output_image)
        output_image=output_image*2

        return input_image,output_image,torch.tensor([input_param_te,input_param_tr,output_param_te,output_param_tr]) # input image, ground truth image, parameters of input image and ground truth image

      else:
        assert False, "Unrecognised Mode Detected."

    def __len__(self):
      if self.mode=='detect':
        return 1
      return len(self.img_path_list)

In [None]:
def train_collate(batch):

  input_image_list,final_image_list,params_list=[],[],[]
  for i,sample in enumerate(batch):
    input_image_list.append(sample[0])
    final_image_list.append(sample[1])
    params_list.append(sample[2])


  input_images=torch.stack(input_image_list)
  final_images=torch.stack(final_image_list)
  params=torch.stack(params_list)

  return input_images,final_images,params

In [None]:
def show_image(img): # to display the output image
  img=img.cpu().numpy()
  img=img/2 + 0.5
  img=img.transpose(1,2,0).squeeze(-1)
  img*=255.0
  print(img.shape)
  cv2_imshow(img)

def show_diff_image(req_image,gen_image): # shows difference image between generated image and ground truth output
  req_image=req_image.cpu().numpy()
  req_image=req_image/2 + 0.5
  req_image=req_image.transpose(1,2,0).squeeze(-1)
  req_image*=255.0

  gen_image=gen_image.cpu().numpy()
  gen_image=gen_image/2 + 0.5
  gen_image=gen_image.transpose(1,2,0).squeeze(-1)
  gen_image*=255.0

  diff_image=abs(gen_image-req_image)
  min_value,max_value,mean_value=np.min(diff_image),np.max(diff_image),np.mean(diff_image)
  print(min_value,max_value,mean_value,diff_image.shape)
  diff_image=(diff_image-min_value)/(max_value-min_value)
  diff_image*=255.0
  cv2_imshow(diff_image)
  return min_value,max_value,mean_value

def save_weights(state,step_no): #deletes previous weight file and saves latest file for param-net model
  weight=glob.glob(config.dir+'weights/latest_KParameter*')
  assert len(weight)<=1, "Multiple weights file, delete others."
  if weight:
    open(weight[0], 'w').close()
    os.remove(weight[0])
  print("Saving weights as latest_KParameter_"+str(step_no))
  torch.save(state,config.dir+"weights/latest_KParameter_"+str(step_no)+".pth.tar")

def PSNR_and_SSIM(imageA, imageB): # calculate psnr and ssim between generated image and ground truth output image

    # todo: check if code for ssim calculation is correct
    
    imageA=imageA.cpu().numpy()
    imageB=imageB.cpu().numpy()
    imageA=imageA.transpose(0,2,3,1)
    imageB=imageB.transpose(0,2,3,1)
    imageA=imageA/2 + 0.5
    imageA*=255.0
    imageB=imageB/2 + 0.5
    imageB*=255.0
    imageA=np.clip(imageA,a_min=0.0,a_max=255.0)
    imageB=np.clip(imageB,a_min=0.0,a_max=255.0)

    mse = np.mean((imageA - imageB) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    ssim_final=0
    for i in range(imageA.shape[0]):
      ssim_temp = ssim(imageA[i],imageB[i],multichannel=True)
      ssim_final+=ssim_temp
    
    #ssim_final/=imageA.shape[0]
    return psnr,ssim_final

#Training

In [None]:
# initialize autoencoder and param-net
end_model=END().to(device)
param_model=ParamNet().to(device)

param_model.train()
end_model.eval()

config.mode="train"

# initialize dataloader and optimizer
dataset=MRIDataLoader(config,config.mode)
optimizer_param=optim.Adam(param_model.parameters(),lr=config.lr,betas=(0.5, 0.999))
train_loader=DataLoader(dataset,config.batch_size,shuffle=True,collate_fn=train_collate)

# check if pre-trained weights exists, initialize accordingly
config.param_weight_path=glob.glob(config.dir+'weights/latest_KParameter*')
assert len(config.param_weight_path)<=1, "Multiple weight files detected."
if len(config.param_weight_path)==1:
  config.resume=True
  config.param_weight_path=config.param_weight_path[0]
else:
  config.param_weight_path=None
  config.resume=False


config.end_weight_path=glob.glob(config.dir+'weights/latest_END*')
assert len(config.end_weight_path)<=1, "Multiple END weight files detected."
if len(config.end_weight_path)==1:
  config.end_weight_path=config.end_weight_path[0]
  checkpoint_end=torch.load(config.end_weight_path)
  end_model.load_state_dict(checkpoint_end['end_model'])
  print("END weight file found",config.end_weight_path)
else:
  assert False,"No END weight file found"

# if pre-trained weights for param-net were found previously, load the previous state otherwise initialize weights
if config.resume:
  checkpoint=torch.load(config.param_weight_path)
  param_model.load_state_dict(checkpoint['param_model'])
  optimizer_param.load_state_dict(checkpoint['optimizer_param'])
  print("Resuming training with",config.param_weight_path)
else:
  param_model.init_weights()
  print("No Param weights Found, Weights initilaized")

In [None]:
torch.autograd.set_detect_anomaly(True)

training=True
step=1
if config.resume:
  step=checkpoint['step']+1
L1=nn.L1Loss()
MSE=nn.MSELoss()
BCE=nn.BCELoss()
time_last=time.time()

while training:
  for i,(input_images,final_images,params) in enumerate(train_loader):

    input_images=Variable(input_images.cuda().type(dtype))
    final_images=Variable(final_images.cuda().type(dtype))
    params=Variable(params.cuda().type(dtype))

    # convert parameters to image shape, and stack them together
    input_param_te=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,0].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()
    input_param_tr=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()
    output_param_te=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,2].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()
    output_param_tr=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,3].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()

    param_input=torch.zeros(input_images.size(0),512,256,256).cuda().type(dtype)
    for j in range(512):
      if j%4==0:
        param_input[:,j,:,:]=input_param_te.squeeze(1)
      elif j%4==1:
        param_input[:,j,:,:]=input_param_tr.squeeze(1)
      elif j%4==2:
        param_input[:,j,:,:]=output_param_te.squeeze(1)
      else:
        param_input[:,j,:,:]=output_param_tr.squeeze(1)

    param_model.zero_grad()

    # extract image features of input image from auto-encoder and feed them to param-net along with the parameters
    brain_features=end_model(input_images)
    
    for j,features in enumerate(brain_features):
      brain_features[j]=brain_features[j].detach()
    param_input=param_input.detach()

    out_images=param_model(brain_features,param_input)

    loss=MSE(out_images.view(out_images.size(0) ,-1),final_images.view(final_images.size(0),-1))

    loss.backward()

    optimizer_param.step()

    this_time=time.time()
    if i%10==0:
      print("Batch No -",i,"Completed with time",this_time-time_last,".Loss =",loss.item())
    time_last=time.time()
  
  # save state of model after every epoch
  print("Epoch ",step," done.")
  state={'step':step,
         'param_model':param_model.state_dict(),
         'optimizer_param':optimizer_param.state_dict()}

  if step>7 or step==3:
    save_weights(state,step)
  with torch.no_grad():
        print("Input -")
        show_image(input_images[0])
        print("Required -")
        show_image(final_images[0])
        print("Generated -")
        show_image(out_images[0])
        print("Input TE =",params[0,0],"Input TR =",params[0,1],"Output TE =",params[0,2],"Output TR =",params[0,3])
  step+=1
  if step==19:
    break


#Testing

In [None]:
# check for latest weights
config.param_weight_path=glob.glob(config.dir+'weights/latest_KParameter_16*')
assert len(config.param_weight_path)<=1, "Multiple weight files detected."
if len(config.param_weight_path)==1:
  config.param_weight_path=config.param_weight_path[0]
  print("Param weight file found",config.param_weight_path)
else:
  assert False,"No Param Weight file found."

config.end_weight_path=glob.glob(config.dir+'weights/latest_END*')
assert len(config.end_weight_path)<=1, "Multiple END weight files detected."
if len(config.end_weight_path)==1:
  config.end_weight_path=config.end_weight_path[0]
  print("END weight file found",config.end_weight_path)
else:
  assert False,"No END weight file found"

#initialize models
end_model=END().to(device)
param_model=ParamNet().to(device)

end_model.eval()
param_model.eval()

#load weights
checkpoint_end=torch.load(config.end_weight_path)
end_model.load_state_dict(checkpoint_end['end_model'])

checkpoint_param=torch.load(config.param_weight_path)
param_model.load_state_dict(checkpoint_param['param_model'])
final_loss,final_psnr,final_ssim,total_no=0.0,0.0,0.0,0
psnr_values=[]
min_diffs,max_diffs,mean_diffs,tot1=[],[],[],0

MSE=nn.MSELoss()

with torch.no_grad():

  #initialize dataloader
  test_data=MRIDataLoader(config,'test')
  test_loader=DataLoader(test_data,config.batch_size,shuffle=False,collate_fn=train_collate)
  time_last=time.time()
  for i,(input_images,final_images,params) in enumerate(test_loader): 

    input_images=Variable(input_images.cuda().type(dtype))
    final_images=Variable(final_images.cuda().type(dtype))
    params=Variable(params.cuda().type(dtype))

    # convert parameters to image shape, and stack them together
    input_param_te=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,0].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()
    input_param_tr=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()
    output_param_te=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,2].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()
    output_param_tr=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,3].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()

    param_input=torch.zeros(input_images.size(0),512,256,256).cuda().type(dtype)
    for j in range(512):
      if j%4==0:
        param_input[:,j,:,:]=input_param_te.squeeze(1)
      elif j%4==1:
        param_input[:,j,:,:]=input_param_tr.squeeze(1)
      elif j%4==2:
        param_input[:,j,:,:]=output_param_te.squeeze(1)
      else:
        param_input[:,j,:,:]=output_param_tr.squeeze(1)

    # extract image features of input image from auto-encoder and feed them to param-net along with the parameters
    brain_features=end_model(input_images)
    
    for j,features in enumerate(brain_features):
      brain_features[j]=brain_features[j].detach()
    param_input=param_input.detach()

    out_images=param_model(brain_features,param_input)
    loss=MSE(out_images.view(out_images.size(0) ,-1),final_images.view(final_images.size(0),-1))

    # calculate psnr and ssim
    psnr_temp,ssim_temp=PSNR_and_SSIM(final_images,out_images)
    psnr_values.append(psnr_temp)
    final_psnr+=psnr_temp
    final_ssim+=ssim_temp
    final_loss+=loss
    total_no+=1

    this_time=time.time()
    print("Batch No -",i,"Completed with time",this_time-time_last,".Loss =",loss.item())
    time_last=time.time()

    # display inputimage, ground truth image, generated image and difference image(|ground truth - generated image|)
    if i%10==0:
      print("Input -")
      show_image(input_images[0])
      print("Required -")
      show_image(final_images[0])
      print("Generated -")
      show_image(out_images[0])
      print("Differnce Image -")
      curr_min,curr_max,curr_mean=show_diff_image(final_images[0],out_images[0])
      min_diffs.append(curr_min)
      max_diffs.append(curr_max)
      mean_diffs.append(curr_mean)
      tot1+=1
      print("Input TE =",params[0,0],"Input TR =",params[0,1])
      print("Output TE =",params[0,2],"Output TR =",params[0,3])
      print("Slice No =",params[0,4])

# display average psnr, avg pixel difference
final_loss/=total_no
final_psnr/=total_no
final_ssim/=total_no
print("Final Loss =",final_loss,"Final PSNR =",final_psnr,"Final SSIM =",final_ssim)
print("PSNR Array =",psnr_values)
avg_min_diff=np.sum(min_diffs)/tot1
avg_max_diff=np.sum(max_diffs)/tot1
avg_mean_diff=np.sum(mean_diffs)/tot1
print("Avg Min Difference =",avg_min_diff,"Avg Max Difference =",avg_max_diff,"Avg Mean Difference =",avg_mean_diff)

# display mean and std dev of pixel differences
mean = sum(mean_diffs) / len(mean_diffs)
variance = sum([((x - mean) ** 2) for x in mean_diffs]) / len(mean_diffs)
res = variance ** 0.5
print("Mean =",mean,"Res =",res)

# display mean and std dev of psnr
mean = sum(psnr_values) / len(psnr_values)
variance = sum([((x - mean) ** 2) for x in psnr_values]) / len(psnr_values)
res = variance ** 0.5
print("Mean =",mean,"Res =",res)

#Detection

In [None]:
# check for latest weights
config.param_weight_path=glob.glob(config.dir+'weights/latest_KParameter_16*')
assert len(config.param_weight_path)<=1, "Multiple weight files detected."
if len(config.param_weight_path)==1:
  config.param_weight_path=config.param_weight_path[0]
  print("Param weight file found",config.param_weight_path)
else:
  assert False,"No Param Weight file found."

config.end_weight_path=glob.glob(config.dir+'weights/latest_END*')
assert len(config.end_weight_path)<=1, "Multiple END weight files detected."
if len(config.end_weight_path)==1:
  config.end_weight_path=config.end_weight_path[0]
  print("END weight file found",config.end_weight_path)
else:
  assert False,"No END weight file found"

#initialize models
end_model=END().to(device)
param_model=ParamNet().to(device)

end_model.eval()
param_model.eval()

#load weights
checkpoint_end=torch.load(config.end_weight_path)
end_model.load_state_dict(checkpoint_end['end_model'])

checkpoint_param=torch.load(config.param_weight_path)
param_model.load_state_dict(checkpoint_param['param_model'])
final_loss,final_psnr,final_ssim,total_no=0.0,0.0,0.0,0
psnr_values=[]

MSE=nn.MSELoss()

with torch.no_grad():

  #initialize dataloader
  test_data=MRIDataLoader(config,'detect')
  test_loader=DataLoader(test_data,config.batch_size,shuffle=True,collate_fn=train_collate)
  time_last=time.time()
  for i,(input_images,final_images,params) in enumerate(test_loader):

    input_images=Variable(input_images.cuda().type(dtype))
    final_images=Variable(final_images.cuda().type(dtype))
    params=Variable(params.cuda().type(dtype))

    # convert parameters to image shape, and stack them together
    input_param_te=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,0].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()
    input_param_tr=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()
    output_param_te=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,2].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()
    output_param_tr=(torch.zeros(input_images.size(0),1,256,256).cuda()+params[:,3].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).detach()

    param_input=torch.zeros(input_images.size(0),512,256,256).cuda().type(dtype)
    for j in range(512):
      if j%4==0:
        param_input[:,j,:,:]=input_param_te.squeeze(1)
      elif j%4==1:
        param_input[:,j,:,:]=input_param_tr.squeeze(1)
      elif j%4==2:
        param_input[:,j,:,:]=output_param_te.squeeze(1)
      else:
        param_input[:,j,:,:]=output_param_tr.squeeze(1)

    # extract image features of input image from auto-encoder and feed them to param-net along with the parameters
    brain_features=end_model(input_images)
    
    for j,features in enumerate(brain_features):
      brain_features[j]=brain_features[j].detach()
    param_input=param_input.detach()

    out_images=param_model(brain_features,param_input)
    loss=MSE(out_images.view(out_images.size(0) ,-1),final_images.view(final_images.size(0),-1))

    # calculate psnr and ssim
    psnr_temp,ssim_temp=PSNR_and_SSIM(final_images,out_images)
    psnr_values.append(psnr_temp)
    final_psnr+=psnr_temp
    final_ssim+=ssim_temp
    final_loss+=loss
    total_no+=1

    this_time=time.time()
    print("Batch No -",i,"Completed with time",this_time-time_last,".Loss =",loss.item())
    time_last=time.time()

    # display inputimage, ground truth image and generated image
    if i%10==0:
      print("Input -")
      show_image(input_images[0])
      print("Required -")
      show_image(final_images[0])
      print("Generated -")
      show_image(out_images[0])
      print("TE =",params[0,0],"TR =",params[0,1])

final_loss/=total_no
final_psnr/=total_no
final_ssim/=total_no
print("Final Loss =",final_loss,"Final PSNR =",final_psnr,"Final SSIM =",final_ssim)
print("PSNR Array =",psnr_values)
