In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from lib.utils.transform_3d import *
import torch
import numpy as np


def tensor_backto_unnormalization_image(input_image, mean, std):
  '''
  1. image = (image + 1) / 2.0
  2. image = image
  :param input_image: tensor whose size is (c,h,w) and channels is RGB
  :param imtype: tensor type
  :return:
     numpy (c,h,w)
  '''
  if isinstance(input_image, torch.Tensor):
    image_tensor = input_image.data
  else:
    return input_image
  image = image_tensor.data.cpu().float().numpy()
  image = image * std + mean
  return image


class CT_XRAY_Data_Augmentation(object):
  def __init__(self, opt=None):
    self.augment = List_Compose([
      (Permute((1,0,2)), None),

      (Resize_image(size=(opt.ct_channel, opt.fine_size, opt.fine_size)),
       Resize_image(size=(opt.xray_channel, opt.fine_size, opt.fine_size))),

      (Limit_Min_Max_Threshold(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]), None),

      (Normalization(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]),
       Normalization(opt.XRAY1_MIN_MAX[0], opt.XRAY1_MIN_MAX[1])),

      (Normalization_gaussian(opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1]),
       Normalization_gaussian(opt.XRAY1_MEAN_STD[0], opt.XRAY1_MEAN_STD[1])),

      # (Get_Key_slice(opt.select_slice_num), None),

      (ToTensor(), ToTensor())

    ])

  def __call__(self, img_list):
    '''
    :param img: PIL image
    :param boxes: numpy.ndarray
    :param labels: numpy.ndarray
    :return:
    '''
    return self.augment(img_list)

class CT_XRAY_Data_Test(object):
  def __init__(self, opt=None):
    self.augment = List_Compose([
      (Permute((1,0,2)), None),

      (Resize_image(size=(opt.ct_channel, opt.fine_size, opt.fine_size)),
       Resize_image(size=(opt.xray_channel, opt.fine_size, opt.fine_size))),

      (Limit_Min_Max_Threshold(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]), None),

      (Normalization(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]),
       Normalization(opt.XRAY1_MIN_MAX[0], opt.XRAY1_MIN_MAX[1])),

      (Normalization_gaussian(opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1]),
       Normalization_gaussian(opt.XRAY1_MEAN_STD[0], opt.XRAY1_MEAN_STD[1])),

      # (Get_Key_slice(opt.select_slice_num), None),

      (ToTensor(), ToTensor())

    ])

  def __call__(self, img):
    '''
    :param img: PIL image
    :param boxes: numpy.ndarray
    :param labels: numpy.ndarray
    :return:
    '''
    return self.augment(img)

class CT_XRAY_Data_AugmentationM(object):
  def __init__(self, opt=None):
    self.augment = List_Compose([
      (Permute((1,0,2)), None),

      (Resize_image(size=(opt.ct_channel, opt.resize_size, opt.resize_size)),
       Resize_image(size=(opt.xray_channel, opt.resize_size, opt.resize_size))),

      (List_Random_cropYX(size=(opt.fine_size, opt.fine_size)),),

      (List_Random_mirror(2), ),

      (Normalization(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]),
       Normalization(opt.XRAY1_MIN_MAX[0], opt.XRAY1_MIN_MAX[1])),

      (Normalization_gaussian(opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1]),
       Normalization_gaussian(opt.XRAY1_MEAN_STD[0], opt.XRAY1_MEAN_STD[1])),

      # (Get_Key_slice(opt.select_slice_num), None),

      (ToTensor(), ToTensor())

    ])

  def __call__(self, img_list):
    '''
    :param img: PIL image
    :param boxes: numpy.ndarray
    :param labels: numpy.ndarray
    :return:
    '''
    return self.augment(img_list)

class CT_XRAY_Data_TestM(object):
  def __init__(self, opt=None):
    self.augment = List_Compose([
      (Permute((1,0,2)), None),

      (Resize_image(size=(opt.ct_channel, opt.fine_size, opt.fine_size)),
       Resize_image(size=(opt.xray_channel, opt.fine_size, opt.fine_size))),

      (Normalization(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]),
       Normalization(opt.XRAY1_MIN_MAX[0], opt.XRAY1_MIN_MAX[1])),

      (Normalization_gaussian(opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1]),
       Normalization_gaussian(opt.XRAY1_MEAN_STD[0], opt.XRAY1_MEAN_STD[1])),

      # (Get_Key_slice(opt.select_slice_num), None),

      (ToTensor(), ToTensor())

    ])

  def __call__(self, img):
    '''
    :param img: PIL image
    :param boxes: numpy.ndarray
    :param labels: numpy.ndarray
    :return:
    '''
    return self.augment(img)

class CT_XRAY_Data_Augmentation_Multi(object):
  def __init__(self, opt=None):
    self.augment = List_Compose([
      (Permute((1,0,2)), None, None),

      (Resize_image(size=(opt.ct_channel, opt.fine_size, opt.fine_size)),
       Resize_image(size=(opt.xray_channel, opt.fine_size, opt.fine_size)),
       Resize_image(size=(opt.xray_channel, opt.fine_size, opt.fine_size)),),

      (Limit_Min_Max_Threshold(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]), None, None),

      (Normalization(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]),
       Normalization(opt.XRAY1_MIN_MAX[0], opt.XRAY1_MIN_MAX[1]),
       Normalization(opt.XRAY2_MIN_MAX[0], opt.XRAY2_MIN_MAX[1])),

      (Normalization_gaussian(opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1]),
       Normalization_gaussian(opt.XRAY1_MEAN_STD[0], opt.XRAY1_MEAN_STD[1]),
       Normalization_gaussian(opt.XRAY2_MEAN_STD[0], opt.XRAY2_MEAN_STD[1])),

      # (Get_Key_slice(opt.select_slice_num), None, None),

      (ToTensor(), ToTensor(), ToTensor())

    ])

  def __call__(self, img_list):
    '''
    :param img: PIL image
    :param boxes: numpy.ndarray
    :param labels: numpy.ndarray
    :return:
    '''
    return self.augment(img_list)

class CT_XRAY_Data_Test_Multi(object):
  def __init__(self, opt=None):
    self.augment = List_Compose([
      (Permute((1,0,2)), None, None),

      (Resize_image(size=(opt.ct_channel, opt.fine_size, opt.fine_size)),
       Resize_image(size=(opt.xray_channel, opt.fine_size, opt.fine_size)),
       Resize_image(size=(opt.xray_channel, opt.fine_size, opt.fine_size)),),

      (Limit_Min_Max_Threshold(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]), None, None),

      (Normalization(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]),
       Normalization(opt.XRAY1_MIN_MAX[0], opt.XRAY1_MIN_MAX[1]),
       Normalization(opt.XRAY2_MIN_MAX[0], opt.XRAY2_MIN_MAX[1])),

      (Normalization_gaussian(opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1]),
       Normalization_gaussian(opt.XRAY1_MEAN_STD[0], opt.XRAY1_MEAN_STD[1]),
       Normalization_gaussian(opt.XRAY2_MEAN_STD[0], opt.XRAY2_MEAN_STD[1])),

      # (Get_Key_slice(opt.select_slice_num), None),

      (ToTensor(), ToTensor(), ToTensor())

    ])

  def __call__(self, img):
    '''
    :param img: PIL image
    :param boxes: numpy.ndarray
    :param labels: numpy.ndarray
    :return:
    '''
    return self.augment(img)


'''
Data Augmentation
'''
class CT_Data_Augmentation(object):
  def __init__(self, opt=None):
    self.augment = Compose([
      Permute((1,0,2)),
      Normalization(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]),
      Normalization_gaussian(opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1]),
      Get_Key_slice(opt.select_slice_num),
      ToTensor()
    ])

  def __call__(self, img):
    '''
    :param img: PIL image
    :param boxes: numpy.ndarray
    :param labels: numpy.ndarray
    :return:
    '''
    return self.augment(img)

class Xray_Data_Augmentation(object):
  def __init__(self, opt=None):
    self.augment = Compose([
      Normalization(opt.XRAY1_MIN_MAX[0], opt.XRAY1_MIN_MAX[1]),
      Normalization_gaussian(opt.XRAY1_MEAN_STD[0], opt.XRAY1_MEAN_STD[1]),
      ToTensor()
    ])

  def __call__(self, img):
    '''
    :param img: PIL Image
    :return:
    '''
    return self.augment(img)

class CT_Data_Test(object):
  def __init__(self, opt=None):
    self.augment = Compose([
      Permute((1, 0, 2)),
      Normalization(opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]),
      Normalization_gaussian(opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1]),
      Get_Key_slice(opt.select_slice_num),
      ToTensor()
    ])

  def __call__(self, img):
    '''
    :param img: PIL image
    :param boxes: numpy.ndarray
    :param labels: numpy.ndarray
    :return:
    '''
    return self.augment(img)

class Xray_Data_Test(object):
  def __init__(self, opt=None):
    self.augment = Compose([
      Normalization(opt.XRAY1_MIN_MAX[0], opt.XRAY1_MIN_MAX[1]),
      Normalization_gaussian(opt.XRAY1_MEAN_STD[0], opt.XRAY1_MEAN_STD[1]),
      ToTensor()
    ])

  def __call__(self, img):
    '''
    :param img: PIL image
    :param boxes: numpy.ndarray
    :param labels: numpy.ndarray
    :return:
    '''
    return self.augment(img)

# ##########################################
# Test
# ##########################################




In [None]:

test_file = r"C:\Users\Babak\Desktop\X-ray_to_CT_project\X2CT\3DGAN\data\LIDC-HDF5-256\1.3.6.1.4.1.9328.50.4.0368.nii\ct_xray_data.h5"
import h5py
import matplotlib.pyplot as plt

from lib.config.config import cfg, merge_dict_and_yaml
opt = merge_dict_and_yaml(dict(), cfg)

hdf = h5py.File(test_file, 'r')
ct = np.asarray(hdf['ct'])
xray = np.asarray(hdf['xray1'])
xray = np.expand_dims(xray, 0)
print(xray.shape)
transforma = CT_XRAY_Data_Augmentation(opt)
transform_normal = CT_XRAY_Data_Test(opt)
ct_normal, xray_normal = transform_normal([ct, xray])
ct_trans, xray_trans = transforma([ct, xray])
ct_trans = tensor_backto_unnormalization_image(ct_trans, opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1])
xray_trans = tensor_backto_unnormalization_image(xray_trans, opt.XRAY1_MEAN_STD[0], opt.XRAY1_MEAN_STD[1])
ct_normal = tensor_backto_unnormalization_image(ct_normal, opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1])
xray_normal = tensor_backto_unnormalization_image(xray_normal, opt.XRAY1_MEAN_STD[0], opt.XRAY1_MEAN_STD[1])
bb = Normalization_to_range()
ct_trans = bb(ct_trans)
xray_trans = bb(xray_trans)
# trans_CT = CT_Data_Augmentation(opt)
# trans_Xray = Xray_Data_Augmentation(opt)
# ct_trans = trans_CT(ct).numpy()
# xray_trans = trans_Xray(xray).numpy()
import cv2
print(ct_trans.shape, ct_normal.shape)
cv2.imshow('1', xray_trans[0].astype(np.uint8))
cv2.imshow('2', ct_trans[80, :, :].astype(np.uint8))
cv2.imshow('1-1', bb(xray_normal)[0].astype(np.uint8))
cv2.imshow('2-1', bb(ct_normal)[80, :, :].astype(np.uint8))
cv2.waitKey(0)
plt.figure(1)
plt.imshow(xray_trans[0], cmap=plt.cm.bone)
plt.figure(2)
plt.imshow(ct_trans[80, :, :], cmap=plt.cm.bone)
plt.show()

In [None]:
import argparse
from lib.config.config import cfg_from_yaml, cfg, merge_dict_and_yaml, print_easy_dict
from lib.dataset.factory import get_dataset
from lib.model.factory import get_model
import copy
import torch
import time
import os


def create_args():
    args = argparse.Namespace(
        data='LIDC256',
        tag='d2_multiview2500',
        dataroot='./data/LIDC-HDF5-256',
        dataset='train',
        valid_dataset='test',
        datasetfile='./data/train.txt',
        valid_datasetfile='./data/test.txt',
        ymlpath='./experiment/multiview2500/d2_multiview2500.yml',
        gpuid='0',
        dataset_class='align_ct_xray_views_std',
        model_class='MultiViewCTGAN',
        check_point=None,
        load_path=None,
        latest=False,
        verbose=False
    )
    return args

def main():
  args = create_args()

  # check gpu
  if args.gpuid == '':
    args.gpu_ids = []
  else:
    if torch.cuda.is_available():
      split_gpu = str(args.gpuid).split(',')
      args.gpu_ids = list(map(int, args.gpuid.split(',')))
    else:
      print('There is no gpu!')
      exit(0)

  # check point
  if args.check_point is None:
    args.epoch_count = 1
  else:
    args.epoch_count = int(args.check_point) + 1

  # merge config with yaml
  if args.ymlpath is not None:
    cfg_from_yaml(args.ymlpath)
  # merge config with argparse
  opt = copy.deepcopy(cfg)
  opt = merge_dict_and_yaml(args.__dict__, opt)
  print_easy_dict(opt)

  # add data_augmentation
  datasetClass, augmentationClass, dataTestClass, collateClass = get_dataset(opt.dataset_class)
  opt.data_augmentation = augmentationClass

  # valid dataset
  if args.valid_dataset is not None:
    valid_opt = copy.deepcopy(opt)
    valid_opt.data_augmentation = dataTestClass
    valid_opt.datasetfile = opt.valid_datasetfile


    valid_dataset = datasetClass(valid_opt)
    print('Valid DataSet is {}'.format(valid_dataset.name))
    valid_dataloader = torch.utils.data.DataLoader(
      valid_dataset,
      batch_size=1,
      shuffle=False,
      num_workers=int(valid_opt.nThreads),
      collate_fn=collateClass)
    valid_dataset_size = len(valid_dataloader)
    print('#validation images = %d' % valid_dataset_size)
  else:
    valid_dataloader = None

  # get dataset
  dataset = datasetClass(opt)
  print('DataSet is {}'.format(dataset.name))
  dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=int(opt.nThreads),
    collate_fn=collateClass)

  dataset_size = len(dataloader)
  print('#training images = %d' % dataset_size)

  # get model
  gan_model = get_model(opt.model_class)()
  print('Model --{}-- will be Used'.format(gan_model.name))
  gan_model.init_process(opt)
  total_steps, epoch_count = gan_model.setup(opt)

  # set to train
  gan_model.train()

  # visualizer
  from lib.utils.visualizer import Visualizer
  visualizer = Visualizer(log_dir=os.path.join(gan_model.save_root, 'train_log'))

  total_steps = total_steps

  # train discriminator more
  dataloader_iter_for_discriminator = iter(dataloader)

  # train
  for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    iter_data_time = time.time()

    for epoch_i, data in enumerate(dataloader):
      iter_start_time = time.time()

      total_steps += 1

      gan_model.set_input(data)
      t0 = time.time()
      gan_model.optimize_parameters()
      t1 = time.time()

      # if total_steps == 1:
      #   visualizer.add_graph(model=gan_model, input=gan_model.forward())

      # # visual gradient
      # if opt.verbose and total_steps % opt.print_freq == 0:
      #   for name, para in gan_model.named_parameters():
      #     visualizer.add_histogram('Grad_' + name, para.grad.data.clone().cpu().numpy(), step=total_steps)
      #     visualizer.add_histogram('Weight_' + name, para.data.clone().cpu().numpy(), step=total_steps)
      #   for name in gan_model.model_names:
      #     net = getattr(gan_model, 'net' + name)
      #     if hasattr(net, 'output_dict'):
      #       for name, out in net.output_dict.items():
      #         visualizer.add_histogram(name, out.numpy(), step=total_steps)

      # loss
      loss_dict = gan_model.get_current_losses()
      # visualizer.add_scalars('Train_Loss', loss_dict, step=total_steps)
      total_loss = visualizer.add_total_scalar('Total loss', loss_dict, step=total_steps)
      # visualizer.add_average_scalers('Epoch Loss', loss_dict, step=total_steps, write=False)
      # visualizer.add_average_scalar('Epoch total Loss', total_loss)

      # metrics
      # metrics_dict = gan_model.get_current_metrics()
      # visualizer.add_scalars('Train_Metrics', metrics_dict, step=total_steps)
      # visualizer.add_average_scalers('Epoch Metrics', metrics_dict, step=total_steps, write=False)

      if total_steps % opt.print_freq == 0:
        print('total step: {} timer: {:.4f} sec.'.format(total_steps, t1 - t0))
        print('epoch {}/{}, step{}:{} || total loss:{:.4f}'.format(epoch, opt.niter + opt.niter_decay,
                                                                   epoch_i, dataset_size, total_loss))
        print('||'.join(['{}: {:.4f}'.format(k, v) for k, v in loss_dict.items()]))
        # print('||'.join(['{}: {:.4f}'.format(k, v) for k, v in metrics_dict.items()]))
        print('')

      # if total_steps % opt.print_img_freq == 0:
      #   visualizer.add_image('Image', gan_model.get_current_visuals(), gan_model.get_normalization_list(), total_steps)

      '''
      WGAN
      '''
      if (opt.critic_times - 1) > 0:
        for critic_i in range(opt.critic_times - 1):
          try:
            data = next(dataloader_iter_for_discriminator)
            gan_model.set_input(data)
            gan_model.optimize_D()
          except:
            dataloader_iter_for_discriminator = iter(dataloader)
      del(loss_dict)

    # # save model every epoch
    # print('saving the latest model (epoch %d, total_steps %d)' %
    #       (epoch, total_steps))
    # gan_model.save_networks(epoch, total_steps, True)

    # save model several epoch
    if epoch % opt.save_epoch_freq == 0 and epoch >= opt.begin_save_epoch:
      print('saving the model at the end of epoch %d, iters %d' %
            (epoch, total_steps))
      gan_model.save_networks(epoch, total_steps)

    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
    ##########
    # For speed
    ##########
    # visualizer.add_image('Image_Epoch', gan_model.get_current_visuals(), gan_model.get_normalization_list(), epoch)
    # visualizer.add_average_scalers('Epoch Loss', None, step=epoch, write=True)
    # visualizer.add_average_scalar('Epoch total Loss', None, step=epoch, write=True)

    # visualizer.add_average_scalers('Epoch Metrics', None, step=epoch, write=True)

    # visualizer.add_scalar('Learning rate', gan_model.optimizers[0].param_groups[0]['lr'], epoch)
    gan_model.update_learning_rate(epoch)

    # # Test
    # if args.valid_dataset is not None:
    #   if epoch % opt.save_epoch_freq == 0 or epoch==1:
    #     gan_model.eval()
    #     iter_valid_dataloader = iter(valid_dataloader)
    #     for v_i in range(len(valid_dataloader)):
    #       data = next(iter_valid_dataloader)
    #       gan_model.set_input(data)
    #       gan_model.test()
    #
    #       if v_i < opt.howmany_in_train:
    #         visualizer.add_image('Test_Image', gan_model.get_current_visuals(), gan_model.get_normalization_list(), epoch*10+v_i, max_image=25)
    #
    #       # metrics
    #       metrics_dict = gan_model.get_current_metrics()
    #       visualizer.add_average_scalers('Epoch Test_Metrics', metrics_dict, step=total_steps, write=False)
    #     visualizer.add_average_scalers('Epoch Test_Metrics', None, step=epoch, write=True)
    #
    #     gan_model.train()

if __name__ == '__main__':
    main()

In [None]:
import argparse
from lib.config.config import cfg_from_yaml, cfg, merge_dict_and_yaml, print_easy_dict
from lib.dataset.factory import get_dataset
from lib.model.factory import get_model
from lib.utils import html
from lib.utils.visualizer import tensor_back_to_unnormalization, save_images, tensor_back_to_unMinMax
#from lib.utils.metrics_np import MAE, MSE, Peak_Signal_to_Noise_Rate, Structural_Similarity, Cosine_Similarity
from lib.utils import ct as CT
import copy
import tqdm
import torch
import numpy as np
import os

class Args:
    pass

if __name__ == '__main__':
  if __name__ == '__main__':
    # args = parse_args()
    args = Args()
    args.data = 'LIDC256'
    args.tag = 'd2_multiview2500'
    args.dataroot = './data/LIDC-HDF5-256'
    args.dataset = 'test'
    args.datasetfile = './data/test.txt'
    args.ymlpath = './experiment/multiview2500/d2_multiview2500.yml'
    args.gpuid = '0'
    args.dataset_class = 'align_ct_xray_views_std'
    args.model_class = 'MultiViewCTGAN'
    args.check_point = '100'
    args.latest = False
    args.verbose = False
    args.load_path = None
    args.how_many = 50
    args.resultdir = './results'


  # check gpu
  if args.gpuid == '':
    args.gpu_ids = []
  else:
    if torch.cuda.is_available():
      split_gpu = str(args.gpuid).split(',')
      args.gpu_ids = [int(i) for i in split_gpu]
    else:
      print('There is no gpu!')
      exit(0)

  # check point
  if args.check_point is None:
    args.epoch_count = 1
  else:
    args.epoch_count = int(args.check_point)

  # merge config with yaml
  if args.ymlpath is not None:
    cfg_from_yaml(args.ymlpath)
  # merge config with argparse
  opt = copy.deepcopy(cfg)
  opt = merge_dict_and_yaml(args.__dict__, opt)
  print_easy_dict(opt)

  opt.serial_batches = True

  # add data_augmentation
  datasetClass, _, dataTestClass, collateClass = get_dataset(opt.dataset_class)
  opt.data_augmentation = dataTestClass

  # get dataset
  dataset = datasetClass(opt)
  print('DataSet is {}'.format(dataset.name))
  dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=int(opt.nThreads),
    collate_fn=collateClass)

  dataset_size = len(dataloader)
  print('#Test images = %d' % dataset_size)

  # get model
  gan_model = get_model(opt.model_class)()
  print('Model --{}-- will be Used'.format(gan_model.name))

  # set to test
  gan_model.eval()

  gan_model.init_process(opt)
  total_steps, epoch_count = gan_model.setup(opt)

  # must set to test Mode again, due to  omission of assigning mode to network layers
  # model.training is test, but BN.training is training
  if opt.verbose:
    print('## Model Mode: {}'.format('Training' if gan_model.training else 'Testing'))
    for i, v in gan_model.named_modules():
      print(i, v.training)

  if 'batch' in opt.norm_G:
    gan_model.eval()
  elif 'instance' in opt.norm_G:
    gan_model.eval()
    # instance norm in training mode is better
    for name, m in gan_model.named_modules():
      if m.__class__.__name__.startswith('InstanceNorm'):
        m.train()
  else:
    raise NotImplementedError()

  if opt.verbose:
    print('## Change to Model Mode: {}'.format('Training' if gan_model.training else 'Testing'))
    for i, v in gan_model.named_modules():
      print(i, v.training)

  web_dir = os.path.join(opt.resultdir, opt.data, '%s_%s' % (opt.dataset, opt.check_point))
  webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.data, opt.dataset, opt.check_point))
  ctVisual = CT.CTVisual()

  avg_dict = dict()
  for epoch_i, data in tqdm.tqdm(enumerate(dataloader)):

    gan_model.set_input(data)
    gan_model.test()

    visuals = gan_model.get_current_visuals()
    img_path = gan_model.get_image_paths()

    if epoch_i <= opt.how_many:
      #
      # Image HTML
      #
      save_images(webpage, visuals, img_path, gan_model.get_normalization_list(), max_image=50)
      #
      # CT Source
      #
      generate_CT = visuals['G_fake'].data.clone().cpu().numpy()
      real_CT = visuals['G_real'].data.clone().cpu().numpy()
      # To NDHW
      if 'std' in opt.dataset_class or 'baseline' in opt.dataset_class:
        generate_CT_transpose = generate_CT
        real_CT_transpose = real_CT
      else:
        generate_CT_transpose = np.transpose(generate_CT, (0, 2, 1, 3))
        real_CT_transpose = np.transpose(real_CT, (0, 2, 1, 3))
      # Inveser Deepth
      generate_CT_transpose = generate_CT_transpose[:, ::-1, :, :]
      real_CT_transpose = real_CT_transpose[:, ::-1, :, :]
      # To [0, 1]
      generate_CT_transpose = tensor_back_to_unnormalization(generate_CT_transpose, opt.CT_MEAN_STD[0],
                                                             opt.CT_MEAN_STD[1])
      real_CT_transpose = tensor_back_to_unnormalization(real_CT_transpose, opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1])
      # Clip generate_CT
      generate_CT_transpose = np.clip(generate_CT_transpose, 0, 1)

      # #
      # # Evaluate Part
      # #
      # mae = MAE(real_CT_transpose, generate_CT_transpose, size_average=False)
      # mse = MSE(real_CT_transpose, generate_CT_transpose, size_average=False)
      # cosinesimilarity = Cosine_Similarity(real_CT_transpose, generate_CT_transpose, size_average=False)
      # psnr = Peak_Signal_to_Noise_Rate(real_CT_transpose, generate_CT_transpose, size_average=False, PIXEL_MAX=1.0)
      # ssim = Structural_Similarity(real_CT_transpose, generate_CT_transpose, size_average=False, PIXEL_MAX=1.0)
      #
      # metrics_list = [('MAE', mae), ('MSE', mse), ('CosineSimilarity', cosinesimilarity), ('PSNR-1', psnr[0]),
      #                 ('PSNR-2', psnr[1]), ('PSNR-3', psnr[2]), ('PSNR-avg', psnr[3]),
      #                 ('SSIM-1', ssim[0]), ('SSIM-2', ssim[1]), ('SSIM-3', ssim[2]), ('SSIM-avg', ssim[3])]

      # To HU coordinate
      generate_CT_transpose = tensor_back_to_unMinMax(generate_CT_transpose, opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]).astype(np.int32) - 1024
      real_CT_transpose = tensor_back_to_unMinMax(real_CT_transpose, opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]).astype(np.int32) - 1024
      # Save
      name1 = os.path.splitext(os.path.basename(img_path[0][0]))[0]
      name2 = os.path.split(os.path.dirname(img_path[0][0]))[-1]
      name = name2 + '_' + name1
      image_root = os.path.join(web_dir, 'CT', name)
      if not os.path.exists(image_root):
        os.makedirs(image_root)
      save_path = os.path.join(image_root, 'fake_ct.mha')
      ctVisual.save(generate_CT_transpose.squeeze(0), spacing=(1.0, 1.0, 1.0), origin=(0,0,0), path=save_path)
      save_path = os.path.join(image_root, 'real_ct.mha')
      ctVisual.save(real_CT_transpose.squeeze(0), spacing=(1.0, 1.0, 1.0), origin=(0, 0, 0), path=save_path)

    else:
      break
    del visuals, img_path
  webpage.save()



In [None]:
import argparse
from lib.config.config import cfg_from_yaml, cfg, merge_dict_and_yaml, print_easy_dict
from lib.dataset.factory import get_dataset
from lib.model.factory import get_model
from lib.utils.visualizer import tensor_back_to_unnormalization, tensor_back_to_unMinMax
from lib.utils.metrics_np import MAE, MSE, Peak_Signal_to_Noise_Rate, Structural_Similarity, Cosine_Similarity, \
  Peak_Signal_to_Noise_Rate_3D
import copy
import tqdm
import torch
import numpy as np
import os


class Args:
    pass

if __name__ == '__main__':
  if __name__ == '__main__':
    # args = parse_args()
    args = Args()
    args.data = 'LIDC256'
    args.tag = 'd2_multiview2500'
    args.dataroot = './data/LIDC-HDF5-256'
    args.dataset = 'test'
    args.datasetfile = './data/test.txt'
    args.ymlpath = './experiment/multiview2500/d2_multiview2500.yml'
    args.gpuid = '0'
    args.dataset_class = 'align_ct_xray_views_std'
    args.model_class = 'MultiViewCTGAN'
    args.check_point = '100'
    args.latest = False
    args.verbose = False
    args.load_path = None
    args.how_many = 50
    args.resultdir = './multiview'




def evaluate(args):
  # check gpu
  if args.gpuid == '':
    args.gpu_ids = []
  else:
    if torch.cuda.is_available():
      split_gpu = str(args.gpuid).split(',')
      args.gpu_ids = [int(i) for i in split_gpu]
    else:
      print('There is no gpu!')
      exit(0)

  # check point
  if args.check_point is None:
    args.epoch_count = 1
  else:
    args.epoch_count = int(args.check_point)

  # merge config with yaml
  if args.ymlpath is not None:
    cfg_from_yaml(args.ymlpath)
  # merge config with argparse
  opt = copy.deepcopy(cfg)
  opt = merge_dict_and_yaml(args.__dict__, opt)
  print_easy_dict(opt)

  opt.serial_batches = True

  # add data_augmentation
  datasetClass, _, dataTestClass, collateClass = get_dataset(opt.dataset_class)
  opt.data_augmentation = dataTestClass

  # get dataset
  dataset = datasetClass(opt)
  print('DataSet is {}'.format(dataset.name))
  dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=int(opt.nThreads),
    collate_fn=collateClass)

  dataset_size = len(dataloader)
  print('#Test images = %d' % dataset_size)

  # get model
  gan_model = get_model(opt.model_class)()
  print('Model --{}-- will be Used'.format(gan_model.name))

  # set to test
  gan_model.eval()

  gan_model.init_process(opt)
  total_steps, epoch_count = gan_model.setup(opt)

  # must set to test Mode again, due to  omission of assigning mode to network layers
  # model.training is test, but BN.training is training
  if opt.verbose:
    print('## Model Mode: {}'.format('Training' if gan_model.training else 'Testing'))
    for i, v in gan_model.named_modules():
      print(i, v.training)

  if 'batch' in opt.norm_G:
    gan_model.eval()
  elif 'instance' in opt.norm_G:
    gan_model.eval()
    # instance norm in training mode is better
    for name, m in gan_model.named_modules():
      if m.__class__.__name__.startswith('InstanceNorm'):
        m.train()
  else:
    raise NotImplementedError()

  if opt.verbose:
    print('## Change to Model Mode: {}'.format('Training' if gan_model.training else 'Testing'))
    for i, v in gan_model.named_modules():
      print(i, v.training)

  result_dir = os.path.join(opt.resultdir, opt.data, '%s_%s' % (opt.dataset, opt.check_point))
  if not os.path.exists(result_dir):
    os.makedirs(result_dir)

  avg_dict = dict()
  for epoch_i, data in tqdm.tqdm(enumerate(dataloader)):

    gan_model.set_input(data)
    gan_model.test()

    visuals = gan_model.get_current_visuals()
    img_path = gan_model.get_image_paths()

    #
    # Evaluate Part
    #
    generate_CT = visuals['G_fake'].data.clone().cpu().numpy()
    real_CT = visuals['G_real'].data.clone().cpu().numpy()
    # To [0, 1]
    # To NDHW
    if 'std' in opt.dataset_class or 'baseline' in opt.dataset_class:
      generate_CT_transpose = generate_CT
      real_CT_transpose = real_CT
    else:
      generate_CT_transpose = np.transpose(generate_CT, (0, 2, 1, 3))
      real_CT_transpose = np.transpose(real_CT, (0, 2, 1, 3))
    generate_CT_transpose = tensor_back_to_unnormalization(generate_CT_transpose, opt.CT_MEAN_STD[0],
                                                           opt.CT_MEAN_STD[1])
    real_CT_transpose = tensor_back_to_unnormalization(real_CT_transpose, opt.CT_MEAN_STD[0], opt.CT_MEAN_STD[1])
    # clip generate_CT
    generate_CT_transpose = np.clip(generate_CT_transpose, 0, 1)

    # CT range 0-1
    mae0 = MAE(real_CT_transpose, generate_CT_transpose, size_average=False)
    mse0 = MSE(real_CT_transpose, generate_CT_transpose, size_average=False)
    cosinesimilarity = Cosine_Similarity(real_CT_transpose, generate_CT_transpose, size_average=False)
    ssim = Structural_Similarity(real_CT_transpose, generate_CT_transpose, size_average=False, PIXEL_MAX=1.0)
    # CT range 0-4096
    generate_CT_transpose = tensor_back_to_unMinMax(generate_CT_transpose, opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]).astype(
      np.int32)
    real_CT_transpose = tensor_back_to_unMinMax(real_CT_transpose, opt.CT_MIN_MAX[0], opt.CT_MIN_MAX[1]).astype(
      np.int32)
    psnr_3d = Peak_Signal_to_Noise_Rate_3D(real_CT_transpose, generate_CT_transpose, size_average=False, PIXEL_MAX=4095)
    psnr = Peak_Signal_to_Noise_Rate(real_CT_transpose, generate_CT_transpose, size_average=False, PIXEL_MAX=4095)
    mae = MAE(real_CT_transpose, generate_CT_transpose, size_average=False)
    mse = MSE(real_CT_transpose, generate_CT_transpose, size_average=False)

    name1 = os.path.splitext(os.path.basename(img_path[0][0]))[0]
    name2 = os.path.split(os.path.dirname(img_path[0][0]))[-1]
    name = name2 + '_' + name1
    print(cosinesimilarity, name)
    if cosinesimilarity is np.nan or cosinesimilarity > 1:
      print(os.path.splitext(os.path.basename(gan_model.get_image_paths()[0][0]))[0])
      continue

    metrics_list = [('MAE0', mae0), ('MSE0', mse0), ('MAE', mae), ('MSE', mse), ('CosineSimilarity', cosinesimilarity),
                    ('psnr-3d', psnr_3d), ('PSNR-1', psnr[0]),
                    ('PSNR-2', psnr[1]), ('PSNR-3', psnr[2]), ('PSNR-avg', psnr[3]),
                    ('SSIM-1', ssim[0]), ('SSIM-2', ssim[1]), ('SSIM-3', ssim[2]), ('SSIM-avg', ssim[3])]

    for key, value in metrics_list:
      if avg_dict.get(key) is None:
        avg_dict[key] = [] + value.tolist()
      else:
        avg_dict[key].extend(value.tolist())

    del visuals, img_path

  for key, value in avg_dict.items():
    print('### --{}-- total: {}; avg: {} '.format(key, len(value), np.round(np.mean(value), 7)))
    avg_dict[key] = np.mean(value)

  return avg_dict


if __name__ == '__main__':
  evaluate(args)
