**3D Reconstruction of Fractured Femoral Bone from X-Ray Based Image**

*Implement by Mr.Danupong Buttongkum*

* Last Modified Date: **11-09-2022**
* Device: **Tyan GPU**


# Preparation 

Import nessesary toolbox and add-on

In [None]:
# Simple Toolbox 
from __future__ import print_function, division
import matplotlib
import matplotlib.pyplot as plt
import os
import skimage
from skimage import io, transform, measure
import scipy
from scipy.spatial.distance import directed_hausdorff
from sklearn.neighbors import NearestNeighbors
import numpy as np
import cv2
from PIL import Image
import sys
import ipyvolume as ipv
import k3d
import pythreejs   # for controlling the Camera using with ipyvolume
import ipywidgets
import pandas as pd
import datetime
import time
import joblib
import ipywidgets as widgets
from tqdm.notebook import tqdm
%matplotlib inline

# Advance Toolbox 
import torch
import torch.nn as nn
import torch.nn.functional as F                                      # useful stateless functions
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.autograd import Function, Variable

from torch.utils.data import Dataset, DataLoader
from torch.utils.data import sampler

import torchvision
import torchvision.datasets as dset
import torchvision.transforms as T
from torchvision import utils, transforms

import optuna
from optuna.trial import TrialState
from optuna.visualization import plot_contour, plot_edf, plot_intermediate_values, plot_optimization_history
from optuna.visualization import plot_parallel_coordinate, plot_param_importances, plot_slice

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion()   # interactive mode


selectdataset = int(input('Select dataset: \n1) Dataset3 \n2) Dataset5 \n3) Dataset2 \n4) CU_Dataset2 \n5) CU_Dataset4_Unaligned \n6) Siriraj_dataset \n\n Input : '))
if selectdataset==1:
    root_dir = r'D:\FEW PhD\Datasets\HNSC\Cleaning Data\Dataset3'     
    training_file = r'D:\FEW PhD\Datasets\HNSC\Cleaning Data\Dataset3\Training_set.xlsx'
    val_file = r'D:\FEW PhD\Datasets\HNSC\Cleaning Data\Dataset3\Validation_set.xlsx'
elif selectdataset==2:
    root_dir = r'D:\FEW PhD\Datasets\HNSC\Cleaning Data\Dataset5'
    training_file = r'D:\FEW PhD\Datasets\HNSC\Cleaning Data\Dataset5\Training_set.xlsx'
    val_file = r'D:\FEW PhD\Datasets\HNSC\Cleaning Data\Dataset5\Validation_set.xlsx'
elif selectdataset==3:
    root_dir = r'D:\FEW PhD\Datasets\HNSC\Cleaning Data\Dataset2'
    training_file = r'D:\FEW PhD\Datasets\HNSC\Cleaning Data\Dataset2\Training_set2.xlsx'
    val_file = r'D:\FEW PhD\Datasets\HNSC\Cleaning Data\Dataset2\Validation_set2.xlsx'
elif selectdataset==4:
    root_dir = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2'
    augment_dataset = int(input('Select : [1] Non-Fracture-Augmention  [2] Fracture-Augmentation  = '))
    if augment_dataset == 1:
        training_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\Training_set.xlsx'
        val_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\Validation_set.xlsx'
    elif augment_dataset == 2:
        training_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\Training_set2.xlsx'
        val_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\Validation_set.xlsx'
    else:
        assert augment_dataset <= 2 , 'Fail to select dataset'
elif selectdataset==5:
    root_dir = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned'    # including fractural augmented data
    augment_dataset = int(input('Select : [1] Scale12 [2] Scale123  [3] Extra  = '))
    if augment_dataset == 1:
        training_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\TrainingSet_Scale12.xlsx'
        val_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\TestingSet_Scale12.xlsx'
    elif augment_dataset == 2:
        training_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\TrainingSet_Scale123.xlsx'
        val_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\TestingSet_Scale123.xlsx'
    elif augment_dataset == 3:
        training_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\TrainingSet_Scale123_extra.xlsx'
        val_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\TestingSet_Scale12.xlsx'
    else:
        assert augment_dataset <= 3 , 'Fail to select dataset'
elif selectdataset==6:
    root_dir = r'D:\FEW PhD\Datasets\Siriraj Dataset\Cleaning Data\Siriraj_UnalignData'
    training_file = r'D:\FEW PhD\Datasets\Siriraj Dataset\Cleaning Data\Siriraj_UnalignData\Siriraj_testset.xlsx'
    val_file = r'D:\FEW PhD\Datasets\Siriraj Dataset\Cleaning Data\Siriraj_UnalignData\Siriraj_testset.xlsx'
else:
    assert selectdataset <= 4 , 'Fail to select dataset'

    
print('Training set {}'.format(training_file))
print('Test set {}'.format(val_file))

use_gpu = int(input('Input: [1] Use GPU  [2] Use CPU : '))
if use_gpu==1:
    USE_GPU = True
    print('\nDevice: GPU')
else:
    USE_GPU = False
    print('\nDevice: CPU')
dtype1 = torch.float32        # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    print('Using device : CUDA')
    device = torch.device("cuda")
    for i in range(torch.cuda.device_count()):
        print('   Device name({}): {}'.format(i, torch.cuda.get_device_name(i)))
        print('     {}'.format(torch.cuda.get_device_properties(i)))
    print('Current device: {}  >>>  {}'.format(torch.cuda.current_device(), torch.cuda.get_device_name(torch.cuda.current_device())))
else:
    print('Using device : CPU')
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

#from utils import show_sample, FemurDataset, NormalizeSample, ToTensor6, ToTensor7
from utils import show_sample, FemurDataset, FemurDataset2, NormalizeSample, NormalizeSample2, NormalizeSample3, ToTensor7, ToTensor8, ToTensor8Plus, ToTensor9
from matricesOperator import iou, TP, FP, TN, FN, union, hausdorff_voxel, overlap_based_metrices, mesh_surface_nearest_distance, surface_distance_measurement
#from losses import DiceLoss, BCEDiceLoss, MulticlassBCEDiceLoss, MulticlassBCEDiceLoss2, BCEHNMDiceLoss
from losses import FocalLossMulticlass
#from LossToolbox.FocalLoss.focal_loss import FocalLoss_Ori, BinaryFocalLoss, FEWFocalLoss, FEWFocalLoss2
#from LossToolbox.TverskyLoss.binarytverskyloss import FocalBinaryTverskyLoss, BinaryTverskyLossV2
from HausdorffLoss.hausdorff_metric import HausdorffDistance, HausdorffDistanceV2
#from HausdorffLoss.hausdorff_loss import HausdorffDTLoss, HausdorffERLoss

print('\n--- END ---')

#device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
torch.cuda.set_device(0)
print('Current device: {}'.format(torch.cuda.current_device()))

In [None]:
## Visual training dataset
files_frame = pd.read_excel(training_file)    # if file is .csv  use pd.read_csv() instead
print('Total training sample = 0-{}'.format(len(files_frame.count(1))-1))
#n = int(input('Input training sample = '))
n = 50    # well-align sample
print('File name: ' + files_frame.iloc[n,1] + ' | '+files_frame.iloc[n,2] + ' | ' + files_frame.iloc[n,3])
t1 = time.time()
target = np.load(os.path.join(root_dir, files_frame.iloc[n,1]))    # Intensity[n,0] or Mask[n,1]
view1 = np.load(os.path.join(root_dir, files_frame.iloc[n,2]))
view2 = np.load(os.path.join(root_dir, files_frame.iloc[n,3]))
print('target classes = {}'.format(np.sort(np.unique(target))))
#i = np.uint8(input('Input i = '))
#target = np.uint8(target==i)
'''t2 = time.time()
sample = {'AP':ap , 'LAT':lat , 'Target':target , 'Output':target}
t3 = time.time()
show_sample(sample, view='all', showOutput=False, detail=True)
t4 = time.time()
print('t1-2 = {} sec.\nt2-3 = {} sec.\nt3-4 = {} sec.\n'.format(t2-t1,t3-t2,t4-t3))'''

## Dataset implement

**Manual select dataset**

In [None]:
# Visual training dataset 
files_frame = pd.read_excel(training_file)
print('Total training sample = 0-{}'.format(len(files_frame.count(1))-1))
%n = int(input('Input training sample = '))
n = 50
print('File name: ' + files_frame.iloc[n,1] + ' | '+files_frame.iloc[n,2] + ' | ' + files_frame.iloc[n,3])
target = np.load(os.path.join(root_dir, files_frame.iloc[n,1]))    # Intensity[n,0] or Mask[n,1]
ap = np.load(os.path.join(root_dir, files_frame.iloc[n,2]))
lat = np.load(os.path.join(root_dir, files_frame.iloc[n,3]))

classes = np.unique(target)
print('classes = {}'.format(classes))
'''if ap.ndim == 2:               # set image to get proper tensor's image format
    ap = ap[np.newaxis,...]
    lat = lat[np.newaxis,...]'''
print('target = {}   {}'.format(target.shape, target.dtype))
print('ap = {}   {}'.format(ap.shape, ap.dtype))
print('lat = {}   {}'.format(lat.shape, lat.dtype))

In [None]:
fig, ax = plt.subplots(1,2)
fig.set_size_inches(12,12)
fig.tight_layout()   # fig.tight_layout(pad=10)
ax[0].imshow(ap, cmap='gray')
ax[0].set_title('DRR AP Projection')
ax[1].imshow(lat, cmap='gray')
ax[1].set_title('DRR LAT Projection')
plt.show()

ipv.figure(lighting=False)
i = 1
ipv.volshow(np.uint8(target==i))
ipv.show()

In [None]:
# For ToTensor7 : Chula Dataset
# create 4D np.array:  first-dim<0.5 for background,  first-dim>0.5 for foreground
# target2[0]=background ; target2[1]=boneMask ; target2[2]=fracMask ;
target2 = np.zeros((3,*target.shape))
print('target2 = {}   {}'.format(target2.shape, target2.dtype))
t1 = time.time()
#target2[2][(target==40)] = 1                        # logical
target2[2] = np.uint8(target==40)*2                  # create new np.uint8
t2 = time.time()
#target2[1][(target>0)*(target<10)] = 1              # logical
#target2[1] = np.where((target>0)*(target<10),1,0)   # np.where
target2[1] = np.uint8((target>0)*(target<10))        # create new np.uint8
t3 = time.time()
print('Sampling time = {} + {} = {} sec'.format(t2-t1, t3-t2, t3-t1))
print('target2 = {}   {}'.format(target2.shape, target2.dtype))
intersec = target2[2]
print('intersec = {}'.format(intersec.shape, intersec.dtype))
print('intersec.sum = {:,.2f}'.format(intersec.sum()))

In [None]:
I0 = target2[0]
I1 = target2[1]
I2 = target2[2]
print('I0 = {}   {}\nI1 = {}   {}'.format(I0.shape,I0.dtype,I1.shape,I1.dtype))
print('I0 = {:,.2f}\nI1 = {:,.2f}\nI = {:,.2f}'.format(I0.sum(),I1.sum(),I0.sum()+I1.sum()))
print('foreground/background ratio = {:,.4f}'.format(I1.sum()/I0.sum()))
ipv.figure()
ipv.volshow(I0)
ipv.show()
ipv.figure()
ipv.volshow(I1)
ipv.show()

**Use DataLoader**

In [None]:
from utils import FemurDataset , show_sample , NormalizeSample, ToTensor, ToTensor2, ToTensor3, ToTensor4, ToTensor5, ToTensor6, ToTensor7, ToTensor8
first = True

In [None]:
class ToTensorTest(object):
    """
        Convert ndarrays in sample to Tensors = [N,D,H,W] which have values={0,1,2}
        Input:  [0] = background
                [1,2,3,...,10] = fragment of femur bone
                [20] = pelvic bone
                [30] = soft-tissue
                [40] = fracture
                [50] = comminute
        Output: Positive non-zeros class in same channel which have label = 0,1,2  except pelvic volume
        classes = {0==background  1==femurMask  2==fracMask}
    """
    def __call__(self, sample):        # callable classes
        print('### ToTensor8 ###')
        target, view1, view2 = sample['Target'], sample['view1'], sample['view2']   # target = [D,W,-H]
        classes = np.unique(target)
        print('    target = {}   {}   {}'.format(target.shape,type(target), target.dtype))
        print('    classes = {}'.format(classes))
        if view1.ndim == 2:               # set image to get proper tensor's image format
            view1 = view1[np.newaxis,...]    # 1-view
            view2 = view2[np.newaxis,...]  # 2-view
        
        # create 3D np.array:  first-dim<0.5 for background,  first-dim>0.5 for foreground
        t1 = time.time()
        target2 = np.zeros(target.shape)    # [D,W,-H]
        target2[ target==40 ] = 2
        target2[ (target>0)*(target<10) ] = 1
        
        # transpose: [D,W,-H] to [D,H,W]
        target2 = np.flip(target2.transpose(0,2,1), axis=1).copy()
        t2 = time.time()
        print('   target2 = {}   {}   {}'.format(target2.shape, type(target2), target2.dtype))
        print('   view1 = {}   {}   {}'.format(view1.shape, type(view1), view1.dtype))
        print('   view2 = {}   {}   {}'.format(view2.shape, type(view2), view2.dtype))
        print('   Time = {} sec.'.format(t2-t1))
        
        return {'Target': torch.from_numpy(target2).int().squeeze(),
                'view1': torch.from_numpy(view1).float(),
                'view2':torch.from_numpy(view2).float() }
    
print(' --- END --- ')

In [None]:
train_transformedFemur = FemurDataset2(csv_file=training_file, root_dir=root_dir, 
                                      transform=transforms.Compose([NormalizeSample2(),ToTensorTest()]))
trainLoader = DataLoader(train_transformedFemur, batch_size=1, shuffle=True, num_workers=0)

if first==True:
    print('First testing sample')
    femurIter = iter(trainLoader)
    first = False
else:
    print('Next testing sample')
    
t1 = time.time()
sample = next(femurIter)
t2 = time.time()
print(sample.keys())
view1 = sample['view1'].squeeze()
view2 = sample['view2'].squeeze()
target = sample['Target'].squeeze()
print('target = {}   {}   fractureCount={:,.0f}'.format(target.size(),target.dtype,(target==2).sum()))
print('Sampling time = {} sec.\n'.format(t2-t1))
print('Classes list ={}'.format(torch.unique(target)))

fig, ax = plt.subplots(1,2, figsize=(12,12))
ax[0].imshow(view1, cmap='gray')
ax[1].imshow(view2, cmap='gray')
fig.show()

ipv.figure()
ipv.volshow(target==1)   # for ToTensor6-8
#ipv.view(270, 90)
ipv.show()
time.sleep(1)
ipv.figure()
'''ipv.volshow(target==2)   # for ToTensor6-8
ipv.view(270, 90)
ipv.show()'''

'''backMask = torch.sum(target[0,0])
boneMask = torch.sum(target[0,1])
fracMask = torch.sum(target[0,2])
print('Background = {:,.1f}   :{:,.4f}%'.format(backMask, backMask/(256**3)*100))
print('boneMask = {:,.1f}   :{:,.4f}%'.format(boneMask, boneMask/(256**3)*100))
print('fracMask = {:,.1f}   :{:,.5f}%\n'.format(fracMask, fracMask/(256**3)*100))
ipv.figure()
ipv.volshow(target[0,0])
ipv.view(270, 90)
ipv.show()
ipv.figure()
ipv.volshow(target[0,1])
ipv.view(270, 90)
ipv.show()'''

In [None]:
# Example for 2D Convolutional and ConvolutionTranspose layer

# With square kernels and equal stride
m1 = nn.ConvTranspose2d(16, 33, 3, stride=1, padding=0)
# non-square kernels and unequal stride and with padding
m2 = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(0, 0))
input1 = torch.randn(20, 16, 50, 100)
print('input1.size = {}'.format(input1.size()))
output1 = m1(input1)
print('output1.size = {}'.format(output1.size()))
output2 = m2(input1)
print('output2.size = {}\n'.format(output2.size()))


# exact output size can be also specified as an argument
input2 = torch.randn(1, 16, 256, 256)
print('input2.size = {}'.format(input2.size()))
downsample = nn.Conv2d(16, 32, 7, stride=1, padding=3)
upsample = nn.ConvTranspose2d(16, 32, 2, stride=2, padding=0)
DownOutput = downsample(input2)
print('DownOutput.size = {}'.format(DownOutput.size()))
UpOutput = upsample(input2)
print('UpOutput.size = .{}'.format(UpOutput.size()))



####################################################################
#   Output features size calculation of Convolutional layer        #
#                                                                  #
#  1) nn.Conv2d:                 output = (input+2*P-F)/S+1        #
#                                                                  #
#  2) nn.ConvTranspose2d:        output = (input-1)*S-2*P+F        #
#    Hint: use F=2*P+1  and  S=1 to get output.size = input.size   #
#                                                                  #
#          This calculation can applied for nn.Conv3d              #
####################################################################


In [None]:
# Example for 3D-Convolutional and 3D-ConvolutionTranspose layer

 # With square kernels and equal stride
m = nn.Conv3d(1, 4, kernel_size=3, stride=1, padding=1)
m2 = nn.ConvTranspose3d(2, 8, kernel_size=4, stride=2, padding=1)

input = torch.randn(2, 2, 32, 32, 32)
#print('input = {}'.format(input))
print('input.size() = {}'.format(input.size()))

t1 = time.time()
output = m2(input)
t2 = time.time()
#print('output = {}'.format(output))
print('output.size = {}'.format(output.size()))
print('Total calculation time = {} sec.'.format(t2-t1))


In [None]:
# Example for nn.BatchNorm2d and nn.LayerNorm

m1 = nn.BatchNorm2d(100)                # With Learnable Parameters
#m1 = nn.BatchNorm2d(100, affine=False)  # Without Learnable Parameters
input1 = torch.randn(20, 100, 35, 45)
output1 = m1(input1)
print('output1.size = {}'.format(output1.size()))
print('m1.parameter = {} \n'.format(m1.parameters))

input2 = torch.randn(20, 5, 10, 10)
print('input of LayerNorm = {}'.format(input2.size()[1:]))
#m2 = nn.LayerNorm(input2.size()[1:])    # With Learnable Parameters 
#m2 = nn.LayerNorm(input2.size()[1:], elementwise_affine=False)     # Without Learnable Parameters
m2 = nn.LayerNorm([10, 10])             # Normalize over last two dimensions
#m2 = nn.LayerNorm(10)                   # Normalize over last dimension of size 10
output2 = m2(input2)
print('output2.size = {}'.format(output2.size()))
print('m2.parameters = {}'.format(m2.parameters))

In [None]:
# Example for MaxPooling and MaxUnPooling

pool = nn.MaxPool2d(2, stride=1, return_indices=True)
unpool = nn.MaxUnpool2d(2, stride=1)
input1 = torch.tensor([[[[ 1.,  2,  3,  4],
                            [ 5,  6,  7,  8],
                            [ 9, 10, 11, 12],
                            [13, 14, 15, 16]]]])

print('\n     ***** Down Sampling *****')
output, indices = pool(input1)
print('MaxPool2d result = ', output)
print('MaxPool2d indices = ', indices)

print('\n     ***** Up Sampling *****')
unpool_result = unpool(output, indices)
print('MaxUnPool2d result = ', unpool_result)

'''
unpool_result2 = unpool(output, indices, output_size=torch.Size([1, 1, 5, 5]))
print('MaxUnPool2d result with specified output_size = ')
print(unpool_result2)
'''

In [None]:
# Example for nn.MaxUnpool3d  >>> Error!!!

pool = nn.MaxPool3d(kernel_size=3, stride=2, return_indices=True)
unpool = nn.MaxUnpool3d(kernel_size=3, stride=2)
input1 = torch.randn(20, 16, 51, 33, 15)

output, indices = pool(input1)
print('output.size = {}'.format(output.size()))
print('indices.size = {}'.format(indices.size()))

unpooled_output = unpool(output, indices)
print('unpooled_output.size = {}'.format(unpooled_output.size()))

In [None]:
# Example for torch.expand(size)

q = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print('q = {}'.format(q))
print('size_q = {} \n'.format(q.size()))

q2 = q.expand(3,3,3)
print('q2 = {}'.format(q2))

In [None]:
# Example: Maxpooling and MaxUnpooling

pool = nn.MaxPool2d(2, stride=2, return_indices=True)
unpool = nn.MaxUnpool2d(2, stride=2)
input1 = torch.tensor([[[[ 1.,  2,  3,  4],
                            [ 5,  6,  7,  8],
                            [ 9, 10, 11, 12],
                            [13, 14, 15, 16]]]])

print('     ***** Down Sampling *****')
output, indices = pool(input1)
print('output = {}'.format(output))
print('output.size = {}'.format(output.size()))
print('indices = {}'.format(indices))
print('indices.size = {}'.format(indices.size()))
print('\n     ***** Up Sampling *****')
unpool_result = unpool(output, indices)
print('MaxUnPool2d result = {} \n'.format(unpool_result))

print('     ***** Up Sampling 3D*****')
output2 = output.expand(1,2,2,2)/2
print('output2 = {}'.format(output2))
#indices2 = indices.expand(1,2,2,2)
indices2 = torch.Tensor([[[[ 5,  7],
          [13, 15]],
         [[ 5,  7],
          [13, 15]]]])

print('indices2 = {}'.format(indices2))
unpool_result2 = unpool(output,indices, output_size=torch.Size([2,1,2,2]))
print('unpool_result2 = {}'.format(unpooled_result2))

In [None]:
# Example: nn.Upsample(scale_factor, size, mode) = interpolately expanding of Tensor

#input1 = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
input2 = torch.randint(1,10, size=(1,1,2,2), dtype=torch.float32)  # try to vary size=(N,2D-4D)
print('input = {}'.format(input2))
print('input.size = {}\n'.format(input2.size()))
###  {mode} for nn.Upsample class  ###
# nearest           (N,2D-4D) = (N,C,(D),(H),W)    >>  (N,C,(2D),(2H),2W)
# linear            (N,2D) = (N,C,W)               >>  (N,C,2W)
# bilinear          (N,3D) = (N,C,H,W)             >>  (N,C,2H,2W)
# bicubic           (N,3D) = (N,C,H,W)             >>  (N,C,2H,2W)
# trilinear         (N,4D) = (N,C,D,H,W)           >>  (N,C,2D,2H,2W)
########################################
mode_selected = 'bilinear'
model1 = nn.Upsample(scale_factor=2, mode=mode_selected, align_corners=None)
model2 = nn.Upsample(scale_factor=2, mode=mode_selected, align_corners=True)
output1 = model1(input2)
output2 = model2(input2)
print('output1 = {}'.format(output1))
print('output2 = {}'.format(output2))
print('output.size = {}'.format(output1.size()))

# Neural Network

Specify data path and Setup computational device

## Architecture  (2D>3D mapping)

**Basic Convolution block and Class nn.Sequectial**

In [None]:
norm2d = nn.InstanceNorm2d
#norm2d = nn.BatchNorm2d
norm3d = nn.InstanceNorm3d
#norm3d = nn.BatchNorm3d

def single_dense_block(in_f, out_f, *args, **kwargs):
    """ Creat single encode block (norm+activation+conv)
    Args:   in_f  = input channel (int) for nn.Conv2d 
            out_f = output channel (int) for nn.Conv2d and nn.BatchNorm2d or nn.LayerNorm
            res = feature resolution for nn.LayerNorm
    """
    return nn.Sequential(norm2d(in_f),
                         nn.ReLU(),
                         #nn.LeakyReLU(0.01),
                         nn.Conv2d(in_f, out_f, *args, **kwargs))

def encode_block(in_f, out_f, res, *args, **kwargs):
    """ Creat single encode block (conv+norm+activation)
    Args:   in_f  = input channel (int) for nn.Conv2d 
            out_f = output channel (int) for nn.Conv2d and nn.BatchNorm2d or nn.LayerNorm
            res = feature resolution for nn.LayerNorm
    """
    '''
    print('in_f = {}'.format(in_f))
    print('out_f = {}'.format(out_f))
    print('res = {}'.format(res))
    '''
    return nn.Sequential(nn.Conv2d(in_f, out_f, *args, **kwargs), 
                         norm2d(out_f),
                         #norm2d([out_f,res,res]),
                         nn.ReLU(),
                         #nn.LeakyReLU(0.01)
                        )

def decode_block(in_f, out_f, res, *args, **kwargs):
    """ Creat single decode block (conv+norm+activation)
    Args:   in_f  = input channel (int) for nn.Conv2d 
            out_f = output channel (int) for nn.Conv2d and nn.BatchNorm2d
            res = feature resolution for nn.LayerNorm
    """
    return nn.Sequential(nn.Conv2d(in_f, out_f, *args, **kwargs), 
                         norm2d(out_f),
                         #nn.Dropout2d(p=0.2),
                         #norm2d([out_f,res,res]),
                         nn.ReLU(),
                         #nn.LeakyReLU(0.01)
                        )

def fully_conv2d(in_f, out_f, res, *args, **kwargs):
    """ Obsolete!!! Creat single nxn-convolution (conv + norm + activation)
    Args:   in_f  = input channel (int) for nn.Conv2d 
            out_f = output channel (int) for nn.Conv2d and nn.BatchNorm2d
            res = feature resolution for nn.LayerNorm
    """
    return nn.Sequential(nn.Conv2d(in_f, out_f, *args, **kwargs),
                         norm2d(out_f),
                         #norm2d([out_f,res,res]),
                         nn.ReLU(),
                         #nn.LeakyReLU(0.01)
                        )

def conv2d_block(in_f, out_f, res, *args, **kwargs):
    """ Creat single nxn-convolution (conv + norm + activation)
    Args:   in_f  = input channel (int) for nn.Conv2d 
            out_f = output channel (int) for nn.Conv2d and nn.BatchNorm2d
            res = feature resolution for nn.LayerNorm
    """
    return nn.Sequential(nn.Conv2d(in_f, out_f, *args, **kwargs),
                         norm2d(out_f),
                         #norm2d([out_f,res,res]),
                         nn.ReLU(),
                         #nn.LeakyReLU(0.01)
                        )

def conv3d_block(in_f, out_f, res, *args, **kwargs):
    """ Create single 3D-convolution block (conv3d + norm + activation)
    Args:   in_f = input channel (int) for nn.Conv3d
            out_f = output channel (int) for nn.Conv3D
            res = feature resolution for nn.LayerNorm
    """
    return nn.Sequential(nn.Conv3d(in_f, out_f, *args, **kwargs),
                         norm3d(out_f),
                         #norm3d([out_f,res,res,res]),
                         nn.ReLU(),
                         #nn.LeakyReLU(0.01)
                        )

def convtranspose2d_block(in_f, out_f, res, *args, **kwargs):
    """ Create single 3D-transpose-convolution block (convtranspose3d + norm + activation)
    Args:   in_f = input channel (int) for nn.ConvTranspose3d
            out_f = output channel (int) for nn.ConvTranspose3D
            res = feature resolution for nn.LayerNorm
    """
    return nn.Sequential(nn.ConvTranspose2d(in_f, out_f, *args, **kwargs),
                         norm2d(out_f),
                         nn.ReLU(),
                         #nn.LeakyReLU(0.01)
                        )

def convtranspose3d_block(in_f, out_f, res, *args, **kwargs):
    """ Create single 3D-transpose-convolution block (convtranspose3d + norm + activation)
    Args:   in_f = input channel (int) for nn.ConvTranspose3d
            out_f = output channel (int) for nn.ConvTranspose3D
            res = feature resolution for nn.LayerNorm
    """
    return nn.Sequential(nn.ConvTranspose3d(in_f, out_f, *args, **kwargs),
                         norm3d(out_f),
                         nn.ReLU(),
                         #nn.LeakyReLU(0.01)
                        )


 ##################  Weight Initialization ########################
def linear_initialize_sequence(sequential):
    """ Initialize 2D-convolution parameter
    Args:   nn.Sequential with include nn.Conv2d
    """
    for seq in sequential:
        #print('\nseq = {}'.format(seq))
        for module in seq:
            #print('module = {}'.format(module))
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                nn.init.constant_(module.bias, 0.1)
                #print('weight = {}'.format(module.weight))
                #print('bias = {}'.format(module.bias))

def conv2d_initialize_sequence(sequential):
    """ Initialize 2D-convolution parameter
    Args:   nn.Sequential with include nn.Conv2d
    """
    for seq in sequential:
        #print('\nseq = {}'.format(seq))
        for module in seq:
            #print('module = {}'.format(module))
            if isinstance(module, nn.Conv2d):
                nn.init.xavier_uniform_(module.weight)
                nn.init.constant_(module.bias, 0.1)
                #print('weight = {}'.format(module.weight))
                #print('bias = {}'.format(module.bias))
                
def conv3d_initialize_sequence(sequential):
    for seq in sequential:
        #print('\nseq = {}'.format(seq))
        for module in seq:
            #print('module = {}'.format(module))
            if isinstance(module, nn.Conv3d):
                nn.init.xavier_uniform_(module.weight)
                nn.init.constant_(module.bias, 0.1)
                #print('weight = {}'.format(module.weight))
                #print('bias = {}'.format(module.bias))
                
def convtranspose2d_initialize_sequence(sequential):
    for seq in sequential:
        #print('\nseq = {}'.format(seq))
        for module in seq:
            #print('module = {}'.format(module))
            if isinstance(module, nn.ConvTranspose2d):
                nn.init.xavier_uniform_(module.weight)
                nn.init.constant_(module.bias, 0.1)
                #print('weight = {}'.format(module.weight))
                #print('bias = {}'.format(module.bias))

def convtranspose3d_initialize_sequence(sequential):
    for seq in sequential:
        #print('\nseq = {}'.format(seq))
        for module in seq:
            #print('module = {}'.format(module))
            if isinstance(module, nn.ConvTranspose3d):
                nn.init.xavier_uniform_(module.weight)
                nn.init.constant_(module.bias, 0.1)
                #print('weight = {}'.format(module.weight))
                #print('bias = {}'.format(module.bias))

                
 ################## Class layer ################## 
class LinearLayer(nn.Module):
    """ Create sigle block of linear layer with norm and activatetion 
    """
    def __init__(self, lin_sz, *args, **kwargs):
        super(LinearLayer, self).__init__()
        linear_block = [ single_linear_block(in_f, out_f) 
                        for in_f,out_f in zip(lin_sz,lin_sz[1:])  ]
        self.linear_block = nn.Sequential(*linear_block)
        linear_initialize_sequence(self.linear_block)
    def forward(self, x):
        return self.linear_block(x)
    
class LinearLayer_bn(nn.Module):
    """ Create sigle block of linear layer with norm and activatetion 
    """
    def __init__(self, lin_sz, bn_sz, *args, **kwargs):
        super(LinearLayer_bn, self).__init__()
        linear_bn_block = [ single_linear_bn_block(in_f, out_f, bn_sz) 
                        for in_f,out_f in zip(lin_sz,lin_sz[1:])  ]
        self.linear_bn_block = nn.Sequential( *linear_bn_block )
        linear_initialize_sequence(self.linear_block)
    def forward(self, x):
        return self.linear_bn_block(x)
    
class MyEncoder(nn.Module):
    """ Create a bundle of level for Encoder
    Args:  en_sz = 2D-list [[encode_block1],[encode_block2],[encode_block3],...,[encode_blockN]]
            *args, **kwargs = up on decode_block
    """
    def __init__(self, en_sz, *args, **kwargs):
        super(MyEncoder, self).__init__()
        encode_blocks = [ encode_block(in_f, out_f, *args, **kwargs) 
                         for in_f, out_f in zip(en_sz,en_sz[1:]) ]
        self.encode_blocks = nn.Sequential( *encode_blocks )
        #print(self.encode_blocks,'\n',type(self.encode_blocks))
        conv2d_initialize_sequence(self.encode_blocks)
    def forward(self, x):
        return self.encode_blocks(x)
    
class MyDenseLayer(nn.Module):
    """ Create a bundle of dense layer
    Args:   in_f = input channel 
            k = growth rate of channel
            layer = number of composite layer
            bn_size = bottom neck size
            *args, **kwargs = up on single_dense_block
    """
    def __init__(self, in_f, k, level, *args, **kwargs):
        super(MyDenseLayer, self).__init__()
        dense_blocks = [ single_dense_block(in_f + i*k, k, *args, **kwargs) 
                         for i in range(level) ]
        self.dense_blocks = nn.Sequential(*dense_blocks)
        conv2d_initialize_sequence(self.dense_blocks)
        
    def forward(self, x):
        x_cat = x
        for i , level in enumerate(self.dense_blocks):
            if i == 0:
                x = level(x_cat)
            else:
                x = level(x_cat)
            x_cat = torch.cat((x_cat,x),dim=1)
            #print('   x_cat = {}'.format(x_cat.size()))
            #print('   x = {}'.format(x.size()))
        #print('   # Dense layer output = {} \n'.format(x_cat.size()))
        return x_cat

class MyDecoder(nn.Module):
    """ Create a bundle of level
    Args:  de_sz = 2D-list [[decode_block1],[decode_block2],[decode_block3],...,[decode_blockN]]
            *args, **kwargs = up on decode_block
    """
    def __init__(self, de_sz, *args, **kwargs):
        super(MyDecoder, self).__init__()
        decode_blocks = [ decode_block(in_f, out_f, *args, **kwargs) 
                         for in_f, out_f in zip(de_sz,de_sz[1:]) ]
        self.decode_blocks = nn.Sequential( *decode_blocks )
        #print(self.decode_blocks,'\n',type(self.decode_blocks))
        conv2d_initialize_sequence(self.decode_blocks)
    def forward(self, x):
        return self.decode_blocks(x)
    
class AxialFusion(nn.Module):
    """ Axially fuse view1 and view2 2D_features by nn.Conv2d 
    """
    def __init__(self, de3d_sz, *args, **kwargs):
        super(AxialFusion, self).__init__()
        fusion_blocks = [ conv2d_block(in_f, out_f, *args, **kwargs) 
                          for in_f, out_f in zip(de3d_sz,de3d_sz[1:]) ]
        self.fusion_blocks = nn.Sequential( *fusion_blocks )
        conv2d_initialize_sequence(self.fusion_blocks)
    def forward(self, x):
        return self.fusion_blocks(x)
    
class Fusion2d(nn.Module):
    """ Fusion view1 and view2 features by nn.Conv2d 
    """
    def __init__(self, de_sz, *args, **kwargs):
        super(Fusion2d, self).__init__()
        fusion_blocks = [ decode_block(in_f, out_f, *args, **kwargs) 
                         for in_f, out_f in zip(de_sz,de_sz[1:]) ]
        self.fusion_blocks = nn.Sequential( *fusion_blocks )
        conv2d_initialize_sequence(self.fusion_blocks)
    def forward(self, x):
        return self.fusion_blocks(x)
    
class MyDecoder3d(nn.Module):
    """
    """
    def __init__(self, de3d_sz, *args, **kwargs):
        super(MyDecoder3d, self).__init__()
        #print('\nMyDecoder3d')
        decode3d_blocks = [ conv3d_block(in_f, out_f, *args, **kwargs)
                           for in_f, out_f in zip(de3d_sz,de3d_sz[1:]) ]
        #print(decode3d_blocks,'\n',type(decode3d_blocks))
        self.decode3d_blocks = nn.Sequential( *decode3d_blocks )
    def forward(self, x):
        return self.decode3d_blocks(x)
    
class FullyConv3d(nn.Module):  # same function as MyDeCoder3d
    """ Creater a bundle of level
    Args:   final_sz = 1D-list [final_sz1,final_sz2,...,final_szN]
            res = H and W dimension resolution 
            *args, **kwargs = up on fully_conv2d
    """
    def __init__(self, final_sz, *args, **kwargs):
        super(FullyConv3d, self).__init__()
        conv3d_blocks = [ conv3d_block( in_f, out_f, *args, **kwargs) 
                           for in_f , out_f in zip(final_sz,final_sz[1:])]
        self.conv3d_blocks = nn.Sequential( *conv3d_blocks )
        conv3d_initialize_sequence(self.conv3d_blocks)
    def forward(self, x):
        return self.conv3d_blocks(x)
    
class UpConv2d(nn.Module):
    """ 
    """
    def __init__(self, final_sz, *args, **kwargs):
        super(UpConv2d, self).__init__()
        convtranspose2d_blocks = [ convtranspose2d_block( in_f, out_f, *args, **kwargs) 
                                  for in_f , out_f in zip(final_sz,final_sz[1:])]
        self.convtranspose2d_blocks = nn.Sequential( *convtranspose2d_blocks )
        convtranspose2d_initialize_sequence(self.convtranspose2d_blocks)
    def forward(self, x):
        return self.convtranspose2d_blocks(x)
    
class UpConv3d(nn.Module):
    """ 
    """
    def __init__(self, final_sz, *args, **kwargs):
        super(UpConv3d, self).__init__()
        convtranspose3d_blocks = [ convtranspose3d_block( in_f, out_f, *args, **kwargs) 
                                  for in_f , out_f in zip(final_sz,final_sz[1:])]
        self.convtranspose3d_blocks = nn.Sequential( *convtranspose3d_blocks )
        convtranspose3d_initialize_sequence(self.convtranspose3d_blocks)
        
    def forward(self, x):
        return self.convtranspose3d_blocks(x)
    
class FinalClassify3d(nn.Module):    # New Version
    """ Creater a bundle of level
    Args:   final_sz = 1D-list [final_sz1,final_sz2,...,final_szN]
            res = H and W dimension resolution 
            *args, **kwargs = up on fully_conv2d
    """
    def __init__(self, final_sz, *args, **kwargs):
        super(FinalClassify3d, self).__init__()
        classify3d_blocks = [ conv3d_block( in_f, out_f, *args, **kwargs) 
                           for in_f , out_f in zip(final_sz,final_sz[1:])]
        self.classify3d_blocks = nn.Sequential( *classify3d_blocks )
        
        #conv3d_initialize_sequence(self.classify3d_blocks)
        
    def forward(self, x):
        return self.classify3d_blocks(x)
'''    
class FinalClassify3d(nn.Module):    # Old version
    """ Creater a bundle of level
    Args:   final_sz = 1D-list [final_sz1,final_sz2,...,final_szN]
            res = H and W dimension resolution 
            *args, **kwargs = up on fully_conv2d
    """
    def __init__(self, final_sz, *args, **kwargs):
        super().__init__()
        classify_blocks = [ conv3d_block( in_f, out_f, *args, **kwargs) 
                           for in_f , out_f in zip(final_sz,final_sz[1:])]
        self.classify_blocks = nn.Sequential( *classify_blocks )
        conv3d_initialize_sequence(self.classify3d_blocks)
        
    def forward(self, x):
        return self.classify_blocks(x)
'''

class UpsampleHW(nn.Module):
    """ Upsample 3D-volume or 5-dimension data only H & W dimensions 
        keeping N, C, D dimension as the same
    Args:   input = 5-dimensional yorch.tensor
            scale_factor = 2, 
            mode = 'Linear' or 'Bilinear'
            align_corners = True or False
    """
    def __init__(self, scale_factor, mode='bilinear', align_corners=True):
        super(UpsampleHW, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.align_corners = align_corners
        self.upsample = nn.Upsample(scale_factor=self.scale_factor, 
                                    mode=self.mode, 
                                    align_corners=self.align_corners )
    def forward(self, x):
        assert x.dim() == 5 , 'Upsample3d: input should be 5D-torch.tensor'
        N, C, D, H, W = x.size()
        for i in range(C):
            if i==0:
                x_up = self.upsample(x[:,i,:,:,:]).unsqueeze(dim=1)
            else:
                x_up = torch.cat( (x_up,self.upsample(x[:,i,:,:,:]).unsqueeze(dim=1)) , dim=1)
            #print('x_up = {}   {}'.format(x_up.size(), x_up.dtype))
        return x_up
        
print('--- END ---')

## Working Model

### Recon2X3D5

In [None]:
# Single-GPU or DataPallel-Multiple-GPU
'''
encode_feat1 = list()
encode_feat2 = list()
decode_feat1 = list()
decode_feat2 = list()
fusion_feat = list()
fusionUp_feat = list()
'''
class Recon2X3D5(nn.Module):   # Signle GPU  Convolute without Dilation
    def __init__(self, in_f, en_sz, de_sz, de3d_sz, final_sz, *args, **kwargs):
        super(Recon2X3D5, self).__init__()
        assert len(en_sz)== len(de_sz) , 'These input {en_sz} and {de_sz} can not build Recon3DDenseUNet'
        
        # Prepare feature resolution
        input_res = 256
        self.res1 = [ round(input_res*0.5**i) for i in range(len(en_sz)) ]
        self.res2 = [ round(input_res*0.5**i) for i in range(len(en_sz)) ]
        self.res3 = [ round(input_res*0.5**i) for i in range(len(en_sz)) ]
        self.res1.reverse()
        self.res2.pop(-1)
        self.res2.reverse()
        self.res3.pop(-1)
        self.res3.pop(-1)
        self.res3.reverse()
        self.res4 = self.res1.copy()
        self.res4.reverse()
        self.res4.pop(0)
        '''print('self.res1 = {}'.format(self.res1))
        print('self.res2 = {}'.format(self.res2))
        print('self.res3 = {}'.format(self.res3))
        print('self.res4 = {}\n'.format(self.res4))'''
        
        self.en_sz = en_sz
        #print('en_sz = {}\n'.format(self.en_sz))
        
        cat1 = [ x[0]+x[1]*x[2] for x in en_sz]
        cat1.reverse()
        cat1.pop(0)
        #cat2 = [ en_sz[-1][0]+en_sz[-1][1]*en_sz[-1][2] , *[ x[-1] for x in de_sz ] ]
        cat2 = [ x[-1] for x in de_sz ]
        cat = [ x1+x2 for x1,x2 in zip(cat1,cat2) ]
        cat = [en_sz[-1][0]+en_sz[-1][1]*en_sz[-1][2], *cat]
        '''print('cat1 = {}'.format(cat1))
        print('cat2 = {}'.format(cat2))
        print('cat = {}\n'.format(cat))'''
        
        self.de_sz = [ [x1,*x2] for x1,x2 in zip(cat,de_sz)]
        #print('self.de_sz = {}\n'.format(self.de_sz))
        
        self.de3d_sz = de3d_sz
        #print('self.de3d_sz = {}\n'.format(self.de3d_sz))
        
        self.final_sz = final_sz
        #print('self.final_sz = {}\n'.format(self.final_sz))
        
        ############################################
        # Class Layer description : MyDenseLayer >> MyDecoder >> UpConv2d >> MyDecoder3d >> UpConv3d >> nn.Conv3d
        # Starting layer
        self.first_layer1 = nn.Conv2d(in_f, en_sz[0][0], kernel_size=3, stride=1, padding=1)
        self.first_layer2 = nn.Conv2d(in_f, en_sz[0][0], kernel_size=3, stride=1, padding=1)
        
        # Dense Connection Encode layer
        self.dense_layer1 = nn.ModuleList([ MyDenseLayer(en_sz[i][0],en_sz[i][1],en_sz[i][2], 
                                                         kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(en_sz))] )
        self.dense_layer2 = nn.ModuleList([ MyDenseLayer(en_sz[i][0],en_sz[i][1],en_sz[i][2], 
                                                         kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(en_sz))] )
        # Pooling2D
        self.pool2d = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, return_indices=True)
        #self.adaptpool2d = nn.ModuleList([ nn.AdaptiveMaxPool2d((x,x)) for x in self.res4])
        
        # Decode layer for 2D
        self.decode_layer1 = nn.ModuleList([ MyDecoder(self.de_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(self.de_sz))] )
        self.decode_layer2 = nn.ModuleList([ MyDecoder(self.de_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(self.de_sz))] )

        # UpConv2d for Decoder
        self.upconv2d1 = nn.ModuleList( [ UpConv2d([self.de_sz[i][-1],self.de_sz[i][-1]], self.res1[i], 
                                                   kernel_size=2, stride=2, padding=0) 
                                          for i in range(len(self.de_sz)) ] )
        self.upconv2d2 = nn.ModuleList( [ UpConv2d([self.de_sz[i][-1],self.de_sz[i][-1]], self.res1[i], 
                                                   kernel_size=2, stride=2, padding=0) 
                                          for i in range(len(self.de_sz)) ] )
        
        # for 3D connection pyramid for Fusion
        #   1) for concatenation and averaging
        self.final_layer = nn.ModuleList([ MyDecoder3d(self.final_sz[i], self.res1[i], 
                                                       kernel_size=3, stride=1, padding=1) 
                                         for i in range(len(self.final_sz))])
        #   2) Custom {kernel_size, stride, padding}
        '''self.final_layer = nn.ModuleList([ MyDecoder3d(self.final_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1 ) 
                                         for i in range(len(self.final_sz)-1)])
        self.final_layer.append( MyDecoder3d(self.final_sz[-1], self.res1[-1], kernel_size=1, stride=1, padding=0) )'''
        #   3) Custom last final_layer
        '''self.final_layer = nn.ModuleList([ MyDecoder3d(self.final_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1 ) 
                                         for i in range(len(self.final_sz)-1)])
        last_final_layer = nn.Sequential(nn.Conv3d(self.final_sz[-1][0],self.final_sz[-1][1],kernel_size=2,stride=1,padding=1),
                                         nn.InstanceNorm3d(self.final_sz[-1][1]),
                                         nn.ReLU(),
                                         nn.Conv3d(self.final_sz[-1][1],self.final_sz[-1][2],kernel_size=2,stride=1,padding=0),
                                         nn.InstanceNorm3d(self.final_sz[-1][2]),
                                         nn.ReLU(),
                                        )
        self.final_layer.append( last_final_layer )'''
        
        # UpConv3d for Fusion
        self.upconv3d = nn.ModuleList( [ UpConv3d([self.final_sz[i][-1], 3], self.res2[i], 
                                                  kernel_size=2, stride=2, padding=0) 
                                          for i in range(len(self.de_sz)-1) ] )   # upsample to 1-volume
        '''self.upconv3d = nn.ModuleList( [ UpConv3d( [self.final_sz[i][-1],4], self.res2[i], 
                                                      kernel_size=2, stride=2, padding=0) 
                                          for i in range(len(self.de_sz)-1) ] )  # upsample to N-sub-volumes'''
        # Final classification by Conv3d(1x1x1)
        self.final_layer2 = nn.Sequential(nn.Conv3d(self.final_sz[-1][-1], 3, kernel_size=1, stride=1, padding=0),
                                          nn.Softmax(dim=1) )  # Change output channel to {2,3} upto ToTensor version
        
    def forward(self, x1, x2):
        #print('\n Running Recon2X3D \n')
        #print('\n--- Encode loop ---')
        x1 = self.first_layer1(x1)
        x2 = self.first_layer2(x2)
        encode_trace1 = []
        encode_trace2 = []
        i = 0
        for layer1 , layer2 in zip(self.dense_layer1 , self.dense_layer2):   # for loop on each nn.Sequential layer
            x1 = layer1(x1)         # layer = each dense blocks
            x2 = layer2(x2)         # layer = each dense blocks
            #print('Encode1 layer:{}   |  h1 = {}   {}   {}'.format(i, x1.size(), x1.dtype, x1.get_device()))
            #print('Encode2 layer:{}   |  h2 = {}   {}   {}'.format(i, x2.size(), x2.dtype, x2.get_device()))
            encode_trace1.append(x1)    # trace x for concatenation with decode layer
            encode_trace2.append(x2)    # trace x for concatenation with decode layer
            #encode_feat1.append(x1)
            #encode_feat2.append(x2)
            if i != len(self.dense_layer1)-1:
                x1 , _ = self.pool2d(x1)        # for MaxPool2d or AvgPool2d
                x2 , _ = self.pool2d(x2)        # for MaxPool2d or AvgPool2d
                #x1 = self.adaptpool2d[i](x1)        # for AdaptivePooling
                #x2 = self.adaptpool2d[i](x2)        # for AdaptivePooling
                #print('                  |  h1_pool = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
                #print('                  |  h2_pool = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
            i += 1
        #[print('encode_trace1 = {}'.format(x.size())) for x in encode_trace1]
        #[print('encode_trace2 = {}'.format(x.size())) for x in encode_trace2]
        
        
        #print('\n--- Decode loop and Fusion loop ---')
        res = [ round(x1[-1]/x2) for x1,x2 in zip(self.de_sz,self.res1) ]
        #print('inside loop res = {}'.format(res))
        i = 0
        for layer1,layer2,fusionlayer in zip(self.decode_layer1, self.decode_layer2, self.final_layer):
            #print('Level: {}'.format(i))
            if i==0:
                x1 = encode_trace1[len(self.dense_layer1)-1-i]
                x2 = encode_trace2[len(self.dense_layer2)-1-i]
            else:   # encoder and decoder fusion
                x1 = torch.cat( (encode_trace1[len(self.dense_layer1)-1-i],x1), dim=1)
                x2 = torch.cat( (encode_trace2[len(self.dense_layer1)-1-i],x2), dim=1)
            N,C,H,W = x1.size()
            #print('            |  x1_cat = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
            #print('            |  x2_cat = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
            x1 = layer1(x1)
            x2 = layer2(x2)
            #print('            |  x1_conv2d = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
            #print('            |  x2_conv2d = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
            
            # extract feature come out from final conv2d encoder-decoder
            #decode_feat1.append(x1)
            #decode_feat2.append(x2)
            
            # Covert to 3D volumes or 5D-Tensor
            X1 = x1.view(-1,res[i],H,H,W)
            X2 = x2.view(-1,res[i],H,H,W).transpose(4,2).flip(4)
            #print('            |  X1_view3d = {}   {}   {}'.format(X1.size(), X1.dtype, X1.get_device()))
            #print('            |  X2_view3d = {}   {}   {}'.format(X2.size(), X2.dtype, X2.get_device()))
            
            # 3D volume fusion (averaging)
            '''if i==0:
                X = (X1+X2)/2
            else:
                X = (X1+X2+X)/3'''
                
            # 3D volume fusion (concatenation)
            if i==0:
                #X = torch.cat((X1,X2,(X1*X2)), dim=1)   # not working  21011802
                #X = torch.cat((X1,X2,(X1+X2)/2,(X1*X2)), dim=1)
                #X = torch.cat(((X1+X2)/2,(X1*X2)), dim=1)
                #X = torch.cat((X1**2,X2**2,(X1+X2)/2,(X1*X2)), dim=1)
                X = torch.cat((X1,X2), dim=1)
            else:
                #X = torch.cat((X1,X2,(X1*X2),X), dim=1)  # not working  21011802
                #X = (torch.cat((X1,X2,(X1+X2)/2,(X1*X2)), dim=1) + X)/2
                #X = (torch.cat(((X1+X2)/2,(X1*X2)), dim=1) + X )/2
                #X = (torch.cat((X1**2,X2**2,(X1+X2)/2,(X1*X2)), dim=1) + X)/2
                X = torch.cat((X1,X2,X), dim=1)
            
            #print('            |  X_cat3d = {}   {}   {}'.format(X.size(), X.dtype, X.get_device()))
            X = fusionlayer(X)
            #print('            |  X_fusion3d = {}   {}   {}'.format(X.size(), X.dtype, X.get_device()))
            #fusion_feat.append(X)
            if i!=len(self.decode_layer1)-1:
                x1 = self.upconv2d1[i](x1)
                x2 = self.upconv2d1[i](x2)
                X = self.upconv3d[i](X)
                #fusionUp_feat.append(X)
                #print('               |  x1_upconv2d = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
                #print('               |  x2_upconv2d = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
                #print('               |  X_upconv3d = {}   {}   {}'.format(X.size(), X.dtype, X.get_device()))
            i+=1
        #print('\n--- Final classify ---')
        X = self.final_layer2(X)
        #print('   X final = {}   {}   {} \n'.format(X.size(), X.dtype, X.get_device()))
        return X
        
def test_Recon2X3D5():
    torch.cuda.empty_cache()
    #en_sz = [[in_f,res,k,layer],[k*layer,res2,k,layer]]
    in_c = 1
    # Averaging AP & LAT features
    '''en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
    de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
    final_sz = [[1,32,32],[1,32,32],[1,32,32],[1,16,16],[1,16,16],[1,16,16],[1,16,16]] # 2-fusion layer'''
    
    # Concatenation AP & LAT features
    en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
    de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]   # work
    final_sz = [[2,16,16],[3,16,16],[3,16,16,16],[3,16,16],[3,16,16],[3,16,16],[3,16,16]]
    
    # Concatenation from AP & LAT sub-feature-volumes
    '''en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
    de_sz = [[64,4*8],[128,8*8],[128,16*4],[256,32*4],[256,64*2],[256,128*2],[256,256]] 
    final_sz = [[16,32,32],[16+4,32,32],[8+4,32,32],[8+4,16,16],[4+4,16,16],[4+4,16,16],[2+4,16,16]]'''
    
    de3d_sz = [[8,4],[16,8],[32,16],[64,32],[128,64],[256,128],[512,256]]
    model = Recon2X3D5( in_c, en_sz, de_sz, de3d_sz, final_sz).to(device=device)
    model = nn.DataParallel(model)
    #print(model,'\n')
    #print(model.__dict__.keys())
    n = 3
    x1 = torch.randn(n,1,256,256).to(device=device)
    x2 = torch.randn(n,1,256,256).to(device=device)
    #print('x = {}'.format(x.size()))
    time1 = time.time()
    with torch.cuda.amp.autocast(enabled=True):
        output = model(x1,x2)
        #torch.cuda.synchronize()
    time2 = time.time()
    print('\n Total running time = {} sec. \n'.format(time2-time1))
    assert output.size()==torch.Size([n,3,256,256,256]) , 'output size error!'
    
#test_Recon2X3D5()
print('--- END ---')

### Recon2X3D6 (Axial + Fusion)

In [None]:
class Recon2X3D6(nn.Module):   # Signle GPU  Convolute without Dilation
    def __init__(self, in_f, en_sz, de_sz, de3d_sz, final_sz, *args, **kwargs):
        super(Recon2X3D6, self).__init__()
        assert len(en_sz)== len(de_sz) , 'These input {en_sz} and {de_sz} can not build Recon3DDenseUNet'
        
        # Prepare feature resolution
        input_res = 256
        self.res1 = [ round(input_res*0.5**i) for i in range(len(en_sz)) ]
        self.res2 = [ round(input_res*0.5**i) for i in range(len(en_sz)) ]
        self.res3 = [ round(input_res*0.5**i) for i in range(len(en_sz)) ]
        self.res1.reverse()
        self.res2.pop(-1)
        self.res2.reverse()
        self.res3.pop(-1)
        self.res3.pop(-1)
        self.res3.reverse()
        self.res4 = self.res1.copy()
        self.res4.reverse()
        self.res4.pop(0)
        '''print('self.res1 = {}'.format(self.res1))
        print('self.res2 = {}'.format(self.res2))
        print('self.res3 = {}'.format(self.res3))
        print('self.res4 = {}\n'.format(self.res4))'''
        
        self.en_sz = en_sz
        cat1 = [ x[0]+x[1]*x[2] for x in en_sz]
        cat1.reverse()
        cat1.pop(0)
        #cat2 = [ en_sz[-1][0]+en_sz[-1][1]*en_sz[-1][2] , *[ x[-1] for x in de_sz ] ]
        cat2 = [ x[-1] for x in de_sz ]
        cat = [ x1+x2 for x1,x2 in zip(cat1,cat2) ]
        cat = [en_sz[-1][0]+en_sz[-1][1]*en_sz[-1][2], *cat]
        '''print('cat1 = {}'.format(cat1))
        print('cat2 = {}'.format(cat2))
        print('cat = {}\n'.format(cat))'''
        
        self.de_sz = [ [x1,*x2] for x1,x2 in zip(cat,de_sz)]
        self.de3d_sz = de3d_sz
        self.final_sz = final_sz
        
        ############################################
        # Class Layer description : MyDenseLayer >> MyDecoder >> UpConv2d >> MyDecoder3d >> UpConv3d >> nn.Conv3d
        # Starting layer
        self.first_layer1 = nn.Conv2d(in_f, en_sz[0][0], kernel_size=3, stride=1, padding=1)
        self.first_layer2 = nn.Conv2d(in_f, en_sz[0][0], kernel_size=3, stride=1, padding=1)
        
        # Dense Connection Encode layer
        self.dense_layer1 = nn.ModuleList([ MyDenseLayer(en_sz[i][0],en_sz[i][1],en_sz[i][2], 
                                                         kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(en_sz))] )
        self.dense_layer2 = nn.ModuleList([ MyDenseLayer(en_sz[i][0],en_sz[i][1],en_sz[i][2], 
                                                         kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(en_sz))] )
        # Pooling2D
        self.pool2d = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, return_indices=True)
        #self.adaptpool2d = nn.ModuleList([ nn.AdaptiveMaxPool2d((x,x)) for x in self.res4])
        
        # Decode layer for 2D
        self.decode_layer1 = nn.ModuleList([ MyDecoder(self.de_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(self.de_sz))] )
        self.decode_layer2 = nn.ModuleList([ MyDecoder(self.de_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(self.de_sz))] )

        # UpConv2d for Decoder
        self.upconv2d1 = nn.ModuleList( [ UpConv2d([self.de_sz[i][-1],self.de_sz[i][-1]], self.res1[i], 
                                                   kernel_size=2, stride=2, padding=0) 
                                          for i in range(len(self.de_sz)) ] )
        self.upconv2d2 = nn.ModuleList( [ UpConv2d([self.de_sz[i][-1],self.de_sz[i][-1]], self.res1[i], 
                                                   kernel_size=2, stride=2, padding=0) 
                                          for i in range(len(self.de_sz)) ] )
        
        # AxialFusion
        self.axialD = nn.ModuleList([ AxialFusion(self.de3d_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                      for i in range(len(self.de3d_sz))])
        self.axialH = nn.ModuleList([ AxialFusion(self.de3d_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                      for i in range(len(self.de3d_sz))])
        self.axialW = nn.ModuleList([ AxialFusion(self.de3d_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                      for i in range(len(self.de3d_sz))])
        
        # for 3D connection pyramid for 3D-Fusion
        self.final_layer = nn.ModuleList([ MyDecoder3d(self.final_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                          for i in range(len(self.final_sz))])
        
        # UpConv3d for 3D-Fusion
        self.upconv3d = nn.ModuleList( [ UpConv3d([self.final_sz[i][-1], 3], self.res2[i], 
                                                   kernel_size=2, stride=2, padding=0) 
                                          for i in range(len(self.de_sz)-1) ] )   # upsample to 1-volume
        '''self.upconv3d = nn.ModuleList( [ UpConv3d( [self.final_sz[i][-1],4], self.res2[i], 
                                                      kernel_size=2, stride=2, padding=0) 
                                           for i in range(len(self.de_sz)-1) ] )   # upsample to N-sub-volumes'''
                                         
        # Final classification by Conv3d(1x1x1)
        self.final_layer2 = nn.Sequential(nn.Conv3d(self.final_sz[-1][-1], 3, kernel_size=1, stride=1, padding=0),
                                          nn.Softmax(dim=1) )  # Change output channel to {2,3} upto ToTensor version
        
    def forward(self, x1, x2):
        #print('\n Running Recon2X3D6 \n')
        #print('\n--- Encode loop ---')
        x1 = self.first_layer1(x1)
        x2 = self.first_layer2(x2)
        encode_trace1 = []
        encode_trace2 = []
        i = 0
        for layer1 , layer2 in zip(self.dense_layer1 , self.dense_layer2):   # for loop on each nn.Sequential layer
            x1 = layer1(x1)         # layer = each dense blocks
            x2 = layer2(x2)         # layer = each dense blocks
            #print('Encode1 layer:{}   |  h1 = {}   {}   {}'.format(i, x1.size(), x1.dtype, x1.get_device()))
            #print('Encode2 layer:{}   |  h2 = {}   {}   {}'.format(i, x2.size(), x2.dtype, x2.get_device()))
            encode_trace1.append(x1)    # trace x for concatenation with decode layer
            encode_trace2.append(x2)    # trace x for concatenation with decode layer
            if i != len(self.dense_layer1)-1:
                x1 , _ = self.pool2d(x1)        # for MaxPool2d or AvgPool2d
                x2 , _ = self.pool2d(x2)        # for MaxPool2d or AvgPool2d
                #x1 = self.adaptpool2d[i](x1)        # for AdaptivePooling
                #x2 = self.adaptpool2d[i](x2)        # for AdaptivePooling
                #print('                  |  h1_pool = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
                #print('                  |  h2_pool = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
            i += 1
        #[print('encode_trace1 = {}'.format(x.size())) for x in encode_trace1]
        #[print('encode_trace2 = {}'.format(x.size())) for x in encode_trace2]
        
        #print('\n--- Decode loop and Fusion loop ---')
        res = [ round(x1[-1]/x2) for x1,x2 in zip(self.de_sz,self.res1) ]
        #print('inside loop res = {}'.format(res))
        i = 0
        for layer1,layer2,axialD,axialH,axialW,fusion3d in zip(self.decode_layer1, self.decode_layer2,
                                                               self.axialD ,self.axialH, self.axialW,
                                                               self.final_layer):
            #print('Level: {}'.format(i))
            if i==0:
                x1 = encode_trace1[len(self.dense_layer1)-1-i]
                x2 = encode_trace2[len(self.dense_layer2)-1-i]
            else:   # encoder and decoder fusion
                x1 = torch.cat( (encode_trace1[len(self.dense_layer1)-1-i] , x1) , dim=1)
                x2 = torch.cat( (encode_trace2[len(self.dense_layer1)-1-i] , x2) , dim=1)
            N,C,H,W = x1.size()
            #print('            |  x1_cat = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
            #print('            |  x2_cat = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
            x1 = layer1(x1)
            x2 = layer2(x2)
            #print('            |  x1_conv2d = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
            #print('            |  x2_conv2d = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
            
            # extract feature come out from final conv2d encoder-decoder
            #feature1.append(x1)
            #feature2.append(x2)
            
            # Axial fusion along D, H, W
            xd = torch.cat( (x1,x2.transpose(3,1).flip(3)), dim=1 )                  # N,2C,H,W >>[Conv2d]>> N,C,H,W
            xh = torch.cat( (x1,x2.transpose(3,1).flip(3)), dim=2 ).transpose(2,1)   # N,C,2H,W >> N,2H,C,W >>[Conv2d]>> N,H,C,W >> N,C,H,W
            xw = torch.cat( (x1,x2.transpose(3,1).flip(3)), dim=3 ).transpose(3,1)   # N,C,H,2W >> N,C,H,2W >>[Conv2d]>> N,W,H,C >> N,C,H,W
            #print('            |  xd_cat = {}   {}   {}'.format(xd.size(), xd.dtype, xd.get_device()))
            #print('            |  xh_cat = {}   {}   {}'.format(xh.size(), xh.dtype, xh.get_device()))
            #print('            |  xw_cat = {}   {}   {}'.format(xw.size(), xw.dtype, xw.get_device()))
            xd = axialD(xd)
            xh = axialH(xh).transpose(2,1)    # axialD(xh).transpose(2,1)
            xw = axialW(xw).transpose(3,1)    # axialD(xh).transpose(3,1)
            #print('            |  xd_axial = {}   {}   {}'.format(xd.size(), xd.dtype, xd.get_device()))
            #print('            |  xh_axial = {}   {}   {}'.format(xh.size(), xh.dtype, xh.get_device()))
            #print('            |  xw_axial = {}   {}   {}'.format(xw.size(), xw.dtype, xw.get_device()))
            
            # Covert to 3D volumes
            XD = xd.view(-1, res[i], H, H, W)
            XH = xh.view(-1, res[i], H, H, W)
            XW = xw.view(-1, res[i], H, H, W)
            #print('            |  XD_view = {}   {}   {}'.format(XD.size(), XD.dtype, XD.get_device()))
            #print('            |  XH_view = {}   {}   {}'.format(XH.size(), XH.dtype, XH.get_device()))
            #print('            |  XW_view = {}   {}   {}'.format(XW.size(), XW.dtype, XW.get_device()))
            
            # Fusion
            #print('\n Fusion loop')
            if i == 0:
                X = torch.cat( (XD,XH,XW), dim=1 )                # cat everything  and  # avg of cat
                #X = (XD + XH + XW)/3                              # avg everything
            else:
                X = torch.cat( (XD,XH,XW,X), dim=1 )               # cat everything
                #X = (X + XD + XH + XW)/4                           # avg everything
                #X = ( torch.cat( (XD,XH,XW), dim=1 ) + X )/2       # avg of cat
            #print('            |  X_cat3d = {}   {}   {}'.format(X.size(), X.dtype, X.get_device()))
            
            X = fusion3d(X)
            #print('            |  X_fusion3d = {}   {}   {}'.format(X.size(), X.dtype, X.get_device()))
            if i!=len(self.decode_layer1)-1:
                x1 = self.upconv2d1[i](x1)
                x2 = self.upconv2d1[i](x2)
                X = self.upconv3d[i](X)
                #print('               |  x1_upconv2d = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
                #print('               |  x2_upconv2d = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
                #print('               |  X_upconv3d = {}   {}   {}'.format(X.size(), X.dtype, X.get_device()))
            i+=1
            
        ##print('\n--- Final classify ---')
        X = self.final_layer2(X)
        #print('   X final = {}   {}   {} \n'.format(X.size(), X.dtype, X.get_device()))
        return X
        
def test_Recon2X3D6():
    torch.cuda.empty_cache()
    # Axial + Fusion of AP & LAT features
    in_c = 1
    en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
    de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
    de3d_sz = [[8,8,4],[16,16,8],[32,32,16],[64,64,32],[128,128,64],[256,256,128],[512,512,256]] # 2-axial layer per level
    final_sz = [[3,32,32],[6,32,32],[6,32,32],[6,16,16],[6,16,16],[6,16,16],[6,16,16]] # 2-fusion layer per level
    model = Recon2X3D6( in_c, en_sz, de_sz, de3d_sz, final_sz).to(device=device)
    #model = nn.DataParallel(model)
    #print(model,'\n')
    #print(model.__dict__.keys())
    n = 1
    x1 = torch.randn(n,1,256,256).to(device=device)
    x2 = torch.randn(n,1,256,256).to(device=device)
    time1 = time.time()
    with torch.cuda.amp.autocast(enabled=True):
        output = model(x1,x2)
        #torch.cuda.synchronize()
    time2 = time.time()
    assert output.size()==torch.Size([n,3,256,256,256]) , 'output size error!'
    print('\n Total running time = {} sec. \n'.format(time2-time1))
    
#test_Recon2X3D6()
print('--- END ---')

### Recon2X3D7 (Axial+Inception+Fusion)

In [None]:
class Recon2X3D7(nn.Module):
    ##### AxialInceptionFusion of AP & LAT features #####
    def __init__(self, in_f, en_sz, de_sz, de3d_sz, final_sz, *args, **kwargs):
        super(Recon2X3D7, self).__init__()
        assert len(en_sz)== len(de_sz) , 'These input {en_sz} and {de_sz} can not build Recon3DDenseUNet'
        
        # Prepare feature resolution
        input_res = 256
        self.res1 = [ round(input_res*0.5**i) for i in range(len(en_sz)) ]
        self.res2 = [ round(input_res*0.5**i) for i in range(len(en_sz)) ]
        self.res3 = [ round(input_res*0.5**i) for i in range(len(en_sz)) ]
        self.res1.reverse()
        self.res2.pop(-1)
        self.res2.reverse()
        self.res3.pop(-1)
        self.res3.pop(-1)
        self.res3.reverse()
        self.res4 = self.res1.copy()
        self.res4.reverse()
        self.res4.pop(0)
        print('self.res1 = {}'.format(self.res1))
        print('self.res2 = {}'.format(self.res2))
        print('self.res3 = {}'.format(self.res3))
        print('self.res4 = {}\n'.format(self.res4))
        
        self.en_sz = en_sz
        #print('en_sz = {}\n'.format(self.en_sz))
        cat1 = [ x[0]+x[1]*x[2] for x in en_sz]
        cat1.reverse()
        cat1.pop(0)
        #cat2 = [ en_sz[-1][0]+en_sz[-1][1]*en_sz[-1][2] , *[ x[-1] for x in de_sz ] ]
        cat2 = [ x[-1] for x in de_sz ]
        cat = [ x1+x2 for x1,x2 in zip(cat1,cat2) ]
        cat = [en_sz[-1][0]+en_sz[-1][1]*en_sz[-1][2], *cat]
        '''print('cat1 = {}'.format(cat1))
        print('cat2 = {}'.format(cat2))
        print('cat = {}\n'.format(cat))'''
        
        self.de_sz = [ [x1,*x2] for x1,x2 in zip(cat,de_sz)]
        #print('self.de_sz = {}\n'.format(self.de_sz))
        self.de3d_sz = de3d_sz
        print('self.de3d_sz = {}\n'.format(self.de3d_sz))
        self.final_sz = final_sz
        #print('self.final_sz = {}\n'.format(self.final_sz))
        
        ############################################
        # Class Layer description : MyDenseLayer >> MyDecoder >> UpConv2d >> MyDecoder3d >> UpConv3d >> nn.Conv3d
        # Starting layer
        self.first_layer1 = nn.Conv2d(in_f, en_sz[0][0], kernel_size=3, stride=1, padding=1)
        self.first_layer2 = nn.Conv2d(in_f, en_sz[0][0], kernel_size=3, stride=1, padding=1)
        
        # Dense Connection Encode layer
        self.dense_layer1 = nn.ModuleList([ MyDenseLayer(en_sz[i][0],en_sz[i][1],en_sz[i][2], 
                                                         kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(en_sz))] )
        self.dense_layer2 = nn.ModuleList([ MyDenseLayer(en_sz[i][0],en_sz[i][1],en_sz[i][2], 
                                                         kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(en_sz))] )
        # Pooling2D
        self.pool2d = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, return_indices=True)
        #self.adaptpool2d = nn.ModuleList([ nn.AdaptiveMaxPool2d((x,x)) for x in self.res4])
        
        # Decode layer for 2D
        self.decode_layer1 = nn.ModuleList([ MyDecoder(self.de_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(self.de_sz))] )
        self.decode_layer2 = nn.ModuleList([ MyDecoder(self.de_sz[i], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                           for i in range(len(self.de_sz))] )

        # UpConv2d for Decoder
        self.upconv2d1 = nn.ModuleList( [ UpConv2d([self.de_sz[i][-1],self.de_sz[i][-1]], self.res1[i], 
                                                   kernel_size=2, stride=2, padding=0) 
                                          for i in range(len(self.de_sz)) ] )
        self.upconv2d2 = nn.ModuleList( [ UpConv2d([self.de_sz[i][-1],self.de_sz[i][-1]], self.res1[i], 
                                                   kernel_size=2, stride=2, padding=0) 
                                          for i in range(len(self.de_sz)) ] )
        
        # AxialInceptionFusion
        self.inceptPool2d = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.axialD1 = nn.ModuleList([ AxialFusion(self.de3d_sz[i][0:2], self.res1[i], kernel_size=1, stride=1, padding=0) 
                                      for i in range(len(self.de3d_sz))])
        self.axialH1 = nn.ModuleList([ AxialFusion(self.de3d_sz[i][0:2], self.res1[i], kernel_size=1, stride=1, padding=0) 
                                      for i in range(len(self.de3d_sz))])
        self.axialW1 = nn.ModuleList([ AxialFusion(self.de3d_sz[i][0:2], self.res1[i], kernel_size=1, stride=1, padding=0) 
                                      for i in range(len(self.de3d_sz))])
        self.axialD3 = nn.ModuleList([ AxialFusion(self.de3d_sz[i][0:2], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                      for i in range(len(self.de3d_sz))])
        self.axialH3 = nn.ModuleList([ AxialFusion(self.de3d_sz[i][0:2], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                      for i in range(len(self.de3d_sz))])
        self.axialW3 = nn.ModuleList([ AxialFusion(self.de3d_sz[i][0:2], self.res1[i], kernel_size=3, stride=1, padding=1) 
                                      for i in range(len(self.de3d_sz))])
        self.axialD5 = nn.ModuleList([ AxialFusion(self.de3d_sz[i][0:2], self.res1[i], kernel_size=5, stride=1, padding=2) 
                                      for i in range(len(self.de3d_sz))])
        self.axialH5 = nn.ModuleList([ AxialFusion(self.de3d_sz[i][0:2], self.res1[i], kernel_size=5, stride=1, padding=2) 
                                      for i in range(len(self.de3d_sz))])
        self.axialW5 = nn.ModuleList([ AxialFusion(self.de3d_sz[i][0:2], self.res1[i], kernel_size=5, stride=1, padding=2) 
                                      for i in range(len(self.de3d_sz))])
        test = [self.de3d_sz[0][1]*5,self.de3d_sz[0][-1] ]
        print(test)
        self.axialFusionD = nn.ModuleList([ AxialFusion([self.de3d_sz[i][1]*5,self.de3d_sz[i][-1]], self.res1[i], kernel_size=1, stride=1, padding=0) 
                                           for i in range(len(self.de3d_sz)) ] )
        self.axialFusionH = nn.ModuleList([ AxialFusion([self.de3d_sz[i][1]*5,self.de3d_sz[i][-1]], self.res1[i], kernel_size=1, stride=1, padding=0)
                                           for i in range(len(self.de3d_sz)) ] )
        self.axialFusionW = nn.ModuleList([ AxialFusion([self.de3d_sz[i][1]*5,self.de3d_sz[i][-1]], self.res1[i], kernel_size=1, stride=1, padding=0) 
                                           for i in range(len(self.de3d_sz)) ] )
        
        # for 3D connection pyramid for 3D-Fusion
        self.final_layer = nn.ModuleList([ MyDecoder3d(self.final_sz[i], self.res1[i], kernel_size=1, stride=1, padding=0) 
                                          for i in range(len(self.final_sz))])
        
        # UpConv3d for 3D-Fusion
        self.upconv3d = nn.ModuleList( [ UpConv3d([self.final_sz[i][-1], 3], self.res2[i], 
                                                  kernel_size=2, stride=2, padding=0) 
                                          for i in range(len(self.de_sz)-1) ] )   # upsample to 1-volume
        '''self.upconv3d = nn.ModuleList( [ UpConv3d( [self.final_sz[i][-1],4], self.res2[i], 
                                                      kernel_size=2, stride=2, padding=0) 
                                           for i in range(len(self.de_sz)-1) ] )   # upsample to N-sub-volumes'''
                                         
        # Final classification by Conv3d(1x1x1)
        self.final_layer2 = nn.Sequential(nn.Conv3d(self.final_sz[-1][-1], 3, kernel_size=1, stride=1, padding=0),
                                          nn.Softmax(dim=1) )  # Change output channel to {2,3} upto ToTensor version
        

    def forward(self, x1, x2):
        #print('\n Running Recon2X3D7 \n')
        #print('\n--- Encode loop ---')
        x1 = self.first_layer1(x1)
        x2 = self.first_layer2(x2)
        encode_trace1 = []
        encode_trace2 = []
        i = 0
        for layer1 , layer2 in zip(self.dense_layer1 , self.dense_layer2):   # for loop on each nn.Sequential layer
            x1 = layer1(x1)         # layer = each dense blocks
            x2 = layer2(x2)         # layer = each dense blocks
            #print('Encode1 layer:{}   |  h1 = {}   {}   {}'.format(i, x1.size(), x1.dtype, x1.get_device()))
            #print('Encode2 layer:{}   |  h2 = {}   {}   {}'.format(i, x2.size(), x2.dtype, x2.get_device()))
            encode_trace1.append(x1)    # trace x for concatenation with decode layer
            encode_trace2.append(x2)    # trace x for concatenation with decode layer
            if i != len(self.dense_layer1)-1:
                x1 , _ = self.pool2d(x1)        # for MaxPool2d or AvgPool2d
                x2 , _ = self.pool2d(x2)        # for MaxPool2d or AvgPool2d
                #x1 = self.adaptpool2d[i](x1)        # for AdaptivePooling
                #x2 = self.adaptpool2d[i](x2)        # for AdaptivePooling
                #print('                  |  h1_pool = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
                #print('                  |  h2_pool = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
            i += 1
        #[print('encode_trace1 = {}'.format(x.size())) for x in encode_trace1]
        #[print('encode_trace2 = {}'.format(x.size())) for x in encode_trace2]
        
        #print('\n--- Decode loop and Fusion loop ---')
        res = [ round(x1[-1]/x2) for x1,x2 in zip(self.de_sz,self.res1) ]
        #print('inside loop res = {}'.format(res))
        i = 0
        for layer1,layer2,axialD1,axialH1,axialW1,axialD3,axialH3,axialW3,axialD5,axialH5,axialW5,axialFuseD,axialFuseH,axialFuseW,fusion3dlayer in zip(self.decode_layer1, self.decode_layer2,
                                                                                                                                                        self.axialD1, self.axialH1, self.axialW1,
                                                                                                                                                        self.axialD3, self.axialH3, self.axialW3,
                                                                                                                                                        self.axialD5, self.axialH5, self.axialW5,
                                                                                                                                                        self.axialFusionD,self.axialFusionH,self.axialFusionW,
                                                                                                                                                        self.final_layer):
            #print('Level: {}'.format(i))
            if i==0:
                x1 = encode_trace1[len(self.dense_layer1)-1-i]
                x2 = encode_trace2[len(self.dense_layer2)-1-i]
            else:   # encoder and decoder fusion
                x1 = torch.cat( (encode_trace1[len(self.dense_layer1)-1-i] , x1) , dim=1)
                x2 = torch.cat( (encode_trace2[len(self.dense_layer1)-1-i] , x2) , dim=1)
            N,C,H,W = x1.size()
            #print('            |  x1_cat = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
            #print('            |  x2_cat = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
            x1 = layer1(x1)
            x2 = layer2(x2)
            #print('            |  x1_conv2d = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
            #print('            |  x2_conv2d = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
            
            # extract feature come out from final conv2d encoder-decoder
            #feature1.append(x1)
            #feature2.append(x2)
            
            # Axial + Inception + Fusion along D, H, W
            xd = torch.cat( (x1,x2.transpose(3,1).flip(3)), dim=1 )                     # N,2C,H,W >>[Conv2d]>> N,C,H,W
            xh = torch.cat( (x1,x2.transpose(3,1).flip(3)), dim=2 ).transpose(2,1)      # N,C,2H,W >> N,2H,C,W >>[Conv2d]>> N,H,C,W >> N,C,H,W
            xw = torch.cat( (x1,x2.transpose(3,1).flip(3)), dim=3 ).transpose(3,1)      # N,C,H,2W >> N,C,H,2W >>[Conv2d]>> N,W,H,C >> N,C,H,W
            #print('            |  xd_cat = {}   {}   {}'.format(xd.size(), xd.dtype, xd.get_device()))
            #print('            |  xh_cat = {}   {}   {}'.format(xh.size(), xh.dtype, xh.get_device()))
            #print('            |  xw_cat = {}   {}   {}'.format(xw.size(), xw.dtype, xw.get_device()))
            xd = torch.cat([self.inceptPool2d(xd),axialD1(xd),axialD3(xd),axialD5(xd)], dim=1)
            xh = torch.cat([self.inceptPool2d(xh),axialH1(xh),axialH3(xh),axialH5(xh)], dim=1)
            xw = torch.cat([self.inceptPool2d(xw),axialW1(xw),axialW3(xw),axialW5(xw)], dim=1)
            #print('            |  xd_axial_inception = {}   {}   {}'.format(xd.size(), xd.dtype, xd.get_device()))
            #print('            |  xh_axial_inception = {}   {}   {}'.format(xh.size(), xh.dtype, xh.get_device()))
            #print('            |  xw_axial_inception = {}   {}   {}'.format(xw.size(), xw.dtype, xw.get_device()))
            xd = axialFuseD(xd)
            xh = axialFuseH(xh).transpose(2,1)
            xw = axialFuseW(xw).transpose(3,1)
            #print('            |  xd_fusion2d = {}   {}   {}'.format(xd.size(), xd.dtype, xd.get_device()))
            #print('            |  xh_fusion2d = {}   {}   {}'.format(xh.size(), xh.dtype, xh.get_device()))
            #print('            |  xw_fusion2d = {}   {}   {}'.format(xw.size(), xw.dtype, xw.get_device()))
            # Covert to 3D volumes
            XD = xd.view(-1, res[i], H, H, W)
            XH = xh.view(-1, res[i], H, H, W)
            XW = xw.view(-1, res[i], H, H, W)
            #print('            |  XD_view3d = {}   {}   {}'.format(XD.size(), XD.dtype, XD.get_device()))
            #print('            |  XH_view3d = {}   {}   {}'.format(XH.size(), XH.dtype, XH.get_device()))
            #print('            |  XW_view3d = {}   {}   {}'.format(XW.size(), XW.dtype, XW.get_device()))
            
            if i == 0:
                X = torch.cat( (XD,XH,XW), dim=1 )                # cat everything  and  # avg of cat
            else:
                X = torch.cat( (XD,XH,XW,X), dim=1 )               # cat everything
            #print('            |  X_cat3d = {}   {}   {}'.format(X.size(), X.dtype, X.get_device()))
            
            X = fusion3dlayer(X)
            #print('            |  X_fusion3d = {}   {}   {}'.format(X.size(), X.dtype, X.get_device()))
            if i!=len(self.decode_layer1)-1:
                x1 = self.upconv2d1[i](x1)
                x2 = self.upconv2d1[i](x2)
                X = self.upconv3d[i](X)
                #print('               |  x1_upconv2d = {}   {}   {}'.format(x1.size(), x1.dtype, x1.get_device()))
                #print('               |  x2_upconv2d = {}   {}   {}'.format(x2.size(), x2.dtype, x2.get_device()))
                #print('               |  X_upconv3d = {}   {}   {}'.format(X.size(), X.dtype, X.get_device()))
            i+=1
            
        #print('\n--- Final classify ---')
        X = self.final_layer2(X)
        #print('   X final = {}   {}   {} \n'.format(X.size(), X.dtype, X.get_device()))
        return X
        
def test_Recon2X3D7():
    torch.cuda.empty_cache()
    # AxialInceptionFusion of AP & LAT features
    in_c = 1
    en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4]]
    de_sz = [[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
    de3d_sz = [[16,8,8],[32,16,16],[64,32,32],[128,64,64],[256,128,128],[512,256,256]] # 2-axial layer per level
    final_sz = [[3,16,16],[6,16,16],[6,16,16],[6,16,16],[6,16,16],[6,16,16]] # 2-fusion layer per level
    model = Recon2X3D7( in_c, en_sz, de_sz, de3d_sz, final_sz).to(device=device)
    #model = nn.DataParallel(model)
    print(model,'\n')
    #print(model.__dict__.keys())
    n = 1
    x1 = torch.randn(n,1,256,256).to(device=device)
    x2 = torch.randn(n,1,256,256).to(device=device)
    time1 = time.time()
    with torch.cuda.amp.autocast(enabled=True):
        output = model(x1,x2)
        #torch.cuda.synchronize()
    time2 = time.time()
    assert output.size()==torch.Size([n,3,256,256,256]) , 'output size error!'
    print('\n Total running time = {} sec. \n'.format(time2-time1))
    
test_Recon2X3D7()
print('--- END ---')

### Discriminator3D for GAN

In [None]:
class Discriminator3D(nn.Module):
    #def __init__(self, ngpu):
    def __init__(self, nc, ndf, ):
        super(Discriminator3D, self).__init__()
        #self.ngpu = ngpu
        self.nc = nc
        self.ndf = ndf
        self.main = nn.Sequential(
            # input is [bz,nc,256,256,256]
            nn.Conv3d( nc, ndf, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2),
            # [bz,ndf,128,128,128]
            nn.Conv3d(ndf, ndf*2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm3d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2),
            # [bz,ndf*2,64,64,64]
            nn.Conv3d(ndf*2, ndf*4, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm3d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2),
            # [bz,ndf*4,32,32,32]
            nn.Conv3d(ndf*4, ndf*8, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm3d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2),
            # [bz,ndf*8,16,16,16]
            nn.Conv3d(ndf*8, ndf*8, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm3d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2),
            # [bz,ndf*8,8,8,8]
            nn.Conv3d(ndf*8, ndf*16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm3d(ndf*16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2),
            # [bz,ndf*16,4,4,4]
            nn.Conv3d(ndf*16, ndf*32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm3d(ndf*32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2),
            # [bz,ndf*32,2,2,2]
            )
        
        # Flattening
        self.classify = nn.Sequential(nn.Conv3d(ndf*32*2*2*2, 512, kernel_size=1, stride=1, padding=0, bias=False) ,
                                      nn.Conv3d(512, 1, kernel_size=1, stride=1, padding=0, bias=False) ,
                                      nn.Sigmoid())
    def forward(self, input):
        #print(' ### Discriminator3D ###')
        #print('input = {} '.format(input.size()))
        n = input.size(0)
        #print('n = {}'.format(n))
        x = self.main(input)
        #print('x = {} '.format(x.size()))
        x = x.reshape(n, -1,1,1,1)
        #print('x = {} '.format(x.size()))
        x = self.classify(x)
        #print('x = {} '.format(x.size()))
        return x

def test_Discriminator3D():
    #torch.cuda.empty_cache()
    n = 8
    nc = 3
    ndf = 32
    model = Discriminator3D(nc, ndf).to(device=device)
    #print(model,'\n')
    input = torch.randn(n,nc,256,256,256).to(device=device)
    print('input = {}'.format(input.size()))
    time1 = time.time()
    with torch.cuda.amp.autocast(enabled=True):
        output = model(input)
        print('output = {}'.format(output.size()))
    time2 = time.time()
    print('\n Total running time = {} sec. \n'.format(time2 - time1))
    assert output.size()==torch.Size([n,1,1,1,1]) , 'output size error!'
    
#test_Discriminator3D()
print('--- END ---')

## Traing Loop

### Mixed precision - Simply and DataParallel (Version2)

In [None]:
model = None
optimizer = None
scheduler = None
criterion = None

def train_mixed(model, trainLoader, valLoader, optimizer, scheduler, criterion, batch_sz, epochs, saved_name, saved_dict):
    ''' Mixed precision training Version2 
        Build-in model saver, saved_dict
    '''
    print('Train on: {}'.format(device))
    print('Optimizer = {} \n'.format(optimizer))
    scaler = torch.cuda.amp.GradScaler()
    for e in range(epochs):
        time1 = time.time()
        print('----- Epoch = {} ----- # Learning rate = {:.4e}'.format(e+1, optimizer.param_groups[0]["lr"]))
        train_loss, train_acc, val_loss, val_acc = 0, 0, 0, 0   # reset every epoch
        
        for t, train_sample in enumerate(trainLoader):
            #print('trainLoader = {}'.format(len(trainLoader)))
            #print('batch size = {} \n'.format(train_sample['Target'].size()))
            if train_sample['Target'].size(0)%batch_sz != 0 or t==len(trainLoader)-1:  # exclude inequal batch
                print('Final training accuracy = {:.4f}'.format(acc))
                break
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=True):
                model.train(mode=True)    # put model to training mode
                target = train_sample['Target'].to(device=device, dtype=dtype1)
                view1 = train_sample['view1'].to(device=device, dtype=dtype1)
                view2 = train_sample['view2'].to(device=device, dtype=dtype1)
                #print('target autocast = {}   {}\nap autocast = {}   {}'.format(target.size(),target.dtype, view1.size(),view1.dtype))
                
                # calculation of output
                output = model(view1, view2)
                #print('output autocast 1 = {}   {}\n'.format(output.size(),output.dtype))
                
                # Calculate loss of this sample batch
                loss = criterion(output, target.long())          # for ToTensor6-9 with FocalLossMulticlass
                #print('loss = {}   {}'.format(loss,loss.dtype))
                train_loss += loss.item()
                
                # Check accuracy
                acc = iou((output[:,1]>=0.5).float(), (target==1).float())                       # for ToTensor6-9
                #acc2 = iou((output[:,2]>=0.5).float(), (target==2).float())                      # for ToTensor6-9
                #acc2 = hausdorff_voxel((output[:,2]>=0.5).float(), (target==2).float())          # for ToTensor6-9
                train_acc += acc.detach().item()
                
                if t%(round(len(trainLoader)/5)) == 0:
                    #print('Iteration: {}   |   Loss = {:.4f}   |   Accuracy = {:.4f} {:.4f}'.format(t, loss.item(), acc, acc2))
                    print('Iteration: {}   |   Loss = {:.4f}   |   Accuracy = {:.4f}'.format(t, loss.item(), acc))
                if t==len(trainLoader)-1:
                    print('Final training accuracy = {:.4f}'.format(acc))
                
            scaler.scale(loss).backward()
            torch.cuda.synchronize()
            scaler.step(optimizer)
            scaler.update()
            
        # Append Loss and Accuracy history
        train_loss /= round(len(trainLoader))    # mean
        train_acc /= round(len(trainLoader))     # mean
        saved_dict['train_loss_history'].append(train_loss)
        saved_dict['train_acc_history'].append(train_acc)
        print('Training loss = {:.4f}'.format(train_loss))
        print('Training accuracy = {:.4f}'.format(train_acc))
        print('Max. of [Mean Training accuracy] = {:.4f}'.format(max(saved_dict['train_acc_history'])))
        time2 = time.time()
        print('Duration training time = {} Min.\n'.format((time2-time1)/60))
        
###########################################################################################################################
        
        print('### Validation loop ###')
        model.eval()
        time3 = time.time()
        
        for t, val_sample in enumerate(valLoader):
            if val_sample['Target'].size(0)%batch_sz%batch_sz != 0 or t==len(valLoader)-1: # exclude inequal batch
                print('Final validation accuracy = {:,.4f}'.format(acc))
                break
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    target = val_sample['Target'].to(device=device, dtype=dtype1)
                    view1 = val_sample['view1'].to(device=device, dtype=dtype1)
                    view2 = val_sample['view2'].to(device=device, dtype=dtype1)
                    output = model(view1,view2)
                    
                    loss = criterion(output ,target.long())         # for ToTensor6-9
                    val_loss += loss.item()
                    
                    acc = iou((output[:,1]>=0.5).float(), (target==1).float())                        # for ToTensor6-9
                    #acc2 = iou((output[:,2]>=0.5).float(), (target==2).float())                       # for ToTensor6-9
                    #acc2 = hausdorff_voxel((output[:,2]>=0.5).float(), (target==2).float())           # for ToTensor6-9
                    val_acc += acc.detach().item()
                    
                    if t%(round(len(valLoader)/5)) == 0:
                        #print('Iteration: {}   |   Loss = {:,.4f}   |   Accuracy = {:.4f} {:.4f}'.format(t, loss.item(), acc, acc2))
                        print('Iteration: {}   |   Loss = {:.4f}   |   Accuracy = {:.4f}'.format(t, loss.item(), acc))
                    if t==len(valLoader)-1:
                        print('Final validation accuracy = {:,.4f}'.format(acc))
        
        val_loss /= round(len(valLoader))    # mean
        val_acc /= round(len(valLoader))     # mean
        saved_dict['val_loss_history'].append(val_loss)
        saved_dict['val_acc_history'].append(val_acc)
        saved_dict['timestamp'] = str(datetime.datetime.now())     # update timestamp
        torch.cuda.synchronize()
        scheduler.step(val_acc)
        print('# Validation loss = {:.4f}'.format(val_loss))
        print('Validation accuracy = {:.4f}'.format(val_acc))
        print('Max. of [Mean Validation accuracy] = {:.4f}'.format(max(saved_dict['val_acc_history'])))
        time4 = time.time()
        print('Duration validation time = {} Min. \n'.format((time4-time3)/60))
        
        ### Save the best trained model and training history ###
        if val_acc >= max(saved_dict['val_acc_history']):   # save the best trained
            print(' *** Update the best model state dict at epoch: {}  at time: {} ***'.format(e+1, str(datetime.datetime.now())))
            saved_dict['model_state_dict'] = model.module.state_dict()    # update the best model
        else:   # just save training history and keep the best trained model
            print('\n *** Update training history and the last model state dict at epoch = {}  at time: {} ***\n'.format(e+1, str(datetime.datetime.now())))
            # don't update any params in the trained model
            
        ### End of training ###
        ### Save trained parameters at the last epoch and training history, keeping the best trained model   
        saved_dict['optimizer_state_dict'] = optimizer.state_dict()
        saved_dict['scheduler_state_dict'] = scheduler.state_dict()
        torch.save(saved_dict, saved_name)
        print(' *** Saved End ***\n')
        print('-'*120,'\n\n')
    
    return saved_dict
    
print('--- END ---')

In [None]:
### K-fold cross validation ###

# Training and validation loop
num_workers = 4
history = {'train_loss': [], 'test_loss': [],'train_acc':[],'test_acc':[]}
batch_sz = 3
learning_rate = 1e-4             # deflaut = 5e-4
weight = None
epochs = 10

# Select mode
toTensorMode = int(input('Select ToTensor : [8] Auxiliary class  [9] Native = '))
if toTensorMode == 8:
    print('ToTensor8 selected')
    ToTensor = ToTensor8()
    weight = torch.tensor([0.15,0.25,0.6], device=device)     # For ToTensor8
elif toTensorMode == 9:
    print('Totensor9 selected')
    ToTensor = ToTensor9()
    weight = torch.tensor([0.5,0.5], device=device)           # For ToTensor9
elif toTensorMode < 8:
    print('Totensor1-7 selected  change NormalizeSample() and FemurDataset')
    ToTensor = ToTensor7()
else:
    raise ValueError ('Invalid ToTensor input')
print('ToTensor = {}'.format(ToTensor))
criterion = FocalLossMulticlass(weight=weight, gamma=2.0, reduction='sum')

root_dir = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2'

for fold in range(1,6):
    
    # Preparing dataset
    saved_name = 'trained\FracReconNet_Fold{}.pt'.format(fold)
    train_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\With augmentation - Fold{} Training.xlsx'.format(fold)
    test_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\With augmentation - Fold{} Testing.xlsx'.format(fold)
    print('Training on: {}', train_file)
    print('Testing on: {}', test_file)
    train_transformedFemur = FemurDataset2(csv_file=training_file, root_dir=root_dir, 
                                           transform=transforms.Compose([NormalizeSample2(), ToTensor]))
    test_transformedFemur = FemurDataset2(csv_file=test_file, root_dir=root_dir, 
                                          transform=transforms.Compose([NormalizeSample2(), ToTensor]))
    trainLoader = DataLoader(train_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=num_workers)
    testLoader = DataLoader(val_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=num_workers)
    
    # Building model
    in_c = 1
    en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
    de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
    de3d_sz = None
    final_sz = [[2,32,32],[5,32,32],[5,32,32],[5,16,16],[5,16,16],[5,16,16],[5,16,16]]
    saved_dict = {'timestamp':None ,'note':note ,'in_c':in_c ,'en_sz':en_sz,'de_sz':de_sz,'de3d_sz':de3d_sz ,'final_sz':final_sz,
                  'train_loss_history':list(),'train_acc_history':list(),'val_loss_history':list(),'val_acc_history':list(),
                  'model_state_dict':None,'optimizer_state_dict':None,'scheduler_state_dict':None }
    model = Recon2X3D5(in_c, en_sz, de_sz, de3d_sz, final_sz)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1, threshold=1e-3, 
                                  threshold_mode='rel', cooldown=0, min_lr=1e-12, verbose=True)
    
    # Training and Testing
    timeT1 = time.time()
    saved_dict = train_mixed(model, trainLoader, testLoader, optimizer, scheduler, criterion, batch_sz, epochs, saved_name, saved_dict)
    timeT2 = time.time()
    print('Save model = {}'.format(saved_name))
    print('Total training time = {} hours'.fomat((timeT2-timeT1)/3600))
    print('############################################################### Clear parameters \n\n')
    del model, train_transformedFemur, test_transformedFemur, trainLoader, testLoader, optimizer, scheduler, saved_dict

torch.save(model,'k_cross_CNN.pt')

### Generative and Adversarial Training - Mixed Precision

#### GAN without Autocast

In [None]:
### Without Autocast ###
# Dataset
batch_sz = 1
train_transformedFemur = FemurDataset2(csv_file=training_file, root_dir=root_dir, 
                                      transform=transforms.Compose([NormalizeSample2(),ToTensor8()]))
val_transformedFemur = FemurDataset2(csv_file=val_file, root_dir=root_dir, 
                                    transform=transforms.Compose([NormalizeSample2(),ToTensor8()]))
trainLoader = DataLoader(train_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=0)
valLoader = DataLoader(val_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=0)
print('   Training set : {}'.format(len(trainLoader)))
print('   Validation set : {}'.format(len(valLoader)))

# Defind the models

nc = 3
ndf = 4   # ndf should higher than nc
netD = Discriminator3D(nc, ndf).to(device=device)
'''in_c = 1
en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
de3d_sz = None
final_sz = [[2,32,32] , [5,32,32] , [5,32,32] , [5,16,16] , [5,16,16] , [5,16,16] , [5,16,16]]
netG = Recon2X3D5(in_c, en_sz, de_sz, de3d_sz, final_sz).to(device=device)'''
netG = model.to(device=device)
#netG = model.to(device=device)

netD = nn.DataParallel(netD)
netG = nn.DataParallel(netG)

# Training setup
real_label = 1.
fake_label = 0.
lr = 1e-4   # default = 0.0002
beta1 = 0.5 # default = 0.5
#weight = torch.tensor([0.15,0.25,0.6], device=device)     # For ToTensor8   
#weight = torch.tensor([0.5,0.5], device=device)           # For ToTensor9   
#criterion = FocalLossMulticlass(weight=weight, gamma=2.0, reduction='sum')
criterion_gan = nn.BCELoss()
#criterion_gan = nn.BCEWithLogitsLoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
epochs = 1

# Lists to keep track of progress
img_list = []
Recon_losses = []
G_losses = []
D_losses = []
D_Gx1 = []
D_Gx2 = []
D_X = []
IoU = []
acc = 0
iters = 0

print("Starting Training Loop...")
t1 = time.time()
# For each epoch
for epoch in range(epochs):
    # For each batch in the dataloader
    #for i, data in enumerate(dataloader, 0):
    for i, train_sample in enumerate(trainLoader,0):
        netD.train(mode=True)
        netG.train(mode=True)
        ##################################################################################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))  to reach zero
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        optimizerD.zero_grad()
        optimizerG.zero_grad()

        # Format batch
        target = train_sample['Target'].to(device=device, dtype=dtype1)
        view1 = train_sample['view1'].to(device=device, dtype=dtype1)
        view2 = train_sample['view2'].to(device=device, dtype=dtype1)
        #print('target = {} {} {}'.format(target.size(), target.dtype, target.get_device()))
        targetc = torch.zeros((target.size(0), nc, target.size(1), target.size(2), target.size(3)),device=device)
        targetc[:,0] = target==0
        targetc[:,1] = target==1
        targetc[:,2] = target==2
        targetc = targetc.to(device=device, dtype=dtype1)
        #print('targetc = {} {} {}'.format(targetc.size(),targetc.dtype,targetc.get_device()))
        #print('view1 = {} {} {}'.format(view1.size(),view1.dtype,view1.get_device()))
        #print('view2 = {} {} {}'.format(view2.size(),view2.dtype,view2.get_device()))
        b_size = target.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        #print('label-real = {}'.format(label))
        # Forward pass real batch through D
        output = netD(targetc).view(-1)
        #print('Real-output = {} {}'.format(output, output.size()))
        # Calculate loss on all-real batch
        errD_real = criterion_gan(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()
        
        ## Train with all-fake batch
        # Generate fake image batch with G
        fake = netG(view1, view2)      # Reconstruct the output 3D shape
        label.fill_(fake_label)
        #print('label-fake = {}'.format(label))
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)   # detach for only updating D without updating G
        #print('Fake-output-z1 = {} {}'.format(output, output.size(), output.dtype))
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion_gan(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        #errD.backward()
        # Update D for fake data
        optimizerD.step()
        
        ########################################################################################
        # (2) Update G network: maximize log(D(G(z))) to reach zero
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        #print('label-fake2 = {}'.format(label))
        # Since we just updated D, perform another forward pass of all-fake batch through D
        # Classify of G'output by D
        output = netD(fake).view(-1)
        #print('Fake-output-z2 = {} {}'.format(output, output.size(), output.dtype))
        # Calculate G's loss based on this output
        errG = criterion_gan(output, label)
        # Calculate gradients for G
        errG.backward()   # or loss_recon.backword
        D_G_z2 = output.mean().item()
        acc = iou((fake[:,1]>=0.5).float(), (targetc[:,1]).float())
        # Update G
        optimizerG.step()
        
        # Output training stats
        if i%(round(len(trainLoader)/50)) == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\t  D(x): %.4f\tD(G(z)): %.4f / %.4f\tIoU: %.4f'
                  % (epoch, epochs, i, len(trainLoader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2, acc))   # len(dataloader)
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        D_X.append(D_x)
        D_Gx1.append(D_G_z1)
        D_Gx2.append(D_G_z2)
        IoU.append(acc)
        
        # Check how the generator is doing by saving G's output on fixed_noise
        '''
        if (iters % 500 == 0) or ((epoch == epochs-1) and (i == len(trainLoader)-1)):
            with torch.no_grad():
                fake = netG(view1,view2).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        '''
        
t2 = time.time()
print('\n ### Total GAN training time = {} min.'.format((t2-t1)/60))


#### GAN with Autocast

In [None]:
### Autocast ###
# Device
torch.cuda.empty_cache()
cuda0 = torch.device('cuda:0')
cuda1 = torch.device('cuda:1')

# Dataset
batch_sz = 2
train_transformedFemur = FemurDataset2(csv_file=training_file, root_dir=root_dir, 
                                      transform=transforms.Compose([NormalizeSample2(),ToTensor8()]))
val_transformedFemur = FemurDataset2(csv_file=val_file, root_dir=root_dir, 
                                    transform=transforms.Compose([NormalizeSample2(),ToTensor8()]))
trainLoader = DataLoader(train_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=4)
valLoader = DataLoader(val_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=4)
print('   Training set : {}'.format(len(trainLoader)))
print('   Validation set : {}'.format(len(valLoader)))

# Defind the models and Training setup
real_label = 1.
fake_label = 0.
lr = 0.0001
beta1 = 0.5
running_mode = int(input('Running mode: [1] Training from scratch  [2] Transfer model  [3] Resume previous GAN model = '))
if running_mode == 1:
    print(' ### Training from scratch ###')
    nc = 3
    ndf = 4  #  ndf should higher than nc
    netD = Discriminator3D(nc, ndf).to(device=cuda1)
    in_c = 1
    en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
    de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
    de3d_sz = None
    final_sz = [[2,32,32],[5,32,32],[5,32,32],[5,16,16],[5,16,16],[5,16,16],[5,16,16]]
    netG = Recon2X3D5(in_c, en_sz, de_sz, de3d_sz, final_sz).to(device=cuda0)
elif running_mode == 2:  # go to 3.2 Main Execution, then load Recon2X3D for transfered model
    print(' ### Transfer model ###')
    nc = 3
    ndf = 4  #ndf should higher than nc
    netD = Discriminator3D(nc, ndf).to(device=cuda1)
    netG = model.to(device=cuda0)
elif running_mode == 3:
    print(' ### Resume previous GAN model ###')
    saved_name = 'trained\Recon2X3D5GAN_22022201.pt'      # Recon2X3D5GAN_22022101
    confirm_saved_name = str(input('Confirm Save : ' + saved_name + ' [y/n] ? '))
    while confirm_saved_name!='y':
        saved_name = str('trained\\') + str(input('Enter saved_name = trained\ ')) + str('.pt')
        confirm_saved_name = str(input('Confirm Save : ' + saved_name + ' [y/n] ? '))
    print('*** Confirm saved_name = {} ***\n'.format(saved_name))
    saved_dict = torch.load(saved_name, map_location=device)
    nc = saved_dict['nc']
    ndf = saved_dict['ndf']
    in_c = saved_dict['in_c']
    en_sz = saved_dict['en_sz']
    de_sz = saved_dict['de_sz']
    de3d_sz = saved_dict['de3d_sz']
    final_sz = saved_dict['final_sz']
    print('NetD.parameter: \n\tnc={} \n\tndf={}'.format(saved_dict['nc'],saved_dict['ndf']))
    print('NetG.parameter: \n\tin_c={} \n\ten_sz={} \n\tde_sz={} \n\tfinal_sz={}'.format(saved_dict['in_c'],saved_dict['en_sz'],saved_dict['de_sz'],saved_dict['final_sz']))
    netD = Discriminator3D(saved_dict['nc'],saved_dict['ndf']).to(device=cuda1)
    netG = Recon2X3D5(saved_dict['in_c'], saved_dict['en_sz'], saved_dict['de_sz'], 
                      saved_dict['de3d_sz'], saved_dict['final_sz']).to(device=cuda0)
    netD.load_state_dict(saved_dict['netD_state_dict'])
    netG.load_state_dict(saved_dict['netG_state_dict'])
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerD.load_state_dict(saved_dict['optimizerD_state_dict'])
    optimizerG.load_state_dict(saved_dict['optimizerG_state_dict'])
    D_losses = saved_dict['D_losses']
    G_losses = saved_dict['G_losses']
    Recon_losses = saved_dict['Recon_losses']
    D_Gx1 = saved_dict['D_Gx1']
    D_Gx2 = saved_dict['D_Gx2']
    D_X = saved_dict['D_X']
    IoU = saved_dict['IoU']
    print('Timestamp = {}'.format(saved_dict['timestamp']))
    print('Total iterations = {:,}'.format(len(saved_dict['G_losses'])))
else:
    raise ValueError("Invalid running_mode input !!! \n")

#netD = nn.DataParallel(netD)
#netG = nn.DataParallel(netG)

#criterion_gan = nn.BCELoss(reduction='sum')
criterion_gan = nn.BCEWithLogitsLoss(reduction='sum')
weight = torch.tensor([0.15,0.25,0.6], device=device)     # For ToTensor8
#weight = torch.tensor([0.5,0.5], device=device)           # For ToTensor9
criterion_recon = FocalLossMulticlass(weight=weight, gamma=2.0, reduction='mean')

if running_mode == 1 or running_mode == 2:
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
    # Lists to keep track of progress
    img_list = []
    D_losses = []
    G_losses = []
    Recon_losses = []
    D_Gx1 = []
    D_Gx2 = []
    D_X = []
    IoU = []
    iters = 0
    

training_logit = int(input('\nDo you want to start training? [1] Yes  [0] No = '))
if training_logit == 1:
    print("\n\t... Starting Training Loop ...")
    epochs = int(input('[Input] epochs = '))
    scaler = torch.cuda.amp.GradScaler()
    # For each epoch
    for epoch in range(epochs):
        # For each batch in the dataloader
        #for i, data in enumerate(dataloader, 0):
        errD_sum, errG_gan_sum, errG_recon_sum, D_x_sum, D_G_z1_sum, D_G_z2_sum, IoU_sum = 0, 0, 0, 0, 0, 0, 0
        for i, train_sample in enumerate(trainLoader,0):
            if i==len(trainLoader)-1:  # exclude inequal batch
                print('Final training accuracy = {:.4f} \n'.format(acc))
                break
            netD.train(mode=True)
            netG.train(mode=True)
            ##################################################################################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))  to reach zero
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            optimizerD.zero_grad()
            optimizerG.zero_grad()
            with torch.cuda.amp.autocast(enabled=True):
                # Format batch
                target = train_sample['Target'].to(device=cuda1, dtype=dtype1)
                view1 = train_sample['view1'].to(device=cuda0, dtype=dtype1)
                view2 = train_sample['view2'].to(device=cuda0, dtype=dtype1)
                #print('target = {} {} {}'.format(target.size(), target.dtype, target.get_device()))
                targetc = torch.zeros((target.size(0), nc, target.size(1), target.size(2), target.size(3)),device=device)
                targetc[:,0] = target==0
                targetc[:,1] = target==1
                targetc[:,2] = target==2
                targetc = targetc.to(device=cuda1, dtype=dtype1)
                #print('targetc = {} {} {}'.format(targetc.size(),targetc.dtype,targetc.get_device()))
                #print('view1 = {} {} {}'.format(view1.size(),view1.dtype,view1.get_device()))
                #print('view2 = {} {} {}'.format(view2.size(),view2.dtype,view2.get_device()))
                b_size = target.size(0)
                label = torch.full((b_size,), real_label, dtype=torch.float, device=cuda1)

                # Forward pass real batch through D
                output = netD(targetc).view(-1)
                #print('Real-output = {} {}'.format(output, output.size()))
                # Calculate loss on all-real batch
                errD_real = criterion_gan(output, label)
                # Calculate gradients for D in backward pass
                #errD_real.backward()
                D_x = output.mean().item()
                D_x_sum += D_x
            # Update D for real data
            scaler.scale(errD_real).backward()
            torch.cuda.synchronize()

            ## Train with all-fake batch
            with torch.cuda.amp.autocast(enabled=True):
                # Generate fake image batch with G
                fake = netG(view1, view2)      # Reconstruct the output 3D shape
                label.fill_(fake_label)
                # Classify all fake batch with D
                output = netD(fake.detach().to(device=cuda1)).view(-1)   # detach for updating D without updating G
                #print('Fake-output = {} {} {}'.format(output, output.size(), output.dtype))
                # Calculate D's loss on the all-fake batch
                errD_fake = criterion_gan(output, label)
                # Calculate the gradients for this batch, accumulated (summed) with previous gradients
                #errD_fake.backward()
                D_G_z1 = output.mean().item()
                D_G_z1_sum += D_G_z1
                # Compute error of D as sum over the fake and the real batches
                errD = errD_real + errD_fake
                errD_sum += errD.item()
            # Update D for fake data
            #optimizerD.step()
            scaler.scale(errD_fake).backward()
            torch.cuda.synchronize()
            scaler.step(optimizerD)

            ########################################################################################
            # (2) Update G network: maximize log(D(G(z))) to reach zero
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            # Classify of G'output by D
            with torch.cuda.amp.autocast(enabled=True):
                output = netD(fake.to(device=cuda1)).view(-1)
                # Calculate G's loss based on this output
                errG_gan = criterion_gan(output, label)
                errG_recon = criterion_recon(fake.to(device=device),target.long().to(device=device))
                errG = errG_gan.to(device=device) + 50*errG_recon.to(device=device)
                errG_gan_sum += errG_gan.item()
                errG_recon_sum += errG_recon.item()
                # Calculate gradients for G
                #errG.backward()   # or loss_recon.backword
                D_G_z2 = output.mean().item()
                D_G_z2_sum += D_G_z2
                acc = iou((fake[:,1]>=0.5).float().to(device=cuda1), (targetc[:,1]).float().to(device=cuda1)).item()
                IoU_sum += acc
            # Update G
            #optimizerG.step()
            scaler.scale(errG).backward()
            torch.cuda.synchronize()
            scaler.step(optimizerG)
            scaler.update()   # update for the next iteration
            
            # Output training stats
            if i%(round(len(trainLoader)/20)) == 0:
                print('Time: {}'.format(str(datetime.datetime.now())))
                print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_recon: %.4f\t D(x): %.4f D(G(z)): %.4f/%.4f IoU: %.4f'
                      % (epoch, epochs-1, i, len(trainLoader)-1,
                         errD.item(), errG_gan.item(), errG_recon.item(), D_x, D_G_z1, D_G_z2, acc))    # len(dataloader)

        # Save Losses for plotting later
        D_losses.append(errD_sum/len(trainLoader))
        G_losses.append(errG_gan_sum/len(trainLoader))
        Recon_losses.append(errG_recon_sum/len(trainLoader))
        D_X.append(D_x_sum/len(trainLoader))
        D_Gx1.append(D_G_z1_sum/len(trainLoader))
        D_Gx2.append(D_G_z2_sum/len(trainLoader))
        IoU.append(IoU_sum/len(trainLoader))
        
        saved_dict2 = {'timestamp':str(datetime.datetime.now()),
                       'note':'GAN training model transfered from Scratch using lossG = log(D(G(x))) + 50*FocalLoss([0.15,0.25,0.6])',
                       'in_c':in_c,
                       'en_sz':en_sz,
                       'de_sz':de_sz,
                       'de3d_sz':de3d_sz,
                       'final_sz':final_sz,
                       'netG_state_dict':netG.state_dict(),
                       'nc':nc,
                       'ndf':ndf,
                       'netD_state_dict':netD.state_dict(),
                       'G_losses':G_losses,
                       'D_losses':D_losses,
                       'Recon_losses':Recon_losses,
                       'D_Gx1':D_Gx1,
                       'D_Gx2':D_Gx2,
                       'D_X':D_X,
                       'IoU':IoU,
                       'optimizerG_state_dict': optimizerG.state_dict(),
                       'optimizerD_state_dict': optimizerD.state_dict(),
                      }
        #saved_name2 = 'trained\\Recon2X3D5GAN_22022201.pt'
        saved_name2 = saved_name
        print('Save as: {}\n'.format(saved_name2))
        #save_logic = int(input('Confirm saved name = {} : [1] Save [0] Not Save ?'.format(saved_name2)))
        save_logic = 1
        if save_logic==1:
            torch.save(saved_dict2, saved_name2)     
            
print(' ### End Session ### ')

# Check how the generator is doing by saving G's output on fixed_noise
'''
if (iters % 500 == 0) or ((epoch == epochs-1) and (i == len(trainLoader)-1)):
    with torch.no_grad():
        fake = netG(view1,view2).detach().cpu()
    img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
'''

#### Visualize and Save GAN Model

In [None]:
plt.figure(figsize=(30,8))
plt.plot(G_losses, 'r', label='G losses')
plt.plot(D_losses, 'b', label='D losses')
plt.legend()
plt.show()

plt.figure(figsize=(30,8))
plt.plot(D_X,'b', label='D_X')
plt.plot(D_Gx1, 'r', label='D_Gx1')
plt.plot(D_Gx2, 'g', label='D_Gx2')
plt.plot(IoU, 'o-', label='IoU')
plt.legend()
plt.show()

In [None]:
saved_dict2 = {'timestamp':str(datetime.datetime.now()),
               'note':'GAN training model transfered from Scratch using lossG = log(D(G(x))) + 50*FocalLoss([0.15,0.25,0.6])',
               'in_c':in_c,
               'en_sz':en_sz,
               'de_sz':de_sz,
               'de3d_sz':de3d_sz,
               'final_sz':final_sz,
               'netG_state_dict':netG.state_dict(),
               'nc':nc,
               'ndf':ndf,
               'netD_state_dict':netD.state_dict(),
               'G_losses':G_losses,
               'D_losses':D_losses,
               'Recon_losses':Recon_losses,
               'D_Gx1':D_Gx1,
               'D_Gx2':D_Gx2,
               'D_X':D_X,
               'IoU':IoU,
               'optimizerG_state_dict': optimizerG.state_dict(),
               'optimizerD_state_dict': optimizerD.state_dict(),
              }

saved_name2 = 'trained\\Recon2X3D5GAN_22022201.pt'
print('Save as: {}'.format(saved_name2))
#save_logic = int(input('Confirm saved name = {} : [1] Save [0] Not Save ?'.format(saved_name2)))
save_logic = 1
if save_logic==1:
    torch.save(saved_dict2, saved_name2)
    print(' --- END --- ')

# Main Execution

* mode = 1:   Transfer learning
* mode = 2:   Training from scratch
* mode = 3:   Evaluation model
* mode = 4:   Hyperparameter tuning

## Main Execution (obsolete)

In [None]:
torch.cuda.empty_cache()
runningMode = int(input('Select running mode: '))
if runningMode==1 or runningMode==2:
    epoch_number = int(input('Number of epoch: '))
    batch_sz = int(input('Batch size: '))
elif runningMode==3:
    batch_sz = 1
    
# Define dataLoader
train_transformedFemur = FemurDataset2(csv_file=training_file, root_dir=root_dir, 
                                      transform=transforms.Compose([NormalizeSample2(),ToTensor8()]))
val_transformedFemur = FemurDataset2(csv_file=val_file, root_dir=root_dir, 
                                    transform=transforms.Compose([NormalizeSample2(),ToTensor8()]))
trainLoader = DataLoader(train_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=4)
valLoader = DataLoader(val_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=4)
print('   Training set : {}'.format(len(trainLoader)))
print('   Validation set : {}'.format(len(valLoader)))
learning_rate = 1e-4             # deflaut = 5e-4

#criterion = BCEDiceLoss(0.75,0.25).to(device=device)
'''posWeight = (0.005,0.015,0.98)   # tunning classes imbalance
criterion = MulticlassBCEDiceLoss(0.75, 0.25, posWeight, 'mean').to(device=device)'''
'''criterion = FocalBCETverskyLoss(ratio=0.8 ,alpha1=1, gamma1=2, 
                                alpha2=1, beta2=1, gamma2=1,
                                reduction='mean', smooth=1e-4,
                               )'''
#criterion = BinaryTverskyLossV2(alpha=0.35, beta=0.65, reduction='mean')
#criterion = FocalBinaryTverskyLoss(alpha=0.5, beta=0.5, gamma=1.0, reduction='mean')
#criterion = FEWFocalLoss2(alpha=0.5, gamma=5, reduction='mean')    # alpha + (1-alpha) = 1
#criterion = BinaryFocalLoss(alpha=1, gamma=2, reduction='mean')    # alpha > positive real number
weight = torch.tensor([0.15,0.25,0.6], device=device)
#weight = torch.tensor([0.5805325277511839, 0.8178515226802764, 0.853697532760545], device=device)  # Tuning
criterion = FocalLossMulticlass(weight=weight, gamma=2.0, reduction='sum')
#criterion = HausdorffERLoss(alpha=2.0)

# Select running mode
if runningMode==1 or runningMode==3:      # Transfer learning
    model_name = 'trained\Recon2X3D5_21100402.pt'
    print('   ### Loading trained model : {} ### \n'.format(model_name))
    checkpoint = torch.load(model_name, map_location=device)
    ### Load checkpoint ###
    timestamp = checkpoint['timestamp']
    note = checkpoint['note']
    en_sz = checkpoint['en_sz']
    de_sz = checkpoint['de_sz']
    de3d_sz = checkpoint['de3d_sz']
    final_sz = checkpoint['final_sz']
    train_loss_history = checkpoint['train_loss_history']
    train_acc_history = checkpoint['train_acc_history']
    val_loss_history = checkpoint['val_loss_history']
    val_acc_history = checkpoint['val_acc_history']
    
    ### Print state ###
    print('Model: en_sz = {}'.format(en_sz))
    print('Model: de_sz = {}'.format(de_sz))
    print('Model: de3d_sz = {}'.format(de3d_sz))
    print('Model: final_sz = {}'.format(final_sz))
    print('Note: {} \n'.format(note))
    print('Timestamp = {}'.format(timestamp))
    print('Total iterations = {:,}'.format(len(train_loss_history)))
    print('Average training accuracy = {:.4f}'.format(max(train_acc_history)))
    print('Average validation accuracy = {:.4f}'.format(max(val_acc_history)))
    
    ### Create Model ###
    model_select = int(input('Select model \n1) Recon3DUNet Old\n2) Recon3DUNet\n3) Recon3DDenseUNet'\
                             '\n4) Recon3DDenseUNet2 \n5) Recon2X3D \n6) Recon2X3D2 \n7) Recon2X3D3 '\
                             '\n8) Recon2X3D4 \n9) Recon2X3D5 \n10) Recon2X3D6 \n\n INPUT: '))
    if model_select==2:
        model = Recon3DUNet(1, en_sz, de_sz, final_sz)
    elif model_select==3:
        model = Recon3DDenseUNet(1, en_sz, de_sz, final_sz)
    elif model_select==4:
        de3d_sz = checkpoint['de3d_sz']
        model = Recon3DDenseUNet2(1, en_sz, de_sz, de3d_sz, final_sz)
    elif model_select==5:
        model = Recon2X3D(1, en_sz, de_sz, final_sz)
    elif model_select==6:
        de3d_sz = checkpoint['de3d_sz']
        model = Recon2X3D2(1, en_sz, de_sz, de3d_sz, final_sz)
    elif model_select==7:
        de3d_sz = checkpoint['de3d_sz']
        model = Recon2X3D3(1, en_sz, de_sz, de3d_sz, final_sz)
    elif model_select==8:
        de3d_sz = checkpoint['de3d_sz']
        model = Recon2X3D4(1, en_sz, de_sz, de3d_sz, final_sz)
    elif model_select==9:
        model = Recon2X3D5(1, en_sz, de_sz, de3d_sz, final_sz)
    elif model_select==10:
        model = Recon2X3D6(1, en_sz, de_sz, de3d_sz, final_sz)
    
    ### Load state dict to Model ###
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1, threshold=1e-3, 
                              threshold_mode='rel', cooldown=0, min_lr=0, verbose=True)
    model.to(device=device)
    model.load_state_dict(checkpoint['model_state_dict'])                 # default: strict=False
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])         # default: strict=False
    #print('Model parameters: \n   {}'.format(model.parameters))
    #print('optimizer parameters: {}\n'.format(optimizer))
    print('\n----- Finished loading -----\n')
    
    
elif runningMode == 2:    # Training from scratch
    print('   ### Start training from scratch ### \n')
    train_loss_history = list()
    train_acc_history = list()
    val_loss_history = list()
    val_acc_history = list()
    
    # Define new
    model_select = int(input('Select model \n1) Recon3DUNet Old\n2) Recon3DUNet\n3) Recon3DDenseUNet'\
                             '\n4) Recon3DDenseUNet2 \n5) Recon2X3D \n6) Recon2X3D2 \n7) Recon2X3D3 '\
                             '\n8) Recon2X3D4 \n9) Recon2X3D5 \n10) Recon2X3D6 \n\n INPUT: '))
    if model_select==2:    # Recon3DUNet
        print('model = Recon3DUNet\n')
        en_sz = [[2,4],[8,16],[32,64],[128,256],[256,256]]      # Encode-dimension
        de_sz = [[256,256],[256,256],[256,256],[256,256]]           # Decode-dimension
        final_sz = [256]                                        # Classify-dimension
        model = Recon3DUNet(1, en_sz, de_sz, final_sz)
    elif model_select==3:    # Recon3DDenseUNet
        print('model = Recon3DDenseUNet\n')
        in_c = 1
        en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4]]       # encoder with dense connection
        de_sz = [[256,256],[256,256],[256,256],[256,256]]                    # decoder dimension
        #final_sz = [256]           # For 2D-classifier dimension use with self.final_sz
        final_sz = [16,16]        # For 3D-classifier dimension use with self.final_sz2 and self.final_sz3
        model = Recon3DDenseUNet(in_c, en_sz, de_sz ,final_sz)
    elif model_select==4:    # Recon3DDenseUNet2
        print('model = Recon3DDenseUNet2 (with 3D-Conv Decoder) \n')
        in_c = 1
        en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4]]   # Encoder with dense-connection [input_C,deep,level]
        de_sz = [[256,256],[256,256],[256,256],[256,256]]             # Decoder dimension
        de3d_sz = [[1,8,8],[1,8,8],[1,8,8],[1,8,8]]                   # 3D-decoder
        final_sz = [[16,8],[16,8],[16,8]]                     # 3D-Reconstruction
        model = Recon3DDenseUNet2( in_c, en_sz , de_sz , de3d_sz, final_sz)
    elif model_select==5:    # Recon2X3D
        print('model = Recon2X3D\n')
        in_c = 1
        en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4]]         # encoder with dense connection
        de_sz = [[256,256],[256,256],[256,256],[256,256]]                       # decoder dimension
        final_sz = [[2,1],[3,1],[3,1],[3,3]]                              # classifier by nn.Conv3d
        #final_sz = [[256],[256],[256],[256]]                               # classifier by nn.Conv2d
        model = Recon2X3D(in_c, en_sz, de_sz, final_sz)
    elif model_select==6:    # Recon2X3D2
        print('model = Recon2X3D2')
        in_c = 1
        en_sz = [[16,16,4],[80,16,4],[144,16,4],[208,16,4],[272,16,4],[336,16,4]]    # encoder with dense connection
        de_sz = [[128,16],[128,32],[128,64],[128,128],[256,256]]              # decoder dimension
        de3d_sz = [[1,4],[1,4],[1,4],[1,4],[1,4]]
        final_sz = [[8,4],[12,4],[12,4],[12,4],[12,12]]
        model = Recon2X3D2(in_c, en_sz, de_sz, de3d_sz, final_sz)
    elif model_select==7:
        print('model = Recon2X3D3')
        in_c = 1
        en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4]]    # encoder with dense connection
        de_sz = [[128,32],[128,64],[128,128],[256,256]]                # decoder dimension
        de3d_sz = [[1,4],[1,4],[1,4],[1,4]]
        final_sz = [[4,8],[4,8],[4,8],[4,8]]                           # 3D Tensor pyramid connection
        model = Recon2X3D3(in_c, en_sz, de_sz, de3d_sz, final_sz)
    elif model_select==8:
        print('model = Recon2X3D4')
        in_c = 1
        en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4]]    # encoder with dense connection
        de_sz = [[128,32],[128,64],[128,64],[256,128]]               # decoder dimension
        de3d_sz = [[256*256, 1024*4, 256*256]]             # Linear fusion of AP and LAT view
        final_sz = [[1,8]]                                        # 3D Tensor pyramid connection
        model = Recon2X3D4(in_c, en_sz, de_sz, de3d_sz, final_sz)
    elif model_select==9:
        print('model = Recon2X3D5')
        in_c = 1
        de3d_sz = [[32,8],[24,6],[14,4],[8,2]] # don't use now
        
        # Six level (Averaging feature)
        '''en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4]]    # encoder with dense connection
        de_sz = [[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]              # decoder dimension
        final_sz = [[1,32,32],[1,32,32],[1,16,16],[1,16,16],[1,16,16],[1,16,16]]'''
        
        # Seven level (Averaging feature)
        en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
        de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
        final_sz = [[1,32,32],[1,32,32],[1,32,32],[1,16,16],[1,16,16],[1,16,16],[1,16,16]] # 2-fusion layer
        #final_sz = [[1,16,16,16],[1,16,16,16],[1,16,16,16],[1,16,16,16],[1,16,16,16],[1,16,16,16],[1,16,16]] # 3-fusion layer
        
        # Concatenation wtih arithmetic features (+,*,**2)
        '''en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
        #de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]               # work
        #final_sz = [[2,32,32],[2,32,32],[2,32,32],[2,16,16],[2,16,16],[2,16,16],[2,16,16]]   # work
        en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
        de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
        #final_sz = [[2,32],[3,32],[3,32],[3,16],[3,16],[3,16],[3,16]]                                  # 1-fusion layer
        final_sz = [[2,32,32],[3,32,32],[3,32,32],[3,16,16],[3,16,16],[3,16,16],[3,16,16]]              # 2-fusion layer
        #final_sz = [[2,16,16,16],[3,16,16,16],[3,16,16,16],[3,16,16,16],[3,16,16,16],[3,16,16,16],[3,16,16,16]] # 3-fusion layer
        '''
        # Concatenation of Sub-feature-volume
        '''en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
        de_sz = [[64,4*8],[128,8*8],[128,16*4],[256,32*4],[256,64*2],[256,128*2],[256,256]] 
        final_sz = [[16,32,32],[16+4,32,32],[8+4,32,32],[8+4,16,16],[4+4,16,16],[4+4,16,16],[2+4,16,16]]'''
        
        model = Recon2X3D5(in_c, en_sz, de_sz, de3d_sz, final_sz)
        
    elif model_select==10:
        print('model = Recon2X3D6')
        in_c = 1
        en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
        de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
        de3d_sz = [[8,8,4],[16,16,8],[32,32,16],[64,64,32],[128,128,64],[256,256,128],[512,512,256]] # 2-axial layer per level
        final_sz = [[3,32,32],[6,32,32],[6,32,32],[6,16,16],[6,16,16],[6,16,16],[6,16,16]]   # cat(X, XD,XH,XW)
        #final_sz = [[3,32,32],[3,32,32],[3,32,32],[3,16,16],[3,16,16],[3,16,16],[3,16,16]]   #  avg[X ,cat(XD,XH,Xw)]
        
        # 6-Levels
        '''en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4]]
        de_sz = [[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
        de3d_sz = [[16,16,8],[32,32,16],[64,64,32],[128,128,64],[256,256,128],[512,512,256]] # 2-axial layer per level
        final_sz = [[3,32,32],[11,32,32],[11,32,32],[11,16,16],[11,16,16],[11,16,16]]   # cat(X, XD,XH,XW)'''
        
        model = Recon2X3D6( in_c, en_sz, de_sz, de3d_sz, final_sz)
        
    
    # Create new model
    #model = nn.DataParallel(model, device_ids=[0, 1])         # Training on all available GPU
    model.to(device=device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1, threshold=1e-3, 
                              threshold_mode='rel', cooldown=0, min_lr=1e-12, verbose=True)
    # Print state
    print('Model parameters: {}\n'.format(model.parameters))

elif runningMode == 4:
    print('   ### Hyperparameter Tuning ### \n')
    train_loss_history = list()
    train_acc_history = list()
    val_loss_history = list()
    val_acc_history = list()
    model = define_model_trial(trial).to(device=device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1, threshold=1e-3, 
                              threshold_mode='rel', cooldown=0, min_lr=1e-12, verbose=True)
    

if runningMode==1 or runningMode==2:
    data_parallel_mode = int(input('Activate nn.DataParallet  [0] No  [1] Yes  : '))
    if data_parallel_mode == 1:
        model = nn.DataParallel(model)
else:
    data_parallel_mode = 0

# Execution loop
if runningMode == 1 or runningMode == 2:
    print('\nTotal iterations = {:,}'.format(len(train_loss_history)))
    timeStr = str(datetime.datetime.now())
    print('Start training at : {}'.format(timeStr))
    
    timeT1 = time.time()
    if USE_GPU and torch.cuda.is_available():
        print('\n --- Use mixed precision training ---\n')
        #train_loss_history,train_acc_history,val_loss_history,val_acc_history = train_mixed(model, trainLoader, valLoader, optimizer, scheduler, criterion, epochs=epoch_number)
        saved_dict = train_mixed(model, trainLoader, valLoader, optimizer, scheduler, criterion, epochs=epoch_number)
        #train_mixed(model, trainLoader, valLoader, optimizer, scheduler, criterion, epochs=1):
    else:
        assert '!!! Training with CPU not available !!!'
        train(model, trainLoader, valLoader, optimizer, scheduler, criterion, epochs=epoch_number)
    timeT2 = time.time()
    print('\nTotal training time = {} min. \nTotal epoch = {}\n'.format((timeT2-timeT1)/60, len(val_acc_history)))
    timeStr = str(datetime.datetime.now()) 
    print('Finish task at : {}\n'.format(timeStr))
elif runningMode == 4:
    print('\n##### Hyperameter Tuning #####')
    timeStr = str(datetime.datetime.now())
    print('Start training at : {}'.format(timeStr))
    timeT1 = time.time()
    #Running Hyperparameter Search
    timeT2 = time.time()
    print('\nTotal training time = {} min. \nTotal epoch = {}\n'.format((timeT2-timeT1)/60, len(val_acc_history)))
    timeStr = str(datetime.datetime.now()) 
    print('Finish task at : {}\n'.format(timeStr))

    

## Main Execution

Recon2X3D5_21101901		FracReconNet

Recon2X3D5_21102201		3DReconNet-Ac

Recon2X3D5_21102101		FracAug only

Recon2X3D5_21102102		Bare    Recon2X3D5_21100402  3DReconNet

NormalizeSample2()  use with FemurDataset2

In [None]:
ToTensor

In [None]:
# torch.cuda.empty_cache()
runningMode = int(input('Select running mode: '))
if runningMode==1 or runningMode==2:
    epoch_number = int(input('Number of epoch: '))
    batch_sz = int(input('Batch size: '))
elif runningMode==3:
    batch_sz = 1
    
toTensorMode = int(input('Select ToTensor : [8] Auxiliary class  [9] Native = '))
if toTensorMode == 8:
    print('ToTensor8 selected')
    ToTensor = ToTensor8()
elif toTensorMode == 9:
    print('Totensor9 selected')
    ToTensor = ToTensor9()
elif toTensorMode < 8:
    print('Totensor1-7 selected  change NormalizeSample() and FemurDataset')
    ToTensor = ToTensor7()
else:
    raise ValueError ('Invalid ToTensor input')
print('ToTensor = {}'.format(ToTensor))

# Define dataLoader
num_workers = 4

train_transformedFemur = FemurDataset2(csv_file=training_file, root_dir=root_dir, 
                                      transform=transforms.Compose([NormalizeSample2(), ToTensor]))
val_transformedFemur = FemurDataset2(csv_file=val_file, root_dir=root_dir, 
                                    transform=transforms.Compose([NormalizeSample2(), ToTensor]))
trainLoader = DataLoader(train_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=num_workers)
valLoader = DataLoader(val_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=num_workers)
print('   Training set : {}'.format(len(trainLoader)))
print('   Validation set : {}'.format(len(valLoader)))

learning_rate = 1e-4             # deflaut = 5e-4
weight = torch.tensor([0.15,0.25,0.6], device=device)     # For ToTensor8
#weight = torch.tensor([0.5,0.5], device=device)           # For ToTensor9
criterion = FocalLossMulticlass(weight=weight, gamma=2.0, reduction='sum')

saved_name = 'trained\Recon2X3D5_21101901.pt' # Recon2X3D5_21101901 Recon2X3D5GAN_22022201  Recon2X3D6_22060701  unalign{Recon2X3D5_22081001 vs Recon2X3D5_22081701}
confirm_saved_name = str(input('Confirm Save : ' + saved_name + ' [y/n] ? '))
while confirm_saved_name!='y':
    saved_name = str('trained\\') + str(input('Enter saved_name = trained\ ')) + str('.pt')
    confirm_saved_name = str(input('Confirm Save : ' + saved_name + ' [y/n] ? '))
print('*** Confirm saved_name = {} ***\n'.format(saved_name))

# Select running mode
if runningMode==1 or runningMode==3:      # Transfer learning, Testing or Inference
    print('### Loading trained model : {} ### \n'.format(saved_name))
    saved_dict = torch.load(saved_name, map_location=device)
    
    ### Print state ###
    saved_dict['in_c'] = 1  # For the Obsolete Model
    print('Model: in_c  = {}'.format(saved_dict['in_c']))
    print('Model: en_sz = {}'.format(saved_dict['en_sz']))
    print('Model: de_sz = {}'.format(saved_dict['de_sz']))
    print('Model: de3d_sz = {}'.format(saved_dict['de3d_sz']))
    print('Model: final_sz = {}'.format(saved_dict['final_sz']))
    print('Note: {} \n'.format(saved_dict['note']))
    print('Timestamp = {}'.format(saved_dict['timestamp']))
    #print('Total iterations = {:,}'.format(len(saved_dict['train_loss_history'])))
    #print('Average training accuracy = {:.4f}'.format(max(saved_dict['train_acc_history'])))
    #print('Average validation accuracy = {:.4f}'.format(max(saved_dict['val_acc_history'])))
    
    ### Create Model ###
    model_select = int(input('Select model \n1) Recon2X3D5 \n2) Recon2X3D6 \n3) Recon2X3D5-GAN \n\n INPUT: '))
    if model_select==1:
        model = Recon2X3D5(saved_dict['in_c'], saved_dict['en_sz'], saved_dict['de_sz'], 
                           saved_dict['de3d_sz'], saved_dict['final_sz'])  # 
    elif model_select==2:
        model = Recon2X3D6(saved_dict['in_c'], saved_dict['en_sz'], saved_dict['de_sz'], 
                           saved_dict['de3d_sz'], saved_dict['final_sz'])
    elif model_select==3:
        netD = Discriminator3D(saved_dict['nc'],saved_dict['ndf']).to(device=device)
        netG = Recon2X3D5(saved_dict['in_c'], saved_dict['en_sz'], saved_dict['de_sz'], 
                          saved_dict['de3d_sz'], saved_dict['final_sz']).to(device=device)
    
    ### Load state dict to Model ###
    
    '''if runningMode==1:
        model.load_state_dict(saved_dict['model_state_dict_last'])     # Load the last trained model to train again
    elif runningMode==3:
        model.load_state_dict(saved_dict['model_state_dict'])          # Load the best model to test and inference'''
    
    model.load_state_dict(saved_dict['model_state_dict'])          # Load the best model to test and inference
    model.to(device=device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1, threshold=1e-3, 
                              threshold_mode='rel', cooldown=0, min_lr=0, verbose=True)
    optimizer.load_state_dict(saved_dict['optimizer_state_dict'])         # default: strict=False
    #scheduler.load_state_dict(saved_dict['scheduler_state_dict'])         # default: strict=False
    print('odel parameters: \n   {}'.format(model.parameters))
    print('optimizer parameters: {}\n'.format(optimizer))
    print('scheduler parameters: {}\n'.format(scheduler))
    print('\n----- Finished loading -----\n')
    
elif runningMode == 2:    # Training from scratch
    print('   ### Start training from scratch ### \n')
    note = 'Recon2X3D5 (IN/IN/IN) CAT(X1,X2,X)*2-fusion layer \
            (last fusion layer [3,16,16][16,3] + ToTensor8 using FocalMulticlass({0.15,0.25,0.6}, 2, sum)'
    #note = 'Recon2X3D5 (IN/IN/IN) CAT[XD,XH,XW,X(C=3))*2-fusion layer \
    #(last fusion layer [3,16,16][16,2] + ToTensor9 using FocalMulticlass({0.5 0.5}, 2, sum)'
    
    # Define new
    model_select = int(input('Select model \n0) Recon2X3D4 \n1) Recon2X3D5 \n2) Recon2X3D6 \n\n INPUT: '))
    if model_select==0:
        print('model = Recon2X3D4')
        in_c = 1
        en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4]]    # encoder with dense connection
        de_sz = [[128,32],[128,64],[128,64],[256,128]]               # decoder dimension
        de3d_sz = [[256*256, 1024*4, 256*256]]             # Linear fusion of AP and LAT view
        final_sz = [[1,8]]                                        # 3D Tensor pyramid connection
        model = Recon2X3D4(in_c, en_sz, de_sz, de3d_sz, final_sz)
    elif model_select==1:
        print('model = Recon2X3D5')
        in_c = 1
        de3d_sz = None       # don't use now
        
        # Six level (Averaging feature)
        '''en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4]]    # encoder with dense connection
        de_sz = [[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]              # decoder dimension
        final_sz = [[1,32,32],[1,32,32],[1,16,16],[1,16,16],[1,16,16],[1,16,16]]'''
        
        # Seven level (Averaging feature)
        '''en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
        de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
        final_sz = [[1,32,32],[1,32,32],[1,32,32],[1,16,16],[1,16,16],[1,16,16],[1,16,16]] # 2-fusion layer
        #final_sz = [[1,16,16,16],[1,16,16,16],[1,16,16,16],[1,16,16,16],[1,16,16,16],[1,16,16,16],[1,16,16]] # 3-fusion layer'''
        
        # Concatenation wtih arithmetic features (+,*,**2)
        en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
        de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
        final_sz = [[2,32,32],[5,32,32],[5,32,32],[5,16,16],[5,16,16],[5,16,16],[5,16,16]]
        
        # Concatenation of Sub-feature-volume
        '''en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
        de_sz = [[64,4*8],[128,8*8],[128,16*4],[256,32*4],[256,64*2],[256,128*2],[256,256]] 
        final_sz = [[16,32,32],[16+4,32,32],[8+4,32,32],[8+4,16,16],[4+4,16,16],[4+4,16,16],[2+4,16,16]]'''
        
        model = Recon2X3D5(in_c, en_sz, de_sz, de3d_sz, final_sz)
        
    elif model_select==2:
        print('model = Recon2X3D6')
        in_c = 1
        en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
        de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
        de3d_sz = [[8,8,4],[16,16,8],[32,32,16],[64,64,32],[128,128,64],[256,256,128],[512,512,256]] # 2-axial layer per level
        final_sz = [[3,32,32],[6,32,32],[6,32,32],[6,16,16],[6,16,16],[6,16,16],[6,16,16]]    # cat(X, XD, XH, XW)
        #final_sz = [[3,32,32],[3,32,32],[3,32,32],[3,16,16],[3,16,16],[3,16,16],[3,16,16]]   #  avg[X ,cat(XD, XH, Xw)]
        
        # 6-Levels
        '''en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4]]
        de_sz = [[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
        de3d_sz = [[16,16,8],[32,32,16],[64,64,32],[128,128,64],[256,256,128],[512,512,256]] # 2-axial layer per level
        final_sz = [[3,32,32],[11,32,32],[11,32,32],[11,16,16],[11,16,16],[11,16,16]]   # cat(X, XD,XH,XW)'''
        
        model = Recon2X3D6(in_c, en_sz, de_sz, de3d_sz, final_sz)
        
    # Create new model
    #model = nn.DataParallel(model, device_ids=[0, 1])         # Training on all available GPU
    model.to(device=device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1, threshold=1e-3, 
                              threshold_mode='rel', cooldown=0, min_lr=1e-12, verbose=True)
    
    saved_dict = {'timestamp':None ,
                  'note':note ,
                  'in_c':in_c ,
                  'en_sz':en_sz ,
                  'de_sz':de_sz ,
                  'de3d_sz':de3d_sz ,
                  'final_sz':final_sz ,
                  'train_loss_history':list() ,
                  'train_acc_history':list() ,
                  'val_loss_history':list() ,
                  'val_acc_history':list() ,
                  'model_state_dict':None ,
                  #'model_state_dict_last':None ,
                  'optimizer_state_dict':None,
                  'scheduler_state_dict':None }
    
    # Print state
    #print('Model parameters: {}\n'.format(model.parameters))

elif runningMode == 4:
    # do not finish
    print('   ### Hyperparameter Tuning ### \n')
    model = define_model_trial(trial).to(device=device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1, threshold=1e-3, 
                              threshold_mode='rel', cooldown=0, min_lr=1e-12, verbose=True)
    

if runningMode==1 or runningMode==2:
    data_parallel_mode = int(input('Activate nn.DataParallet  [0] No  [1] Yes  : '))
    if data_parallel_mode == 1:
        model = nn.DataParallel(model)
else:
    data_parallel_mode = 0

# Execution loop
if runningMode == 1 or runningMode == 2:
    print('\nTotal iterations = {:,}'.format(len(saved_dict['train_loss_history'])))
    print('Start training at : {}'.format(str(datetime.datetime.now())))
    
    timeT1 = time.time()
    if USE_GPU and torch.cuda.is_available():
        print('\n --- Use mixed precision training ---\n')
        saved_dict = train_mixed(model, trainLoader, valLoader, optimizer, scheduler, criterion, 
                                 batch_sz, epoch_number, saved_name, saved_dict)
    else:
        assert USE_GPU and torch.cuda.is_available(), '!!! Training with CPU not available !!!'
        
    timeT2 = time.time()
    print('\nTotal training time = {} hours \nTotal epoch = {}\n'.format((timeT2-timeT1)/3600, len(saved_dict['val_acc_history'])))
    print('Saved Name: {}'.format(saved_name))
    print('Finish task at : {}\n'.format(str(datetime.datetime.now()) ))
    
elif runningMode == 4:
    print('\n##### Hyperameter Tuning #####')
    print('Start training at : {}'.format(str(datetime.datetime.now()) ))
    timeT1 = time.time()
    #Running Hyperparameter Search
    timeT2 = time.time()
    print('\nTotal training time = {} hours \nTotal epoch = {}\n'.format((timeT2-timeT1)/3600, len(val_acc_history)))
    print('Finish task at : {}\n'.format(str(datetime.datetime.now()) ))   
#trainLoader = 43906
#batch size = 256

## Hyperparameter Tuning

In [None]:
# Hyperparameter tuning
''' Tuning parameter list:
    
    network_deep = depth of overall network  (optional)
    en_sz = encoder dimension : [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
    de_sz = decoder dimension : [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
    final_sz = fusion dimension : [[1,32,32],[1,32,32],[1,32,32],[1,16,16],[1,16,16],[1,16,16],[1,16,16]]
    
    weight = [w1,w2,w3] for background, bone and fracture occupancy, respectively in the Focal loss
    gamma = focus rate of the Focal loss
    lr = learning rate
'''

torch.cuda.empty_cache()
epochs = 3
batch_sz = 3
train_transformedFemur = FemurDataset(csv_file=training_file, root_dir=root_dir, 
                                      transform=transforms.Compose([NormalizeSample(),ToTensor6()]))
val_transformedFemur = FemurDataset(csv_file=val_file, root_dir=root_dir, 
                                    transform=transforms.Compose([NormalizeSample(),ToTensor6()]))
trainLoader = DataLoader(train_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=4)
valLoader = DataLoader(val_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=4)
train_loss_history, train_acc_history, val_loss_history, val_acc_history = list(), list(), list(), list()

def define_model_trial(trial):
    de3d_sz = [[32,8],[24,6],[14,4],[8,2]]  # ignore
    en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
    de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
    final_sz = [[1,32,32],[1,32,32],[1,32,32],[1,16,16],[1,16,16],[1,16,16],[1,16,16]]
    return Recon2X3D5(1, en_sz, de_sz, de3d_sz, final_sz)

def objective(trial):
    # Setup model and hyperparameter
    torch.cuda.empty_cache()
    time.sleep(3)
    model = define_model_trial(trial).to(device=device)
    model = nn.DataParallel(model, device_ids=[0,1])
    w1 = trial.suggest_float("w1", 1e-2, 1.00, log=False)   # for background occupancy
    w2 = trial.suggest_float("w2", 1e-2, 1.00, log=False)   # for bone occupancy
    w3 = trial.suggest_float("w3", 1e-2, 1.00, log=False)   # for fracture occupancy
    weight = torch.tensor([w1,w2,w3], device=device)
    gamma = trial.suggest_discrete_uniform("gamma", 1, 5, 0.25)
    criterion = FocalLossMulticlass(weight=weight, gamma=gamma, reduction='sum')
    learning_rate = trial.suggest_categorical("lr", [2.5e-4, 1e-4, 5e-5, 2.5e-5, 1e-5])
    #optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
    #optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1, threshold=1e-3, 
                                  threshold_mode='rel', cooldown=0, min_lr=0, verbose=True)
    # Traning and Validation loop
    scaler = torch.cuda.amp.GradScaler()
    stepShow = round(len(trainLoader)/11)
    for e in range(epochs):
        time1 = time.time()
        print('----- Epoch = {} ----- # Learning rate = {:.4e}'.format(e+1, optimizer.param_groups[0]["lr"]))
        max_train_acc, train_loss, train_acc, train_acc2, max_val_acc, val_loss, val_acc, val_acc2 = 0, 0, 0, 0, 0, 0, 0 ,0
        hd = HausdorffDistance()
        for t, train_sample in enumerate(trainLoader):
            if len(train_sample)%batch_sz != 0 or t==len(trainLoader)-1:  # exclude inequal batch
                #print('Final training accuracy = {:.4f}'.format(acc))
                break
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=True):
                model.train(mode=True)    # put model to training mode
                target = train_sample['Target'].to(device=device, dtype=dtype1)
                ap = train_sample['AP'].to(device=device, dtype=dtype1)
                lat = train_sample['LAT'].to(device=device, dtype=dtype1)
                
                # calculation of output
                output = model(ap,lat)
                
                # Calculate loss of this sample batch
                #loss = criterion(output,target[0:2]).to(device=device)  # calculate loss from both foreground and background (and fracMask)
                loss = criterion(output, target.long())          # for ToTensor6 with FocalLossMulticlass
                train_loss += loss.item()
                
                # Check accuracy
                #acc = iou((output[:,1]>=0.5).float(),target[:,1])                              # for ToTensor1-5
                #acc2 = iou((output[:,0]>=0.5).float(), (target[:,0]).float())                  # for ToTensor1-5
                #acc2 = hd.compute((output[:,2:3]>=0.5).float(), (target[:,2:3]>=0.5).float())  # for ToTensor1-5
                acc = iou((output[:,1]>=0.5).float(), (target==1).float())         # for ToTensor6
                acc2 = iou((output[:,2]>=0.5).float(), (target==2).float())        # for ToTensor6
                train_acc += acc.detach()
                
                if acc > max_train_acc:    # record max.accuracy in current batch
                    max_train_acc = acc
                if t%(round(len(trainLoader)/10)) == 0:
                    print('Iteration: {}   |   Loss = {:.4f}   |   Accuracy = {:.4f} {:.4f}'.format(t, loss.item(), acc, acc2))
                if t==len(trainLoader)-1:
                    print('Final training accuracy = {:.4f}'.format(acc))
                
            scaler.scale(loss).backward()
            torch.cuda.synchronize()
            scaler.step(optimizer)
            scaler.update()
        
        # Append Loss and Accuracy history
        train_loss /= round(len(trainLoader)/batch_sz)
        train_acc /= round(len(trainLoader)/batch_sz)
        train_loss_history.append(train_loss)
        train_acc_history.append(train_acc)
        print('# Training loss = {:.4f}'.format(train_loss))
        print('Training accuracy = {:.4f}'.format(train_acc))
        print('Max. Training accuracy = {:.4f}'.format(max_train_acc))
        
        time2 = time.time()
        print('Duration training time = {} Min.\n'.format((time2-time1)/60))
        
###########################################################################################################################
        
        print('### Validation loop ###')
        model.eval()
        time3 = time.time()
        for t, val_sample in enumerate(valLoader):
            if len(val_sample)%batch_sz != 0 or t==len(valLoader)-1: # exclude inequal batch
                #print('Final validation accuracy = {:,.4f}'.format(acc))
                break
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    target = val_sample['Target'].to(device=device, dtype=dtype1)
                    ap = val_sample['AP'].to(device=device, dtype=dtype1)
                    lat = val_sample['LAT'].to(device=device, dtype=dtype1)
                    output = model(ap,lat)
                    
                    #loss = criterion(output,target[:,0:2])         # for ToTensor1-5
                    loss = criterion(output ,target.long())   # for ToTensor6
                    val_loss += loss.item()
                    
                    #acc = iou((output[:,1]>=0.5).float(),target[:,1])                                # for ToTensor1-5
                    #acc2 = iou((output[:,0]>=0.5).float(), (target[:,0]).float())                    # for ToTensor1-5
                    #acc2 = hd.compute((output[:,2:3]>=0.5).float(), (target[:,2:3]>=0.5).float())    # for ToTensor1-5
                    acc = iou((output[:,1]>=0.5).float(), (target==1).float())                        # for ToTensor6
                    acc2 = iou((output[:,2]>=0.5).float(), (target==2).float())                       # for ToTensor6
                    #acc2 = hausdorff_voxel((output[:,2]>=0.5).float(), (target==2).float())          # for ToTensor6
                    val_acc += acc.detach()
                    val_acc2 += acc2.detach()
                    if acc > max_val_acc:    # record max. validation accuracy in current batch
                        max_val_acc = acc
                    if t%(round(len(valLoader)/10)) == 0:
                        print('Iteration: {}   |   Loss = {:,.4f}   |   Accuracy = {:.4f} {:.4f}'.format(t, loss.item(), acc, acc2))
                    if t==len(valLoader)-1:
                        print('Final validation accuracy = {:,.4f}'.format(acc))
        
        val_loss /= round(len(valLoader)/batch_sz)
        val_acc /= round(len(valLoader)/batch_sz)
        val_acc2 /= round(len(valLoader)/batch_sz)
        val_loss_history.append(val_loss)
        val_acc_history.append(val_acc)
        torch.cuda.synchronize()
        print('# Validation loss = {:.4f}'.format(val_loss))
        print('Validation accuracy = {:.4f}'.format(val_acc))
        print('Max. Validation accuracy = {:.4f}'.format(max_val_acc))
        scheduler.step(val_acc)
        time4 = time.time()
        print('Duration validation time = {} Min. \n'.format((time4-time3)/60))
        print('                      *****    Total epoch time = {} Min.    ***** \n'.format((time4-time1)/60))
        print('-'*120,'\n')
        
        '''trial.report(val_acc, e)   # Trial.report is not supported for multi-objective optimization
        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()'''
        
    return val_acc , val_acc2

def callback(study, trial):  # failed
    #print(' ##### callback def ##### ')
    #if study.best_trial == trial:
    #print('trial = {}'.format(trial))
    if study.best_trials[0].number == trial.number :   # for multi-objective
        print(' ***** Callback : Save Best Trial {} ***** '.format(trial.number))
        timeStr = str(datetime.datetime.now()) 
        note = 'Recon2X3D5 (batch2d/batch2d/batch3d) avg(X1,X2,X)*2-fusion layer (last fusion layer [1,16,16][16,3] + ToTensor6 using FocalMulticlass({0.1 0.3 0.6}, 2, sum)'
        save_dict = {'timestamp':timeStr ,
                     'note':note ,
                     'en_sz':en_sz ,
                     'de_sz':de_sz ,
                     'de3d_sz':de3d_sz ,
                     'final_sz':final_sz ,
                     'train_loss_history':train_loss_history ,
                     'train_acc_history':train_acc_history ,
                     'val_loss_history':val_loss_history ,
                     'val_acc_history':val_acc_history ,
                     #'model_state_dict':model.state_dict() ,              # without nn.DataParallel
                     'model_state_dict':model.module.state_dict() ,        # with    nn.DataParallel
                     'optimizer_state_dict': optimizer.state_dict()}
        # Save Model state_dict
        saved_name ='trained\hyper_' + study_name +  '.pt'   # file name = hyper_Recon2X3D5_210609.pt  (example)
        torch.save(save_dict, saved_name)
        print('\n   ##### Model is saved @ {}  ;\n   ##### Trial =  {} \n'.format(timeStr,trial))
    
    
#################### Starting trial ####################
torch.cuda.empty_cache()
#study_name = str(input(' Input Study Name = '))
study_name = 'Recon2X3D5_21060902'   # 'test'
de3d_sz = [[32,8],[24,6],[14,4],[8,2]]  # ignore
en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
final_sz = [[1,32,32],[1,32,32],[1,32,32],[1,16,16],[1,16,16],[1,16,16],[1,16,16]]

hypertuneLogic = int(input('Run Hyperparameter Tuning : [1] Yes  [0] No  = '))
storageLogic = int(input('Select saving storage [1] Remote database [2] Local : '))
if storageLogic==1:
    storage = 'sqlite:///recon2x3d5.db'       # 'sqlite:///example.db'
    print('   Use Remote databases (RDB)\nStudy name = {}   At {}\n'.format(study_name, storage))
    study = optuna.create_study(study_name=study_name , directions=['maximize','maximize'], 
                                storage=storage, load_if_exists=True)
elif storageLogic==2:
    print('   Use Local memory')
    creatnewstudyLogic = int(input('   Input [1] Create new study: <{}>  [2] Load existing: <{}.pkl> :'.format(study_name,study_name)))
    if creatnewstudyLogic==1:   # create new study
        study = optuna.create_study(study_name=study_name , direction=['maximize','maximize'])
    elif creatnewstudyLogic==2:  # load existing study
        study = joblib.load(r'trained\hyper_' + study_name + '.pkl')
        
print('   Study = {} | {}'.format(study_name ,study))
print("   Number of finished trials: ", len(study.trials))
print('   Sampler = {} '.format(study.sampler.__class__.__name__))
print('   Number of Training set : {}'.format(len(trainLoader)))
print('   Number of Validation set : {}'.format(len(valLoader)))

if hypertuneLogic==1:   # Hyperparameter Tuning
    n_trials = int(input('Input number of trials (n_trials) = '))
    print()
    time1 = time.time()
    #study.optimize(objective, n_trials=n_trials)
    study.optimize(objective, n_trials=n_trials, callbacks=[callback])  # Callback to save the best trained pytorch model
    if storageLogic==1:
        joblib.dump(study, r'trained\hyper_' + study_name + '.pkl')  # Save into local storage
    time2 = time.time()
    print(' ##### Total optimizing time = {} Hours #####\n'.format((time2-time1)/3600))

print("  Study : {} ".format(study_name))
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(study.get_trials(deepcopy=False, states=[TrialState.PRUNED])))
print("  Number of complete trials: ", len(study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])))

print("\nBest trial number: {}".format(study.best_trials[0].number))
trial = study.best_trials[0]
print("  Value: ", trial.values)
print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

In [None]:
print("\nBest trial number: {}".format(study.best_trials[0].number))
trial = study.best_trials[0]
print("  Value: ", trial.values)
print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

In [None]:
study.best_trials[0].number

In [None]:
trial = study.best_trials
print("  Value: ", trial.value)
print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

In [None]:
#help(study.trials[2])
for i, trial in enumerate(study.trials):
    print('Trial:{}   value = {}   Duration={} Hr.Min.Sec.milSec'.format(i, trial.values, trial.duration))
    print('Params: {}\n'.format(trial.params))

In [None]:
#print("   Number of finished trials: ", len(study.trials))
study.best_trials    # Get best trial's information.
#study.trials        # Get all trials' information.

#study.best_params   # Get best parameters for the objective function.
#study.best_value    # Get best objective value.

In [None]:
# For Deleting The Study from Remoting DataBase
delete_logic = int(input('Are you sure to delete The Study: < {} > ? ; [1] Yes  [0] No = '.format(study_name)))
if delete_logic==1:
    optuna.delete_study(study_name=study_name , storage=storage)
    print('*** The study deleted ***')
else:
    print('*** Do not delete the study ***')

**Visualize Hyperparameter**

In [None]:
plot_parallel_coordinate(study,target=['val_acc2'])

In [None]:
#plot_contour(study)
plot_contour(study, params=['w1','w2','w3'])   # params [Str] = gamma , w1 , w2 , w3 , lr

In [None]:
plot_slice(study)

In [None]:
plot_edf(study)

In [None]:
# Multi-objective
optuna.visualization.plot_pareto_front(study, target_names=["val_acc", "val_acc2"])

## Learning curve

In [None]:
'''val_loss_history.append(125478.0054)
val_acc_history.append(2.402127273)'''
print(saved_dict['train_loss_history'])
print(saved_dict['train_acc_history'])
print(saved_dict['val_loss_history'])
print(saved_dict['val_acc_history'])

In [None]:
# Visualize learning curve and accuracy

print('Total iterations = {:}'.format(len(saved_dict['train_loss_history'])))
print('Min. Training loss = {:,.4f}'.format(min(saved_dict['train_loss_history'])))
print('Max. Training accuracy = {:.4f}'.format(max(saved_dict['train_acc_history'])))
print('Min. Validation loss = {:,.4f}'.format(min(saved_dict['val_loss_history'])))
print('Max. Validation accuracy = {:.4f}'.format(max(saved_dict['val_acc_history'])))

plt.figure(figsize=(16,8))
plt.plot(saved_dict['train_loss_history'],'r',label='Training')
plt.plot(saved_dict['val_loss_history'],'g',label='Validation')
plt.ylabel('Loss history')
plt.xlabel('Iteration')
plt.title('Loss')
plt.legend()
plt.show()

plt.figure(figsize=(16,8))
plt.plot(saved_dict['train_acc_history'],'r',label='Training')
plt.plot(saved_dict['val_acc_history'],'g',label='Validation')
plt.ylabel('Accuracy history')
plt.xlabel('Iteration')
plt.title('Accuracy')
plt.legend()
plt.show()

# Save model

**CNN + FCN Model**

In [None]:
# Create dictionary to save
timeStr = str(datetime.datetime.now()) 
#note = 'Recon2X3D5 (instanceNorm/instanceNorm/instanceNorm) AVG(X1,X2,X)*2-fusion layer (last fusion layer [3,16,16][16,3] + ToTensor7 using FocalMulticlass({0.1 0.3 0.6}, 2, sum)'
note = 'Recon2X3D6 (instanceNorm/instanceNorm/instanceNorm) CAT[XD,XH,XW,X(C=3))*2-fusion layer (last fusion layer [3,16,16][16,3] + ToTensor7 using FocalMulticlass({0.15 0.25 0.6}, 2, sum)'
save_dict = {'timestamp':timeStr ,
             'note':note ,
             'en_sz':en_sz ,
             'de_sz':de_sz ,
             'de3d_sz':de3d_sz ,
             'final_sz':final_sz ,
             'train_loss_history':train_loss_history ,
             'train_acc_history':train_acc_history ,
             'val_loss_history':val_loss_history ,
             'val_acc_history':val_acc_history ,
             #'model_state_dict':model.state_dict() ,              # without nn.DataParallel
             'model_state_dict':model.module.state_dict() ,        # with    nn.DataParallel
             'optimizer_state_dict': optimizer.state_dict()}

# Save Model state_dict
saved_name = 'trained\Recon2X3D6_21101501.pt'
confirm_saved_name = str(input('Confirm Save : ' + saved_name + ' [y/n] ? '))
#confirm_saved_name = 'y'
if confirm_saved_name=='y':
    torch.save(save_dict, saved_name)
    print('model.parameters:')
    print(model.parameters)
    print('\n   ### Model is saved @ {}  ; Total Epoch =  {} ### \n'.format(timeStr,len(train_loss_history)))
else:
    print('\n   ### Model is not saved !!!')

**AutoEncoder Model**

In [None]:
note = 'Autoencoder: Encode3d using AdaptiveMaxPool3d'
timeStr = str(datetime.datetime.now())
save_dict = {'timestamp':timeStr , 
             'note':note , 
             'en3d_sz':en3d_sz , 
             'de3d_sz':de3d_sz ,
             'en_sz' : en_sz ,
             'final_sz': final_sz ,
             'train_loss_history':train_loss_history , 
             'train_acc_history':train_acc_history , 
             'val_loss_history':val_loss_history ,
             'val_acc_history':val_acc_history , 
             'model1_state_dict':model1.state_dict() ,    # for encode3d or encode2d
             'model2_state_dict':model2.state_dict() ,    # for encode3d
             'optimizer_state_dict': optimizer.state_dict()}

# Save Model state_dict
saved_name = 'trained\TLNetAuto_21012502.pt'
confirm_saved_name = str(input('Confirm save : ' + saved_name + ' [y/n] ? '))
#confirm_saved_name = 'y'
if confirm_saved_name=='y':
    torch.save(save_dict, saved_name)
    print('model1.parameters: \n{}'.format(model1.parameters))
    print()
    print('model2.parameters: \n{}'.format(model2.parameters))
    print()
    print('\n   ### Model is saved @ {}  ; Total Epoch =  {} ### \n'.format(timeStr,len(train_loss_history)))
else:
    print('\n   ### Model is not saved !!!')

# Testing Model

## Original

In [None]:
torch.cuda.empty_cache()

In [None]:
# For anlign dataset2

first = True

test_mode = int(input('Select test mode : [1] All sample  [2] Specify sample : '))

if test_mode == 1:
    test_set = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\Validation_set.xlsx'
    shuffle_logic = bool(input('Input Shuffle : [1] True  [0] False '))
    
elif test_mode == 2:
    shuffle_logic = False
    sampletype = int(input('Select sample type : [1] Intact  [2] Nondisplaced  [3] Displaced : '))
    if sampletype == 1:
        print('Intact        id = {0 5 8 14 16} ')
        test_set = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\IntactLogTotal0.xlsx'
    elif sampletype == 2:
        print('Nondisplaced  id = {0 1 2 3 4 5*} ')
        test_set = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\NondisplaceLogTotal0.xlsx'
    elif sampletype == 3:
        print('Displaced     id = {0 3 4* 5 8 9} ')
        test_set = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\DisplaceLogTotal0.xlsx'
    elif sampletype not in (1,3):
        raise ValueError('Error sampletype input : invalid integer input !!! ')
elif test_mode not in (1,2):
    raise ValueError('Error test_model input')
    
'''test_transformedFemur = FemurDataset(csv_file=test_set, root_dir=root_dir, 
                                     transform=transforms.Compose([NormalizeSample(),ToTensor7()]))  # For ToTensor7'''


test_transformedFemur = FemurDataset2(csv_file=test_set, root_dir=root_dir, 
                                    transform=transforms.Compose([NormalizeSample2(),ToTensor8()]))  # For ToTensor8-9
                                    
testLoader = DataLoader(test_transformedFemur, batch_size=1, shuffle=shuffle_logic, num_workers=4)

print(' --- END ---')

## k3D

In [None]:
#root_dir = r'D:\FEW PhD\Datasets\Siriraj Dataset\Cleaning Data\Siriraj_UnalignData'
#training_file = r'D:\FEW PhD\Datasets\Siriraj Dataset\Cleaning Data\Siriraj_UnalignData\Siriraj_testset.xlsx'
#val_file = r'D:\FEW PhD\Datasets\Siriraj Dataset\Cleaning Data\Siriraj_UnalignData\Siriraj_testset.xlsx'

In [None]:
# For unanlign dataset4
first = True

test_mode = int(input('Select test mode : [1] All sample  [2] Specify sample : '))
if test_mode == 1:
    test_set = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\TrainingSet_Scale12.xlsx'      # TestingSet_Scale12.xlsx
    shuffle_logic = bool(input('Input Shuffle : [1] True  [0] False '))
    sampletype = int(input('Select sample type : [1] Intact  [2] Nondisplaced  [3] Displaced [4] Siriraj: '))
    
elif test_mode == 2:
    shuffle_logic = False
    sampletype = int(input('Select sample type : [1] Intact  [2] Nondisplaced  [3] Displaced [4] Siriraj: '))
    if sampletype == 1:
        print('Intact {0 - 1625}          id = unalign{} ')    # 0 - 1625
        test_set = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\TestingSet_IntactLog2.xlsx'
    elif sampletype == 2:
        print('Nondisplaced {1626 - 2321} id = unalign{} ')   # 1626 - 2321
        test_set = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\TestingSet_NondisplaceLog2.xlsx'
    elif sampletype == 3:
        print('Displaced {2322 - 3365}    id = unalign{} ')   # 2322 - 3365
        test_set = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\TestingSet_DisplaceLog2.xlsx'
    elif sampletype == 4:
        print('Displaced {2322 - 3365}    id = unalign{} ')   # 2322 - 3365
        test_set = r'D:\FEW PhD\Datasets\Siriraj Dataset\Cleaning Data\Siriraj_UnalignData\Siriraj_testset.xlsx'
    elif sampletype not in (1,3):
        raise ValueError('Error sampletype input : invalid integer input !!! ')
elif test_mode not in (1,2):
    raise ValueError('Error test_model input')

'''test_transformedFemur = FemurDataset(csv_file=test_set, root_dir=root_dir, 
                                        transform=transforms.Compose([NormalizeSample(),ToTensor7()]))  # For ToTensor7'''

test_transformedFemur = FemurDataset2(csv_file=test_set, root_dir=root_dir, 
                                      transform=transforms.Compose([NormalizeSample2(),ToTensor8()]))  # For ToTensor8-9
                                    
testLoader = DataLoader(test_transformedFemur, batch_size=1, shuffle=shuffle_logic, num_workers=4)

print(' --- END ---')

In [None]:
# Original
torch.cuda.empty_cache()
test_mode = 2

if test_mode == 1:
    print('Testing on Shuffle sample')
    if first==True:
        print('First testing sample')
        femurIter = iter(testLoader)
        first = False
    else:
        print('Next testing sample')
    sample = next(femurIter)
    target , ap , lat = sample['Target'] , sample['view1'] , sample['view2']
    #target , ap , lat = sample['Target'] , sample['AP'] , sample['LAT']
    print('Raw: target={}  ap={}  lat={}'.format(target.size(),ap.size(),lat.size()))
    #target,ap,lat=target,ap[0].unsqueeze(0).to(device=device),lat[0].unsqueeze(0).to(device=device) # for ToTensor1-5
    target, ap, lat = target[0],ap[0].unsqueeze(0).to(device=device),lat[0].unsqueeze(0).to(device=device)  # for ToTensor6-8
    print('Final: target={}  ap={}  lat={}'.format(target.size(), ap.size(), lat.size()))
    
elif test_mode == 2:
    print('Choose Sample Number : 0 - {} ?'.format(len(testLoader)-1))
    sample_number = int(input('Sample number : '))    # using 123
    #valLoader = DataLoader(val_transformedFemur, batch_size=1, shuffle=False, num_workers=0)
    sample = test_transformedFemur[sample_number]
    target , ap , lat = sample['Target'] , sample['view1'] , sample['view2']
    #target , ap , lat = sample['Target'] , sample['AP'] , sample['LAT']
    print('Raw Dataset = {} {} {}'.format(target.size(),ap.size(),lat.size()))
    #target,ap,lat=target,ap[0].unsqueeze(0).to(device=device),lat[0].unsqueeze(0).to(device=device) # for ToTensor1-5
    target,ap,lat=target,ap.unsqueeze(0).to(device=device),lat.unsqueeze(0).to(device=device)  # for ToTensor6
    print('Final = {} {} {}'.format(target.size(), ap.size(), lat.size()))
    
with torch.cuda.amp.autocast(enabled=True):
    t1 = time.time()
    #output = model(ap).detach()
    #model = nn.DataParallel(model, device_ids=[0,1])
    #model = model.module
    #model.to(device='cpu')
    #model.to(device='cuda:0')
    #model = netG.to(device=device)
    model.eval()
    output = model(ap,lat).detach()
    #torch.cuda.synchronize()
    print('\nOutput = {}   {}'.format(output.size(), output.dtype))
    #print('output.device: {} \ntarget.device: {} '.format(output.device,target.device))
    t2 = time.time()
    print('Calculating time = {:,.4f} sec.'.format(t2-t1))
    
    
##### OUTPUT: Bone Visualization #####
fig, ax = plt.subplots(1,2, figsize=(20,20))
ax[0].imshow(ap.detach().cpu().numpy().squeeze(), cmap='gray')
ax[0].set_title('Input view 1')
ax[1].imshow(lat.detach().cpu().numpy().squeeze(), cmap='gray')
ax[1].set_title('Input view 2')
plt.show()

gt = (target==1).detach().cpu().numpy()
ot = (output[0,1,:]>0.5).detach().cpu().numpy()
#print('IOU = {:.4f} '.format(iou((output[0,1]>=0.5).float(), (target==1).float().to(device=device)))) # ToTensor6-9
print('IoU = {:,.3f}'.format(iou((target==1).float().cpu(), (output[0,1,:]>0.5).float().cpu())), end=', ')
dist_metrics = surface_distance_measurement(output[0,1].cpu().numpy(), (target==1).cpu().numpy(), res=0.5, verbose=False)
print('ASSD = {:.3f} mm'.format(dist_metrics['ASSD']))

### Ground Truth : Bone Class ###
# Camera view
cam_view1 = [0,0,-1.5, 0,0,0 ,0,-1,0]   # view1
cam_view2 = [1.5,0,0, 0,0,0 ,1,-1,0]    # view2

plot1 = k3d.plot(name='Plot 1 : [view1]')
obj = k3d.volume(gt, name='Ground Truth', 
                 color_map=k3d.colormaps.matplotlib_color_maps.Bone,
                 gradient_step=0.005,
                 shadow='dynamic',
                 shadow_delay=10,
                )
plot1 += obj + k3d.text2d(text='Ground Truth (view1)', color=0, size=1 ,position=(0.01,0.025), label_box=False)
plot1.display()
plot1.camera = cam_view1

plot3 = k3d.plot(name='Plot 1 : [view2]')
obj3 = k3d.volume(gt, name='Ground Truth', 
                  color_map=k3d.colormaps.matplotlib_color_maps.Bone,
                  gradient_step=0.005,
                  shadow='dynamic',
                  shadow_delay=10,
                 )
plot3 += obj3 + k3d.text2d(text='Ground Truth (view2)', color=0, size=1 ,position=(0.01,0.025), label_box=False)
plot3.display()
plot3.camera = cam_view2


### Output : Bone Class ###
plot2 = k3d.plot(name='Plot 2 : [view1]')
obj = k3d.volume(ot, name='Output',
                 color_map=k3d.colormaps.matplotlib_color_maps.Bone,
                 gradient_step=0.005,
                 shadow='dynamic',
                 shadow_delay=10,
                )
plot2 += obj + k3d.text2d(text='Output (view1)', color=0, size=1 ,position=(0.01,0.025), label_box=False)
plot2.display()
plot2.camera = cam_view1


plot4 = k3d.plot(name='Plot 2 : [view2]')
obj = k3d.volume(ot, name='Output',
                 color_map=k3d.colormaps.matplotlib_color_maps.Bone,
                 gradient_step=0.005,
                 shadow='dynamic',
                 shadow_delay=10,
                )
plot4 += obj + k3d.text2d(text='Output (view2)', color=0, size=1 ,position=(0.01,0.025), label_box=False)
plot4.display()
plot4.camera = cam_view2
print(' --- END --- ')
### For align dataset2 ###
# Intact        id = {0 5 8 14 16}
# Nondisplaced  id = {6 1 2 3 4 5'}
# Displaced     id = {0 3 4' 5 8 9}

### For unalign dataset4 ###
### For Training set ###
# Intact {0 - 1625}          id = {14771(0),14772(+10),14773(+7.5),14774(+5),14775(+2.5)}           {(0),(+10),(+7.5),(+5),(+2.5)}
# Nondisplaced {1626 - 2321} id = {16453(0),16454(+10),16455(+7.5),16456(+5),16457(+2.5)}
#                               = {(0),(+10),(+7.5),(+5),(+2.5)}
#                               = {(0),(+10),(+7.5),(+5),(+2.5)}
# Displaced {2322 - 3365}    id = {17700(0),17701(+10),17702(+7.5),17703(+5),17704(+2.5)}
#                            id = {18077(0),18078(+10),18079(+7.5),18080(+5),18081(+2.5)}
### For Testing set ###
# Intact                     id = {39(0),40(+10),41(+7.5),42(+5),43(+2.5)}           {(0),(+10),(+7.5),(+5),(+2.5)}
#                                 {(0),(+10),(+7.5),(+5),(+2.5)}
# Nondisplace                id = {39(0),40(+10),41(+7.5),42(+5),43(+2.5)}   ไม่เอา
#                            id = {10(0),11(+10),12(+7.5),13(+5),14(+2.5)}
#                            id = {68(0),69(+10),70(+7.5),71(+5),72(+2.5)}
# Displace                   id = {10(0),11(+10),12(+7.5),13(+5),14(+2.5)}
#                            id = {155(0),156(+10),157(+7.5),158(+5),159(+2.5)}#         

In [None]:
# Update upper display: camera view of ground truth to corresponding camera view of output
plot3.camera = plot4.camera

In [None]:
# Update lower display: camera view of output to corresponding camera view of ground truth
plot4.camera = plot3.camera

In [None]:
### Show Volume rotation360 ###
N = 1   # number of round
r_orbit = 1.75
camera_rotate = list([-r_orbit*np.sin(t), -0.2,r_orbit*np.cos(t), 0,0,0, 0,-1,0] for t in np.linspace(0, 2*np.pi*N, num=360*N) )  # 360*N
k3d.plot()
plot3.grid_visible = False
plot4.grid_visible = False

for i, view in enumerate(camera_rotate):
    #print('Order:{} | {}'.format(i,view))
    plot3.camera = view
    plot4.camera = view
    time.sleep(8/360)

print('--- Rotation END ---')

**Save**

In [None]:
### Save the result in k3d-widget.html files ###
from PIL import Image
# Normalize the intensity of DRRs
drr1_file = ap.detach().cpu().numpy().squeeze()
drr2_file = lat.detach().cpu().numpy().squeeze()
drr1_file = (drr1_file - drr1_file.min())*65535/(drr1_file.max() - drr1_file.min())
drr2_file = (drr2_file - drr2_file.min())*65535/(drr2_file.max() - drr2_file.min())
drr1_file = drr1_file.astype(np.uint16())   # float > uint16 for .png
drr2_file = drr2_file.astype(np.uint16())   # float > uint16 for .png
drr1_file = cv2.resize(drr1_file, dsize=(1024, 1024), interpolation=cv2.INTER_CUBIC)
drr2_file = cv2.resize(drr2_file, dsize=(1024, 1024), interpolation=cv2.INTER_CUBIC)
drr1_file = Image.fromarray(drr1_file)
drr2_file = Image.fromarray(drr2_file)

# Name the files
k3dResultLocation = r'trained\K3DResult\\FracReconNet_Recon2X3D5_21101901_unalignTrain (feak)'   # root directory
confirm_savePath = int(input('Confirm savePath = {}\n    [1] Yes  [0] No = '.format(k3dResultLocation)))
if confirm_savePath==1:
    drr1_filename = k3dResultLocation + r'\DRRs\drr1-type-' + str(sampletype) + '-NO-' + str(sample_number) + '.png'
    drr2_filename = k3dResultLocation + r'\DRRs\drr2-type-' + str(sampletype) + '-NO-' + str(sample_number) + '.png'
    gt_filename = k3dResultLocation + r'\Volume\volumeGT-type-' + str(sampletype) + '-NO-' + str(sample_number) + '.html' # For only 0.0 degree alignment
    ot_filename = k3dResultLocation + r'\Volume\volumeOT-type-' + str(sampletype) + '-NO-' + str(sample_number) + '.html'

    # save ผิด folder FracReconNet_Recon2X3D5_22081001_unalignTest 
    # Save the files
    drr1_file.save(drr1_filename)
    drr2_file.save(drr2_filename)
    with open(gt_filename,'w') as fp:     # For only 0.0 degree alignment
        fp.write(plot3.get_snapshot()) 
    with open(ot_filename,'w') as fp:
        fp.write(plot4.get_snapshot())
    print(' Saving complete :\n      {}\n      {}\n      {}\n      {}'.format(drr1_filename,drr2_filename,gt_filename, ot_filename))

In [None]:
### Save RAW volumetric data in .numpy files 
gt_np = (target==1).detach().cpu().numpy()
ot_np = (output[0,1,:]>0.5).detach().cpu().numpy()
print('gt = {}\not = {}'.format(gt_np.shape, ot_np.shape))

# Name the file
gt_array_filename = k3dResultLocation + r'\Numpy Array\gtArray-type-' + str(sampletype) + '-NO-' + str(sample_number) + '.npy'
ot_array_filename = k3dResultLocation + r'\Numpy Array\otArray-type-' + str(sampletype) + '-NO-' + str(sample_number) + '.npy'

# Saving the .numpy files
np.save(gt_array_filename, gt_np)
np.save(ot_array_filename, ot_np)
print(' Saving complete :\n      {}\n      {}'.format(gt_array_filename, ot_array_filename))

In [None]:
### Snapshot view.png files
folder_dir = k3dResultLocation + r"\GIF_Volume"

N = 1
r_orbit = 1.75
camera_rotate = list([r_orbit*np.sin(t), -0.2,r_orbit*np.cos(t), 0,0,0, 0,-1,0] for t in np.linspace(0, 2*np.pi*N, num=45*N) )  # 360*N

out = ipywidgets.Output()
gt_or_ot = int(input('Select: [1] Ground-truth [2] Output: '))
if gt_or_ot == 1:
    plot3.grid_visible = False
    @plot3.yield_screenshots   # running in the background  change={plot3, plot4}
    def coroutine():
        print('Running coroutine for ground-truth ...')
        global camera_rotate   # use global parameters
        for i, view in enumerate(camera_rotate):
            print(i)
            # For ground-truth ####
            plot3.camera = view     
            plot3.fetch_screenshot()
            screenshot = yield 
            with open(folder_dir + r'\GT\viewGT_%03d.png'%i, 'wb') as f:
                f.write(screenshot)
            with out:
                print('viewGT_%03d.png saved.'%i)
        with out:
            print('Done!!! \n')
    
elif gt_or_ot == 2:
    plot4.grid_visible = False
    @plot4.yield_screenshots   # running in the background  change={plot3, plot4}
    def coroutine():
        print('Running coroutine for output ...')
        global camera_rotate   # use global parameters
        for i, view in enumerate(camera_rotate):
            print(i)
            ### For output ###
            plot4.camera = view
            plot4.fetch_screenshot()
            screenshot = yield
            with open(folder_dir + r'\OT\viewOT_%03d.png'%i, 'wb') as f:
                f.write(screenshot)
            with out:
                print('viewOT_%03d.png '%i, end=' ')
        with out:
            print('Done!!! \n')
        
coroutine()
out
print('Saving folder: {}'.format(folder_dir))

In [None]:
### Generate the Volume_GIF file
# get the path/directory
multi_GTimgs = list()
multi_OTimgs = list()
t1 = time.time()
if gt_or_ot == 1:
    print('Generating .GIF for Ground-truth')
    for image_name in os.listdir(folder_dir + r'\GT'):    
        # check if the image ends with png
        if (image_name.endswith(".png")):   # Ground-truth
            print('Name: ',image_name, end=' ')
            img = Image.open(os.path.join(folder_dir,'GT',image_name))
            multi_GTimgs.append(img)

    gif_GTname = folder_dir + r'\GT\GIFVolumeGT-type-' + str(sampletype) + '-NO-' + str(sample_number) +'.gif'
    multi_GTimgs[0].save(gif_GTname, format='GIF', append_images=multi_GTimgs[1:], save_all=True, duration=250, loop=0 )
    print('\n ### Save .gif name: {}'.format(gif_GTname))
elif gt_or_ot == 2:
    print('Generating .GIF for Output')
    for image_name in os.listdir(folder_dir + r'\OT'):   
        if (image_name.endswith(".png")):   # Output
            print(image_name, end=' ')
            img = Image.open(os.path.join(folder_dir,'OT',image_name))
            multi_OTimgs.append(img)
    gif_OTname = folder_dir + r'\OT\GIFVolumeOT-type-' + str(sampletype) + '-NO-' + str(sample_number) +'.gif'
    multi_OTimgs[0].save(gif_OTname, format='GIF', append_images=multi_OTimgs[1:], save_all=True, duration=250, loop=0 )
    print('\n ### Save .gif name: {}'.format(gif_OTname))
t2 = time.time()
print('Done in {} sec.'.format(t2-t1))

## ASSD

In [None]:
t1 = time.time()
surf_dist_result = surface_distance_measurement(gt, ot, res=0.5, return_vert_dist=True, verbose=False)
t2 = time.time()
print('\nTotal time of calculation = {} sec.'.format(t2-t1))
print('key result = {}'.format(surf_dist_result.keys()))

background_color = 65536*255 + 256*255 + 255
pltmesh = k3d.plot(background_color=background_color)
title_text = k3d.text2d(text='Min. Surface Distance (GroundTruth-based)', color=0, size=1 ,position=(0.01,0.025), label_box=False)
meshsurf = k3d.mesh(surf_dist_result['surface_dist_gt']['target_vert'], surf_dist_result['surface_dist_gt']['target_face'],
                    name='surface distance (mm)',
                    color_map=k3d.colormaps.basic_color_maps.Jet,
                    #color_map=k3d.colormaps.paraview_color_maps.Bone_Matlab
                    color_range=[0,3],   # [0,3] for align dataset   |   [0,6] for unalign dataset
                    attribute=surf_dist_result['surface_dist_gt']['distances'].astype(np.float32())
                   )   #     -target_vert[:,2]
pltmesh += title_text + meshsurf
pltmesh.display()
pltmesh.camera = [128,-100,512, 128,-128,128,  0,1,0]


pltmesh3 = k3d.plot(background_color=background_color)
title_text3 = k3d.text2d(text='Min. Surface Distance (GroundTruth-based2)', color=0, size=1 ,position=(0.01,0.025), label_box=False)
meshsurf3 = k3d.mesh(surf_dist_result['surface_dist_gt']['target_vert'], surf_dist_result['surface_dist_gt']['target_face'],
                    name='surface distance (mm)',
                    color_map=k3d.colormaps.basic_color_maps.Jet,
                    #color_map=k3d.colormaps.paraview_color_maps.Bone_Matlab
                    color_range=[0,3],   # [0,3] for align dataset   |   [0,6] for unalign dataset
                    attribute=surf_dist_result['surface_dist_gt']['distances'].astype(np.float32())
                   )   #     -target_vert[:,2]
pltmesh3 += title_text3 + meshsurf3
pltmesh3.display()
pltmesh3.camera = [-224,-100,128, 128,-128,128,  0,1,0]


pltmesh2 = k3d.plot(background_color=background_color)
title_text2 = k3d.text2d(text='Min. Surface Distance (Output-based)', color=0, size=1 ,position=(0.01,0.025), label_box=False)
meshsurf2 = k3d.mesh(surf_dist_result['surface_dist_ot']['target_vert'], surf_dist_result['surface_dist_ot']['target_face'],
                    name='surface distance (mm)',
                    color_map=k3d.colormaps.basic_color_maps.Jet, 
                    color_range=[0,3],  # [0,3] for align dataset   |   [0,6] for unalign dataset
                    attribute=surf_dist_result['surface_dist_ot']['distances'].astype(np.float32())
                   )   #     -target_vert[:,2]
pltmesh2 += title_text2 + meshsurf2
pltmesh2.display()
pltmesh2.camera = [128,-100,512, 128,-128,128,  0,1,0]
print('--- END ---')

In [None]:
pltmesh.fetch_screenshot()
base64_decoded = base64.b64decode(pltmesh.screenshot)
img_snap = Image.open(io.BytesIO(base64_decoded))
img_np = np.array(img_snap)

In [None]:
plt.figure(figsize=(16,8))
plt.imshow(img_np[:,:,0:3])
plt.show()

In [None]:
meshsurf.attribute = {str(t): 0.5*surface_dist_gt['distances']*t for t in np.linspace(0, 1.5, num=100)}
meshsurf2.attribute = {str(t): 0.5*surface_dist_ot['distances']*t for t in np.linspace(0, 1.5, num=100)}
pltmesh.start_auto_play()
pltmesh2.start_auto_play()

In [None]:
pltmesh.stop_auto_play()
pltmesh2.stop_auto_play()

In [None]:
### Show Mesh rotation360 ###
N = 1   # number of round
r_orbit = 384   #384
camera_rotate = list([128+r_orbit*np.sin(t),-90,128-r_orbit*np.cos(t), 128,-128,128, 0,1,0] 
                     for t in np.linspace(0+np.pi/2, 2*np.pi*N+np.pi/2, num=360*N) )  # 360*N
#k3d.plot()
pltmesh.grid_visible = False
pltmesh2.grid_visible = False
#pltmesh2.camera_mode = 'orbit'
#pltmesh2.camera_auto_fit = False

for i, view in enumerate(camera_rotate):
    #print('Order:{} | {}'.format(i,view))
    pltmesh.camera = view
    pltmesh2.camera = view
    time.sleep(8/360)

print('--- Rotation END ---')

**Save**

In [None]:
### Save the result in k3d-widget.html files ###
gt_filename = k3dResultLocation + r'\Mesh\meshGT-type-' + str(sampletype) + '-NO-' + str(sample_number) + '.html'
ot_filename = k3dResultLocation + r'\Mesh\meshOT-type-' + str(sampletype) + '-NO-' + str(sample_number) + '.html'
with open(gt_filename,'w') as fp:
    fp.write(pltmesh.get_snapshot())
with open(ot_filename,'w') as fp:
    fp.write(pltmesh2.get_snapshot())
print(' Saving complete :\n      {}\n      {}'.format(gt_filename, ot_filename))

In [None]:
### Snapshot view.png files
folder_dir2 = k3dResultLocation + r"\GIF_Mesh"
print('Saving folder: {}'.format(folder_dir2))
N = 1
r_orbit = 384
camera_rotate = list([128+r_orbit*np.sin(t),-90,128-r_orbit*np.cos(t), 128,-128,128, 0,1,0] 
                     for t in np.linspace(0+np.pi/2, 2*np.pi*N+np.pi/2, num=45*N) )  # 360*N

pltmesh.grid_visible = False
pltmesh2.grid_visible = False

out = ipywidgets.Output()
@pltmesh.yield_screenshots
def coroutine():
    print('coroutine is running ...')
    global camera_rotate   # use global parameters
    for i, view in enumerate(camera_rotate):
        print(i)
        pltmesh.camera = view
        pltmesh.fetch_screenshot()
        screenshot = yield
        with open(folder_dir2 + r'\view_%03d.png'%i, 'wb') as f:
            f.write(screenshot)
        with out:
            print('view_%03d.png'%i, end=' ')
    with out:
        print('Done!!! \n')

coroutine()
out

In [None]:
### Generate the GIF file
# get the path/directory
multi_imgs = list()
print('Converting images.png to .gif')
for image_name in os.listdir(folder_dir2):
    # check if the image ends with png
    if (image_name.endswith(".png")):
        print(image_name, end=' ')
        img = Image.open(os.path.join(folder_dir2,image_name))
        multi_imgs.append(img) # imageio.imread(image_name)
    #print('multi_imgs len = {}'.format(len(multi_imgs)))

t1 = time.time()
gif_name = folder_dir2 + r'\GIFMeshGT-type-' + str(sampletype) + '-NO-' + str(sample_number) +'.gif'
multi_imgs[0].save(gif_name, format='GIF', append_images=multi_imgs[1:], save_all=True, duration=250, loop=0 )   # duration = time to the next picture
t2 = time.time()
print('\n\n### Done in {} sec. : {}'.format(t2-t1, gif_name))

**Obsolete**

In [None]:
# Do not work !!!
import imageio 
import base64
import io

gif_name = k3dResultLocation + r'\GIF\meshGT-type-' + str(sampletype) + '-NO-' + str(sample_number) + '.gif'
print('GIF file name = {}\n'.format(gif_name))
pltmesh2.camera_auto_fit = False
N = 1   # number of round
r_orbit = 384
camera_rotate = list([128-r_orbit*np.sin(t),-90,128-r_orbit*np.cos(t), 128,-128,128, 0,1,0] for t in np.linspace(0, 2*np.pi*N, num=5*N) )  # 360*N
#k3d.plot()
#pltmesh2.camera_mode = 'orbit'
pltmesh.grid_visible = False
pltmesh2.grid_visible = False


@pltmesh.yield_screenshots
def save_gif(camera_rotate):
    gif = []   # for save .gif files
    for i, view in enumerate(camera_rotate):
        print('Order:{} | {}'.format(i,view))
        pltmesh.camera = view
        pltmesh2.camera = view
        time.sleep(8/360)   

        # Generating .gif file
        pltmesh.fetch_screenshot()
        #time.sleep(10)  # Waiting for the widgets to synchronize behind the scenes, before calling the next cell.
        try:
            print('method1')
            img_snap = pltmesh.screenshot.decode('base64')
        except:
            print('method2')
            #base64_decoded = base64.b64decode(pltmesh.screenshot)
            base64_decoded = base64.b64decode(pltmesh.fetch_screenshot().screenshot)
            img_snap = Image.open(io.BytesIO(base64_decoded))

        img_np = np.array(img_snap)
        print(img_np.shape, type(img_np), img_np.dtype)
        gif.append(img_np[:,:,0:3])
        
    imageio.mimwrite(gif_name, gif, duration=1/30)
    print('### Save .GIF is done!!! ###')
    return gif

print('--- END ---')

## ipyvolume ploting

model = netG

In [None]:
# Original
#torch.cuda.empty_cache()
encode_feat1 = list()
encode_feat2 = list()
decode_feat1 = list()
decode_feat2 = list()
fusion_feat = list()
fusionUp_feat = list()
test_mode = 2

if test_mode==1:
    print('Testing on Shuffle sample')
    if first==True:
        print('First testing sample')
        femurIter = iter(testLoader)
        first = False
    else:
        print('Next testing sample')
    sample = next(femurIter)
    target , ap , lat = sample['Target'] , sample['view1'] , sample['view2']
    #target , ap , lat = sample['Target'] , sample['AP'] , sample['LAT']
    print('Raw: target={}  ap={}  lat={}'.format(target.size(),ap.size(),lat.size()))
    #target,ap,lat=target,ap[0].unsqueeze(0).to(device=device),lat[0].unsqueeze(0).to(device=device) # for ToTensor1-5
    target,ap,lat=target[0],ap[0].unsqueeze(0).to(device=device),lat[0].unsqueeze(0).to(device=device)  # for ToTensor6-8
    print('Final: target={}  ap={}  lat={}'.format(target.size(), ap.size(), lat.size()))
    
elif test_mode==2:
    print('Choose Sample Number : 0 - {} ?'.format(len(testLoader)-1))
    sample_number = int(input('Sample number : '))    # using 123
    #valLoader = DataLoader(val_transformedFemur, batch_size=1, shuffle=False, num_workers=0)
    sample = test_transformedFemur[sample_number]
    target , ap , lat = sample['Target'] , sample['view1'] , sample['view2']
    #target , ap , lat = sample['Target'] , sample['AP'] , sample['LAT']
    print('Raw Dataset = {} {} {}'.format(target.size(),ap.size(),lat.size()))
    #target,ap,lat=target,ap[0].unsqueeze(0).to(device=device),lat[0].unsqueeze(0).to(device=device) # for ToTensor1-5
    target,ap,lat=target,ap.unsqueeze(0).to(device=device),lat.unsqueeze(0).to(device=device)  # for ToTensor6
    print('Final = {} {} {}'.format(target.size(), ap.size(), lat.size()))
    
with torch.cuda.amp.autocast(enabled=True):
    t1 = time.time()
    #output = model(ap).detach()
    #model = nn.DataParallel(model, device_ids=[0,1])
    #model = model.module
    #model.to(device='cpu')
    #model.to(device='cuda:0')
    model.eval()
    output = model(ap,lat).detach()
    #torch.cuda.synchronize()
    print('\nOutput = {}   {}'.format(output.size(), output.dtype))
    #print('output.device: {} \ntarget.device: {} '.format(output.device,target.device))
    t2 = time.time()
    print('Calculating time = {:,.4f} sec.'.format(t2-t1))
    '''fig, ax = plt.subplots(1,2, figsize=(16,16))
    ax[0].imshow(ap.detach().cpu().numpy().squeeze(), cmap='gray')
    ax[0].set_title('Input view 1')
    ax[1].imshow(lat.detach().cpu().numpy().squeeze(), cmap='gray')
    ax[1].set_title('Input view 2')
    plt.show()'''
    #print(' --- END ---')
    
    
### OUTPUT: Bone Visualization ###
print('\n ### OUTPUT: Bone Visualization ###')
print('### Bone ###')
print('IOU = {:.4f} '.format(iou((output[0,1]>=0.5).float(),(target==1).float().to(device=device)))) # ToTensor6-9
dist_metrics = surface_distance_measurement(output[0,1].cpu().numpy(), (target==1).cpu().numpy(), res=0.5, verbose=False)
print('ASSD = {:.3f} mm'.format(dist_metrics['ASSD']))
print('BHD = {:.3f} mm'.format(dist_metrics['BHD']))
#print('IOU = {:.4f} '.format(iou((output[0].argmax(dim=0)).float(),(target==1).float().to(device=device)))) # ToTensor6-9
#print('HD = {:.3f} \n'.format(hausdorff_voxel((output[0,1]>=0.5).float(),(target==1).float().to(device=device))))


#fig = ipv.figure()
'''control = pythreejs.OrbitControls(controlling=fig.camera)
fig.controls = control                # assigning to fig.controls will overwrite the builtin controls
control.autoRotate = True
control.autoRotateSpeed = 10.0
fig.render_continuous = True          # the controls does not update itself, ut if we toggle this setting, ipyvolume will update the controls
'''

### Ground Truth : Bone Class ###
'''ipv.figure()   # for ToTensor7 view1
ipv.volshow((target==1).transpose(2,1).flip(0).cpu().detach().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1], max_opacity=3)
ipv.style.use(['minimal','light'])
ipv.show()
time.sleep(2)

ipv.figure()   # for ToTensor7 view2
ipv.volshow((target==1).float().detach().cpu().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1],)
ipv.view(270,90)
ipv.style.use(['minimal','light'])   
ipv.show()
time.sleep(2)'''


'''ipv.figure()   # for ToTensor8-9 view1
ipv.volshow((target==1).float().flip(0).flip(1).float().detach().cpu().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1],)
#ipv.view(270,90)
ipv.style.use(['minimal','light'])   
ipv.show()
time.sleep(2)
ipv.figure()

ipv.figure()   # for ToTensor8-9 view2
ipv.volshow((target==1).transpose(2,0).flip(1).float().detach().cpu().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1],)
#ipv.view(270,90)
ipv.style.use(['minimal','light'])   
ipv.show()
time.sleep(2)'''

### Output : Bone Class ###
'''ipv.figure()   # for ToTensor7
ipv.volshow((output[0,1]>=0.5).cpu().detach().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1], max_opacity=3)
#ipv.volshow(((output[0].argmax(dim=0)==1).float().transpose(2,0).flip(1)).cpu().detach().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1], max_opacity=3)
#ipv.view(270,90)
ipv.style.use(['minimal','light'])
ipv.show()
time.sleep(2)'''


ipv.figure()     # for ToTensor8-9   view1
ipv.volshow((output[0,1].flip(0).flip(1)>=0.5).cpu().detach().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1], max_opacity=3)
#ipv.volshow(((output[0].argmax(dim=0)==1).float().flip(0).flip(1)).cpu().detach().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1], max_opacity=3)
#ipv.volshow((output[0,1].transpose(2,1).flip(0)>=0.5).cpu().detach().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1], max_opacity=3)
ipv.style.use(['minimal','light']) 
ipv.show()
time.sleep(2)
ipv.figure()   # for ToTensor8-9 view2
ipv.volshow((output[0,1].transpose(2,0).flip(1).float()>=0.5).detach().cpu().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1],)
#ipv.view(270,90)
ipv.style.use(['minimal','light'])   
ipv.show()
print(' --- END --- ')

# Intact        id = {0 5 8 14 16}
# Nondisplaced  id = {0 1 2 3 4 5'}
# Displaced     id = {0 3 4' 5 8 9}

In [None]:
output.size()

In [None]:
gt = target.detach().cpu().numpy()
ot = output.detach().cpu().numpy()
np.save('dataset\gt_displaced8.npy', gt)
np.save('dataset\ot_displaced8.npy', ot)
print(' --- END --- ')

In [None]:
'''fig2 = ipv.figure()
control2 = pythreejs.OrbitControls(controlling=fig2.camera)
fig.controls = control2
control2.autoRotate = True
control2.autoRotateSpeed = 10.0
fig2.render_continuous = True''' 


ipv.figure()   # for ToTensor7 view1
ipv.volshow((output[0,1].transpose(2,1).flip(0)>=0.5).cpu().detach().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1], max_opacity=3)
ipv.style.use(['minimal','light'])
ipv.show()
time.sleep(2)

ipv.figure()   # for ToTensor7 view2
ipv.volshow((output[0,1]>=0.5).cpu().detach().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1], max_opacity=3)
ipv.view(270,90)
ipv.style.use(['minimal','light'])
ipv.show()


'''ipv.figure()    # for ToTensor8-9 view1
ipv.volshow((output[0,1].flip(0).flip(1)>=0.5).cpu().detach().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1], max_opacity=3)
ipv.style.use(['minimal','light']) 
ipv.show()

ipv.figure()    # for ToTensor8-9 view2
ipv.volshow((output[0,1].transpose(2,0).flip(1)>=0.5).cpu().detach().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1], max_opacity=3)
ipv.style.use(['minimal','light']) 
ipv.show()
time.sleep(2)'''
print(' --- END --- ')

In [None]:
#print('***  IOU Fracture = {:.4f}  *** \n'.format(iou((output[0,2]>=0.5).float(),(target[0,2]).float().to(device=device))))  # ToTensor1-5
print('### Fracture ###')
print('IOU = {:.4f} '.format(iou((output[0,2]>=0.5).float(),(target==2).float().to(device=device))))  # ToTensor6
t1 = time.time()
'''hd = HausdorffDistance()
print('Exact HD Fracture = {:.2f}\n'.format(hd.compute((output[:,2:3]>=0.5).float(),
                                                                      (target==2).unsqueeze(0).unsqueeze(0).float().to(device=device) )))'''
dist_metrics = surface_distance_measurement(output[0,2].cpu().numpy(), (target==2).cpu().numpy(),verbose=False)
print('ASSD = {:.3f} mm'.format(dist_metrics['ASSD']))
print('BHD = {:.3f} mm'.format(dist_metrics['BHD']))
t2 = time.time()
print('      Times calculating surface distance metrices = {:.3f} sec.'.format(t2-t1))
ipv.figure()
#ipv.volshow((target[0,2]).float().detach().cpu().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1],)
ipv.volshow((target==2).transpose(2,0).flip(1).float().detach().cpu().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1],)
#ipv.view(270,90)
ipv.style.use(['minimal','light']) 
ipv.show()
time.sleep(2)

ipv.figure()
#ipv.volshow(output[0,2].detach().cpu().numpy().squeeze())
ipv.volshow((output[0,2].transpose(2,0).flip(1)>=0.5).detach().cpu().numpy().squeeze(), level=[0.1, 0.5, 0.67], opacity=[0.01, 0.1, 0.1],)
#ipv.view(270,90)
ipv.style.use(['minimal','light']) 
ipv.show()
time.sleep(2)

In [None]:
# Simplify Hausdorff-distance
t1 = time.time()
output_vert,_,_,_ = measure.marching_cubes((output[0,2]>=0.5).detach().cpu().numpy().squeeze())
target_vert,_,_,_ = measure.marching_cubes((target==2).detach().cpu().numpy().squeeze())
t2 = time.time()
print('output_vert = {}   |   target_vert = {}'.format(output_vert.shape, target_vert.shape))
print('   Time marching : {:,.5f} sec.'.format(t2-t1))
h1 = directed_hausdorff(output_vert,target_vert)
h2 = directed_hausdorff(target_vert,output_vert)
line1 = np.vstack((output_vert[h1[1],:] , target_vert[h1[2],:]))
line2 = np.vstack((target_vert[h2[1],:] , output_vert[h2[2],:]))
print('h1 = {:,.4f}   |   h2 = {:,.4f}'.format(h1[0], h2[0]))
t3 = time.time()
print('   Time hausdorff distance : {:,.5f} sec'.format(t3-t2))

fig = plt.figure(figsize=(16,16))
ax = fig.add_subplot(221, projection='3d')
ax.scatter(output_vert[0::100,0],output_vert[0::100,1],output_vert[0::100,2],c='r',marker='x')
ax.plot(output_vert[h1[1],0],output_vert[h1[1],1],output_vert[h1[1],2],c='r',marker='x')
ax.plot(output_vert[h2[2],0],output_vert[h2[2],1],output_vert[h2[2],2],c='r',marker='x')
ax.set_title('Output fracture')
ax = fig.add_subplot(222, projection='3d')
ax.scatter(target_vert[0::100,0],target_vert[0::100,1],target_vert[0::100,2],c='g',marker='.')
ax.plot(target_vert[h1[2],0],target_vert[h1[2],1],target_vert[h1[2],2],c='g',marker='.')
ax.plot(target_vert[h2[1],0],target_vert[h2[1],1],target_vert[h2[1],2],c='g',marker='.')
ax.set_title('Ground truth fracture')
plt.show()

fig = plt.figure(figsize=(16,16))
for angle in range(0, 1):
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(output_vert[0::150,0],output_vert[0::150,1],output_vert[0::150,2],c='r', marker='x',label='output')
    ax.scatter(target_vert[0::150,0],target_vert[0::150,1],target_vert[0::150,2],c='g', marker='.',label='target')
    ax.plot(line1[:,0],line1[:,1],line1[:,2],'black',label='h1')
    ax.plot(line2[:,0],line2[:,1],line2[:,2],'black',label='h2')
    ax.plot(output_vert[h1[1],0],output_vert[h1[1],1],output_vert[h1[1],2],c='r',marker='x')
    ax.plot(target_vert[h1[2],0],target_vert[h1[2],1],target_vert[h1[2],2],c='g',marker='.')
    ax.plot(target_vert[h2[1],0],target_vert[h2[1],1],target_vert[h2[1],2],c='g',marker='.')
    ax.plot(output_vert[h2[2],0],output_vert[h2[2],1],output_vert[h2[2],2],c='r',marker='x')
    ax.legend()
    ax.set_title('Comparing Output vs Ground truth fracture')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    #ax.view_init(20, 30)
    #plt.draw()
    #plt.pause(.001)
    plt.show()

In [None]:
# 0=background  1=boneMask  2=fracMask
vidx = 1

shp = output.size()
# ToTensor1-5
'''tp = TP((output[0,vidx]>=0.5).float(), (target[0,vidx]).float().to(device=device))
tn = TN((output[0,vidx]>=0.5).float(), (target[0,vidx]).float().to(device=device))
fp = FP((output[0,vidx]>=0.5).float(), (target[0,vidx]).float().to(device=device))
fn = FN((output[0,vidx]>=0.5).float(), (target[0,vidx]).float().to(device=device))'''
# ToTensor6-9
tp = TP((output[0,vidx]>=0.5).float(), (target==vidx).float().to(device=device))
tn = TN((output[0,vidx]>=0.5).float(), (target==vidx).float().to(device=device))
fp = FP((output[0,vidx]>=0.5).float(), (target==vidx).float().to(device=device))
fn = FN((output[0,vidx]>=0.5).float(), (target==vidx).float().to(device=device))
tp_sum = torch.sum(tp.view(-1))
tn_sum = torch.sum(tn.view(-1))
fp_sum = torch.sum(fp.view(-1))
fn_sum = torch.sum(fn.view(-1))

# Confusion matrix
precision = tp_sum/(tp_sum+fp_sum)
recall = tp_sum/(tp_sum+fn_sum)    # = true positive rate(TPR) = Sensivity
accuracy = (tp_sum+tn_sum)/(shp[-3]*shp[-2]*shp[-1])
Specificity = tn_sum/(tn_sum+fp_sum)
#fpr = 1.0 - Specificity.
fpr = fp_sum/(tn_sum+fp_sum)
F1 = 2*(precision*recall)/(precision+recall)
print('True positive = {:,.1f}'.format(tp_sum))
print('False positive = {:,.1f}'.format(fp_sum))
print('False negative = {:,.1f}'.format(fn_sum))
print('True negative = {:,.1f}\n'.format(tn_sum))
print('Precision = {:,.4f}'.format(precision))
print('Recall = {:,.4f}'.format(recall))
print('F1-score = {:,.4f}\n'.format(F1))
#print('Accuracy = {:,.4f}'.format(accuracy))
print('True positive rate = {:,.4f}'.format(recall))
print('False positive rate = {:,.4f}\n'.format(fpr))


print('Number of <False positive> voxel = {:,.1f}/16.7M'.format(fp.view(-1).sum()))
fp = fp.transpose(2,0).flip(1).detach().cpu().numpy().squeeze()
ipv.figure()
ipv.volshow(fp)
ipv.style.use(['minimal','light']) 
#ipv.view(270,90)
ipv.show()
time.sleep(2)
print('\n\n')

print('Number of <False negative> voxel = {:,.1f}/16.7M'.format(fn.view(-1).sum()))
fn = fn.transpose(2,0).flip(1).detach().cpu().numpy().squeeze()
ipv.figure()
ipv.volshow(fn)
ipv.style.use(['minimal','light']) 
#ipv.view(270,90)
ipv.show()

In [None]:
plt.figure()
def sliceviewer2(voxel1, voxel2, voxel3, voxel4, x):
    fig , ax = plt.subplots(2,2,figsize=(14,14))
    #fig.suptitle('Comparing result')
    ax[0,0].imshow(voxel1[:,:,x] , cmap='bone')
    ax[0,0].set_title('Ground truth')
    ax[0,1].imshow(voxel2[:,:,x] , cmap='bone')
    ax[0,1].set_title('Output')
    
    ax[1,0].imshow(voxel3[:,:,x] , cmap='seismic')
    ax[1,0].set_title('False positive')
    ax[1,1].imshow(voxel4[:,:,x] , cmap='seismic')
    ax[1,1].set_title('False negative')
    
    return x

#voxel1 = sampleOutput['Target']
#voxel2 = sampleOutput['Output']
voxel1 = target==1
voxel2 = (output[0,1]>=0.5).cpu()
print(voxel1.size(), type(voxel1), voxel1.dtype)
print(voxel2.size(), type(voxel2), voxel2.dtype)
widgets.interact( sliceviewer2, 
                 voxel1 = widgets.fixed(voxel1), 
                 voxel2 = widgets.fixed(voxel2), 
                 voxel3 = widgets.fixed(fp), 
                 voxel4 = widgets.fixed(fn), 
                 x=(0,voxel1.shape[2]-1)
                )


## Feature visualization

In [None]:
print(encode_feat1[0].size())
print(decode_feat1[6].size())
print(fusion_feat[6].size(), fusion_feat[6].dtype)
print(fusionUp_feat[5].size(), fusion_feat[6].dtype)
print()

In [None]:
plt.colorbar()

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
def enc_dec_feature_visualize(enc1, enc2, dec1, dec2, x1, x2, x3, x4):
    
    fig , ax = plt.subplots(2,2,figsize=(14,14))
    fig.suptitle('Feature map visulization ')
    im1 = ax[0,0].imshow(enc1[0,x1,:,:], cmap='seismic')
    ax[0,0].set_title('Encode_feat1')
    divider = make_axes_locatable(ax[0,0])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im1, cax=cax)
    #plt.show()
    
    im2 = ax[0,1].imshow(enc2[0,x2,:,:], cmap='seismic')
    ax[0,1].set_title('Encode_feat2')
    divider = make_axes_locatable(ax[0,1])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im2, cax=cax)
    #plt.show()
    
    im3 = ax[1,0].imshow(dec1[0,x3,:,:], cmap='seismic')
    ax[1,0].set_title('Decode_feat1')
    divider = make_axes_locatable(ax[1,0])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im3, cax=cax)
    #plt.show()
    
    im4 = ax[1,1].imshow(dec2[0,x4,:,:], cmap='seismic')
    ax[1,1].set_title('Decode_feat2')
    divider = make_axes_locatable(ax[1,1])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im4, cax=cax)
    #plt.show()
    
    
    return x1, x2, x3, x4

def fusion_feature_visualize(fusion, fusionUp, x1, x2):
    ipv.figure()
    ipv.volshow(fusion[0,x1,:,:,:])
    ipv.style.use(['minimal','light']) 
    #ipv.view(270,90)
    ipv.show()
    
    ipv.figure()
    ipv.volshow(fusionUp[0,x2,:,:,:])
    ipv.style.use(['minimal','light']) 
    #ipv.view(270,90)
    ipv.show()
    return x1, x2

In [None]:
print('Encode Level = 0 - {}'.format(len(encode_feat1)-1))
print('Decode Level = 0 - {}'.format(len(decode_feat1)-1))
print('Fusion Level Up = 0 - {}'.format(len(fusion_feat)-1))
print('FusionUp Level Up = 0 - {}'.format(len(fusionUp_feat)-1))
print('---------------------------------- \n')
level = int(input('Level select : '))


widgets.interact(enc_dec_feature_visualize, 
                 enc1 = widgets.fixed(encode_feat1[level].cpu().detach().numpy().astype(float)), 
                 enc2 = widgets.fixed(encode_feat2[level].cpu().detach().numpy().astype(float)), 
                 dec1 = widgets.fixed(decode_feat1[len(decode_feat1)-1-level].cpu().detach().numpy().astype(float)), 
                 dec2 = widgets.fixed(decode_feat2[len(decode_feat1)-1-level].cpu().detach().numpy().astype(float)), 
                 x1=(0,encode_feat1[level].shape[1]-1),
                 x2=(0,encode_feat2[level].shape[1]-1),
                 x3=(0,decode_feat1[len(decode_feat1)-1-level].shape[1]-1),
                 x4=(0,decode_feat2[len(decode_feat2)-1-level].shape[1]-1),
                )



In [None]:
widgets.interact(fusion_feature_visualize, 
                 fusion = widgets.fixed(fusion_feat[len(fusion_feat)-1-level].cpu().detach().numpy().astype(float)), 
                 fusionUp = widgets.fixed(fusionUp_feat[len(fusionUp_feat)-1-level].cpu().detach().numpy().astype(float)), 
                 x1=(0,fusion_feat[len(fusion_feat)-1-level].shape[1]-1),
                 x2=(0,fusionUp_feat[len(fusionUp_feat)-1-level].shape[1]-1),
                )

In [None]:
torch.cuda.empty_cache()

In [None]:
first = True

In [None]:
if first==True:
    print('First testing sample')
    femurIter = iter(valLoader)
    first = False
else:
    print('Next testing sample')
    
sample = next(femurIter)

target , ap , lat = sample['Target'] , sample['AP'] , sample['LAT']
#print('Raw1 = {} {} {}'.format(target.size(),ap.size(),lat.size()))
target , ap , lat= target[0,1] , ap[0] , lat[0]
#print('Raw2 = {} {} {}'.format(target.size(), ap.size(), lat.size()))
target = target.unsqueeze(0).unsqueeze(0).to(device=device, dtype=dtype1)
ap = ap.unsqueeze(0).to(device=device, dtype=dtype1)
lat = lat.unsqueeze(0).to(device=device, dtype=dtype1)
#print('Final = {} {} {}'.format(target.size(), ap.size(), lat.size()))

with torch.cuda.amp.autocast(enabled=True):
    embeded = model1(target)
    output = model2(embeded)
    print('output = {}   {}'.format(output.size(), output.dtype))
    output = output[:,1].unsqueeze(1)
    print('output = {}   {}'.format(output.size(), output.dtype))
    #print('output.device: {} \ntarget.device: {} '.format(output.device,target.device))
    output = (output>=0.25).float()
    print('***  Accuracy = {:.4f}  *** \n'.format( iou(output,target.to(device=device)) ))

    output2 =output.detach().cpu().numpy().squeeze()
    #print('output2 type: {}   output2.shape = {} \n'.format(type(output2),output2.shape))

sampleOutput = {'Target':target.cpu().numpy().squeeze() ,
                'AP':ap.cpu().numpy().squeeze() , 
                'LAT':lat.cpu().numpy().squeeze() , 
                'Output':output2}

show_sample(sampleOutput, view='all', showOutput=True, detail=False)

In [None]:
print(output.max())
print(output.min())

In [None]:
N,C,D,H,W = output.size()
print(output.size(), target.size())
tp = TP(output,target.to(device=device))
tn = TN(output,target.to(device=device))
fp = FP(output,target.to(device=device))
fn = FN(output,target.to(device=device))
tp_sum = torch.sum(tp.view(-1))
tn_sum = torch.sum(tn.view(-1))
fp_sum = torch.sum(fp.view(-1))
fn_sum = torch.sum(fn.view(-1))

# Confusion matrix
precision = tp_sum/(tp_sum+fp_sum)
recall = tp_sum/(tp_sum+fn_sum)    # = true positive rate(TPR) = Sensivity
accuracy = (tp_sum+tn_sum)/(D*H*W)
Specificity = tn_sum/(tn_sum+fp_sum)
#fpr = 1.0 - Specificity.
fpr = fp_sum/(tn_sum+fp_sum)
F1 = 2*(precision*recall)/(precision+recall)
print('True positive = {:,.1f}'.format(tp_sum))
print('False positive = {:,.1f}'.format(fp_sum))
print('False negative = {:,.1f}'.format(fn_sum))
print('True negative = {:,.1f}\n'.format(tn_sum))
print('Precision = {:,.4f}'.format(precision))
print('Recall(TPR or Sensivity) = {:,.4f}'.format(recall))
print('Accuracy = {:,.4f}'.format(accuracy))
print('True positive rate = {:,.4f}'.format(recall))
print('False positive rate = {:,.4f}'.format(fpr))
print('F1-score = {:,.4f}\n\n'.format(F1))

print('Number of <False positive> voxel = {:,.1f}/16.7M'.format(fp.view(-1).sum()))
#print(fp.size())
#fp = fp.cpu().numpy().squeeze()
fp = fp.detach().cpu().numpy().squeeze()
#print(fp.shape)
ipv.figure()
ipv.volshow(fp)
ipv.view(270,90)
ipv.show()
time.sleep(2)
print('\n\n')

print('Number of <False negative> voxel = {:,.1f}/16.7M'.format(fn.view(-1).sum()))
fn = fn.detach().cpu().numpy().squeeze()
ipv.figure()
ipv.volshow(fn)
ipv.view(270,90)
ipv.show()

# Evaluation

**Evaluate on Testing set**

## On Perfect Orthogonal Dataset

In [None]:
torch.cuda.empty_cache()

In [None]:
model = netG

In [None]:
eva_set = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\Validation_set.xlsx'
eva_transformedFemur = FemurDataset2(csv_file=eva_set, root_dir=root_dir, 
                                    transform=transforms.Compose([NormalizeSample2(),ToTensor8()]))
evaLoader = DataLoader(eva_transformedFemur, batch_size=1, shuffle=False, num_workers=4)

In [None]:
weight = torch.tensor([0.15,0.25,0.6], device=device)     # For ToTensor8: Auxiliary Class
#weight = torch.tensor([0.5,0.5], device=device)           # For ToTensor9: Augmentation only 
criterion = FocalLossMulticlass(weight=weight, gamma=2.0, reduction='sum')

In [None]:
print('### Evaluation loop ###')
torch.cuda.empty_cache()
model.eval()
time1 = time.time()
column_name = ['loss','IoU1','TP1','FP1','FN1','TN1','ASD_gt','ASD_ot','ASSD','RMSD_gt','RMSD_ot','HD_gt','HD_ot','BHD','IoU2','TP2','FP2','FN2','TN2']   # Auxiliary Class
#column_name = ['loss','IoU1','TP1','FP1','FN1','TN1','ASD_gt','ASD_ot','ASSD','RMSD_gt','RMSD_ot','HD_gt','HD_ot','BHD']                                  # Without Auxiliary Class
result = np.zeros((len(evaLoader),len(column_name)))
progressbar = tqdm(enumerate(evaLoader), total=len(evaLoader), desc="Process:   ")
for  t, eva_sample in progressbar:
#for t, val_sample in tqdm(enumerate(valLoader)):
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            target = eva_sample['Target'].to(device=device, dtype=dtype1)
            view1 = eva_sample['view1'].to(device=device, dtype=dtype1)
            view2 = eva_sample['view2'].to(device=device, dtype=dtype1)
            #view1 = eva_sample['AP'].to(device=device, dtype=dtype1)
            #view2 = eva_sample['LAT'].to(device=device, dtype=dtype1)
            output = model(view1, view2)
            
            loss = criterion(output,target.long())   # for ToTensor6-7 with FocalMultiClass
            #loss_eva.append(loss.item())
            
            target_aux = (target[0,:]==2).float().detach().cpu()       # for Auxiliary (only) or Auxiliary + FracAugmentation
            output_aux = (output[0,2]>=0.5).float().detach().cpu()     # for Auxiliary (only) or Auxiliary + FracAugmentation
            target = (target[0,:]==1).float().detach().cpu()
            output = (output[0,1]>=0.5).float().detach().cpu()
            
            metrices = overlap_based_metrices(output, target, sum_of_matrix=True)                        # for ToTensor7-9
            metrices2 = overlap_based_metrices(output_aux, target_aux, sum_of_matrix=True)               # for ToTensor7-9
            bound_metrices = surface_distance_measurement(target.numpy(), output.numpy(), res=0.5, return_vert_dist=False, verbose=False)
            
            result[t,:] = [loss,
                           *[ i.item() for i in metrices.values()],
                           *bound_metrices.values(),
                           *[ i.item() for i in metrices2.values()]
                          ]
            progressbar.set_description('Loss:{:,.0f}   IOU:{:.3f}   ASSD:{:,.2f}'.format(loss,metrices['IoU'],bound_metrices['ASSD']))
            
time2 = time.time()
eva_loss_acc = pd.DataFrame(data=result, columns=column_name)
print('   Duration Testing time = {} Min. '.format((time2-time1)/60))
print('--- END ---')

In [None]:
'''
Recon2X3D5_21101901		Both
Recon2X3D5_21102201		Aux only
Recon2X3D5_21102101		FracAug only
Recon2X3D5_21102102		Bare
Recon2X3D5GAN_22022201		FracReconNet-GAN
'''
print(saved_name)

In [None]:
eva_loss_acc

In [None]:
eva_loss_acc.to_csv(path_or_buf='trained\Recon2X3D5GAN_22022201.csv')

In [None]:
list(bound_metrices.values())

In [None]:
x = [ i.item() for i in metrices2.values()]
print(x, len(x))

In [None]:
plt.figure(figsize=(16,8))
plt.hist(result[:,0], bins=100)
plt.xlabel('Loss')
plt.title('Loss')
plt.show()
fig, ax = plt.subplots(1,2,figsize=(16,8))
ax[0].hist(result[:,1], bins=100)    # IoU_bone
ax[0].set_title('Intersection-Over-Union of Bone')
ax[0].set_xlabel('Samples')
ax[1].hist(result[:,8], bins=100)      # ASSD
ax[1].set_title('Average Symmetrical Surface Distances')
ax[1].set_xlabel('Samples')
plt.show()

In [None]:
cc = np.random.randint(0, 10, size=(1000))
print(cc.shape)
#cc_hist = np.histogram(cc, bins=100)
plt.figure(figsize=(16,8))
plt.hist(cc, bins='auto')
plt.show()

In [None]:
ipv.figure()
ipv.volshow(target==2)
ipv.show()
time.sleep(3)
ipv.figure()
ipv.volshow(output[2]>=0.5)
ipv.show()

In [None]:
print(output.size())
print(target.size())

In [None]:
output = output.detach().cpu().numpy().squeeze()
target = target.detach().cpu().numpy().squeeze()
print(type(output), output.shape, output.dtype)

In [None]:
np.save('output.npy', output)
np.save('target.npy', target)

In [None]:
print(output[2].shape)
print((target==2).shape)

### ETC.

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import cv2 as cv
import argparse

from torchvision import models, transforms

In [None]:
ap = argparse.ArgumentParser()
ap.add_argument('-i', '--image', required=True,
    help='path to image')
args = vars(ap.parse_args())

In [None]:
# load the model
#model = models.resnet50(pretrained=True)
print(model)
model_weights = [] # we will save the conv layer weights in this list
conv_layers = [] # we will save the 49 conv layers in this list
# get all the model children as list
model_children = list(model.children())
print(model_children)

In [None]:
# counter to keep count of the conv layers
counterConv2d = 0
counterConv3d = 0
counterConvTranspose3d = 0
# append all the conv layers and their respective weights to the list
for i,level in enumerate(model_children):
    print('Level:{}'.format(i))
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d):
        model_weights.append(level.weight)
        conv_layers.append(level)
    for j,layer in enumerate(level):
        print('   Layer:{}'.format(j))
        for k,module in enumerate(layer):
            print('      Module:{}'.format(k))
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d):
                model_weights.append(module.weight)
                conv_layers.append(module)
print('Total Conv2d = {}'.format(counterConv2d))
print('Total Conv3d = {}'.format(counterConv3d))
print('Total ConvTranspose3d = {}'.format(counterConvTranspose3d))

'''
for i in range(len(model_children)):
    if type(model_children[i]) == nn.Conv2d:
        counter += 1
        model_weights.append(model_children[i].weight)
        conv_layers.append(model_children[i])
    elif type(model_children[i]) == nn.Sequential:
        for j in range(len(model_children[i])):
            for child in model_children[i][j].children():
                if type(child) == nn.Conv2d:
                    counter += 1
                    model_weights.append(child.weight)
                    conv_layers.append(child)
print(f"Total convolutional layers: {counter}")
'''

In [None]:
# visualize the first conv layer filters
plt.figure(figsize=(20, 17))
for i, filter in enumerate(model_weights[0]):
    plt.subplot(8, 8, i+1) # (8, 8) because in conv0 we have 7x7 filters and total of 64 (see printed shapes)
    plt.imshow(filter[0, :, :].detach(), cmap='gray')
    plt.axis('off')
    plt.savefig('../outputs/filter.png')
plt.show()

In [None]:
# Passing the image through all the layers
results = [conv_layers[0](img)]
for i in range(1, len(conv_layers)):
    # pass the result from the last layer to the next layer
    results.append(conv_layers[i](results[-1]))
# make a copy of the `results`
outputs = results

## On Non-orthogonal data

In [None]:
torch.cuda.empty_cache()

In [None]:
#eva_set = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\Validation_set.xlsx'
eva_set = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\TestingSet_Scale12.xlsx'
eva_transformedFemur = FemurDataset2(csv_file=eva_set, root_dir=root_dir, 
                                     transform=transforms.Compose([NormalizeSample3(),ToTensor8Plus()]))
evaLoader = DataLoader(eva_transformedFemur, batch_size=1, shuffle=False, num_workers=4)

In [None]:
weight = torch.tensor([0.15,0.25,0.6], device=device)     # For ToTensor8: Auxiliary Class
#weight = torch.tensor([0.5,0.5], device=device)           # For ToTensor9: Augmentation only 
criterion = FocalLossMulticlass(weight=weight, gamma=2.0, reduction='sum')

def findOccurrencesAngle(strr, ch):
    kk = list()      # location matched ch
    strr = strr[0]
    for ff, letter in enumerate(strr):
        #print('ff = {}   letter = {}'.format(ff, letter))
        if letter == ch:
            #print('Match!!! = {}'.format(ch))
            kk.append(ff)
    #print('kk = {}'.format(kk))
    angle = float(strr[kk[-2]+1:kk[-1]])
    return angle

In [None]:
set00 = dict()
set25 = dict()
set50 = dict()
set75 = dict()
set100 = dict()

In [None]:
overlap_based_metrices
overlap_based_metrices
surface_distance_measurement
print('Checked Evaluation metrics: Done')

In [None]:
print('### Evaluation loop of Unaligned Dataset ###')
torch.cuda.empty_cache()
model.eval()

time1 = time.time()
column_name = ['diff_angle','loss','IoU1','TP1','FP1','FN1','TN1','ASD_gt','ASD_ot','ASSD','RMSD_gt','RMSD_ot','HD_gt','HD_ot','BHD','IoU2','TP2','FP2','FN2','TN2']  # Auxiliary Class
#column_name = ['loss','IoU1','TP1','FP1','FN1','TN1','ASD_gt','ASD_ot','ASSD','RMSD_gt','RMSD_ot','HD_gt','HD_ot','BHD']                                  # Without Auxiliary Class
result = np.zeros((len(evaLoader),len(column_name)))
progressbar = tqdm(enumerate(evaLoader), total=len(evaLoader), desc="Process:   ")
for  t, eva_sample in progressbar:
#for t, val_sample in tqdm(enumerate(valLoader)):
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            target = eva_sample['Target'].to(device=device, dtype=dtype1)
            view1 = eva_sample['view1'].to(device=device, dtype=dtype1)
            view2 = eva_sample['view2'].to(device=device, dtype=dtype1)
            drr1_name = eva_sample['drr1_name']
            drr2_name = eva_sample['drr2_name']
            ang1 = findOccurrencesAngle(drr1_name,'-')
            ang2 = findOccurrencesAngle(drr2_name,'-')
            #print('drr1_name = {} {}'.format(drr1_name,findOccurrencesAngle(drr1_name,'-')))
            #print('drr2_name = {} {}'.format(drr2_name,findOccurrencesAngle(drr2_name,'-')))
            diff_angle = abs(ang1 - ang2)
            output = model(view1, view2)
            loss = criterion(output,target.long())   # for ToTensor6-7 with FocalMultiClass
            #loss_eva.append(loss.item())
            
            target_aux = (target[0,:]==2).float().detach().cpu()       # for Auxiliary (only) or Auxiliary + FracAugmentation
            output_aux = (output[0,2]>=0.5).float().detach().cpu()     # for Auxiliary (only) or Auxiliary + FracAugmentation
            target = (target[0,:]==1).float().detach().cpu()
            output = (output[0,1]>=0.5).float().detach().cpu()
            
            metrices = overlap_based_metrices(output, target, sum_of_matrix=True)                        # for ToTensor7-9
            metrices2 = overlap_based_metrices(output_aux, target_aux, sum_of_matrix=True)               # for ToTensor7-9
            bound_metrices = surface_distance_measurement(target.numpy(), output.numpy(), res=0.5, return_vert_dist=False, verbose=False)
            
            result[t,:] = [diff_angle, loss.cpu().numpy(),
                           *[ i.item() for i in metrices.values()],   # 5 metrics
                           *bound_metrices.values(),                  # 8 metrics
                           *[ i.item() for i in metrices2.values()]   # 5 metrics
                          ]
            progressbar.set_description('Loss:{:,.0f}   IOU:{:.3f}   ASSD:{:,.2f}'.format(loss,metrices['IoU'],bound_metrices['ASSD']))
            
time2 = time.time()
eva_loss_acc = pd.DataFrame(data=result, columns=column_name)
print('   Duration Testing time = {} Min. '.format((time2-time1)/60))
print('--- END ---')

In [None]:
eva_loss_acc.head(15)

In [None]:
eva_loss_acc.describe()

In [None]:
eva_loss_acc.info(verbose=True)

In [None]:
# Save Evaluation sample

eval_savename = 'trained\Recon2X3D6_22100401_unalign12.csv'

eval_save_logic = int(input('Save the evaluation [1] Yes [0] No ?'))
if eval_save_logic == 1:
    verify_eval_savename = int(input('Evaluation savename = {} :\n   [1] Yes [0] No = '.format(eval_savename)))
    if verify_eval_savename == 1:
        eva_loss_acc.to_csv(path_or_buf=eval_savename)
        print('--- Save is done ---')

In [None]:
df = eva_loss_acc
print(df.shape)
df.head()

In [None]:
df00 = df[df['diff_angle']==0.0]
df00

In [None]:
df25 = df[df['diff_angle']==2.5]
df25

In [None]:
df50 = df[df['diff_angle']==5]
df50

In [None]:
df75 = df[df['diff_angle']==7.5]
df75

In [None]:
df100 = df[df['diff_angle']==10]
df100

## Test with Real-world X-ray

In [None]:
path = r'D:\FEW PhD\Datasets\Siriraj Dataset\Cleaning Data\3D Intensity Volume (RAW)\Xray'   

xray1 = np.load(path+r'\Xray1-52827012-Scale-0.7.npy')   # Intact
xray2 = np.load(path+r'\Xray2-52827012-Scale-0.7.npy')   # Intact
#xray1 = np.load(path+'\Xray1-53568941-Scale-0.6.npy')   # Nondisplaced
#xray2 = np.load(path+'\Xray2-53568941-Scale-0.6.npy')   # Nondisplaced
#xray1 = np.load(path+'\Xray1--Scale-0.6.npy')   # Displaced
#xray2 = np.load(path+'\Xray2--Scale-0.6.npy')   # Displaced
print('xray1 = {}  {}  {}\nvalue = {} - {} '.format(xray1.shape, type(xray1), xray1.dtype,xray1.min(),xray1.max()))
print('xray2 = {}  {}  {}\nvalue = {} - {} '.format(xray2.shape, type(xray2), xray2.dtype,xray2.min(),xray2.max()))
maxx = xray1.max()
xray1 = xray1/maxx
xray2 = xray2/maxx
print('xray1 = {}  {}  {}\nvalue = {} - {} '.format(xray1.shape, type(xray1), xray1.dtype,xray1.min(),xray1.max()))
print('xray2 = {}  {}  {}\nvalue = {} - {} '.format(xray2.shape, type(xray2), xray2.dtype,xray2.min(),xray2.max()))
fig, ax = plt.subplots(1,2,figsize=(24,24))
ax[0].imshow(1-xray1, cmap='gray')
ax[0].set_title('Judet Xray1')
ax[1].imshow(1-xray2, cmap='gray')
ax[1].set_title('Judet Xray2')
plt.show()

In [None]:
x1 = torch.tensor(xray1,dtype=torch.float).unsqueeze(dim=0).unsqueeze(dim=1).cuda()
x2 = torch.tensor(xray2,dtype=torch.float).unsqueeze(dim=0).unsqueeze(dim=1).cuda()
print('x1 = {} {} {} cuda{}'.format(x1.size(),type(x1),x1.dtype,x1.get_device()))
print('x2 = {} {} {} cuda{}'.format(x2.size(),type(x2),x2.dtype,x2.get_device()))

In [None]:
predict = model(1-x1,1-x2)

In [None]:
cam_view1 = [0,0,-1.5, 0,0,0 ,0,-1,0]   # view1
cam_view2 = [1.5,0,0, 0,0,0 ,1,-1,0]    # view2
plot = k3d.plot(name='Plot output')
obj = k3d.volume(predict[0,2].detach().cpu().numpy(), name='predict', 
                 color_map=k3d.colormaps.matplotlib_color_maps.Bone,
                 gradient_step=0.005,
                 shadow='dynamic',
                 shadow_delay=10,
                )
plot += obj + k3d.text2d(text='Predict', color=0, size=1 ,position=(0.01,0.025), label_box=False)
plot.display()
plot.camera = cam_view2
#plot.camera = cam_view2

In [None]:
xray_filename = r'D:\FEW PhD\Program\3D_Reconstruction\femur3dnet\trained\K3DResult\Siriraj_Dataset\Volume\sample.html'
with open(xray_filename,'w') as fp:
    fp.write(plot.get_snapshot())

In [None]:
x = np.load('D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset4_Unaligned\DRR1-20140225CT0101-1-1-41-4.npy')
print('x = {} {} {}'.format(x.shape,type(x),x.dtype))
print('x values = {} - {}'.format(x.min(),x.max()))

# ETC.

## TensorBoard

In [None]:
%load_ext tensorboard

In [None]:
femurBatched = next(iter(femurLoader))
print('femurBatched.size = {}'.format(femurBatched['AP'].size()))
img = femurBatched['AP']
#img = img.unsqueeze(dim=0)
img = img.to(device=device)                # img.to(device=torch.device('cpu'))
print('img dtype = {}'.format(img.type()))
print('img.size = {}'.format(img.size()))
imgs = utils.make_grid(femurBatched['AP'])
imgs_show = imgs.numpy().transpose(1,2,0)
plt.figure(figsize=(18,18))
plt.imshow(imgs_show, cmap='gray')

In [None]:
writer.add_image('AP',imgs)
writer.add_graph(model,img.to(device=device))

In [None]:
print(logRecordDir)

In [None]:
%tensorboard --logdir='runs/Recon3DUNet_201023'

## Register Hook

https://pytorch.org/docs/stable/generated/torch.nn.Module.html

In [None]:
class SaveFeatures():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.features = torch.tensor(output,requires_grad=True).cuda()
    def close(self):
        self.hook.remove()
        
class FilterVisualizer():
    def __init__(self, size=56, upscaling_steps=12, upscaling_factor=1.2):
        self.size, self.upscaling_steps, self.upscaling_factor = size, upscaling_steps, upscaling_factor
        self.model = model.cuda().eval()
        #print(model.parameters)
        #self.model = vgg16(pretrained=True).cuda().eval()
        #set_trainable(self.model, False)

    def visualize(self, layer, filter, lr=0.1, opt_steps=20, blur=None):
        sz = self.size
        img = np.uint8(np.random.uniform(150, 180, (sz, sz, 3)))/255  # generate random image
        activations = SaveFeatures(list(self.model.children())[layer])  # register hook

        for _ in range(self.upscaling_steps):  # scale the image up upscaling_steps times
            train_tfms, val_tfms = tfms_from_model(vgg16, sz)
            img_var = V(val_tfms(img)[None], requires_grad=True)  # convert image to Variable that requires grad
            optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6)
            for n in range(opt_steps):  # optimize pixel values for opt_steps times
                optimizer.zero_grad()
                self.model(img_var)
                loss = -activations.features[0, filter].mean()
                loss.backward()
                optimizer.step()
            img = val_tfms.denorm(img_var.data.cpu().numpy()[0].transpose(1,2,0))
            self.output = img
            sz = int(self.upscaling_factor * sz)  # calculate new image size
            img = cv2.resize(img, (sz, sz), interpolation = cv2.INTER_CUBIC)  # scale image up
            if blur is not None: img = cv2.blur(img,(blur,blur))  # blur image to reduce high frequency patterns
        self.save(layer, filter)
        activations.close()
        
    def save(self, layer, filter):
        plt.imsave("layer_"+str(layer)+"_filter_"+str(filter)+".jpg", np.clip(self.output, 0, 1))
        


In [None]:
#vgg16 = torchvision.models.vgg16(pretrained=True)  # load and save in C:\Users\BDML\.cache\torch\hub\checkpoints\vgg16-397923af.pth
layer = 40
filter = 265
FV = FilterVisualizer(size=56, upscaling_steps=12, upscaling_factor=1.2)
print(type(FV))
FV.visualize(layer, filter, blur=5)

In [None]:
class Hook():
    def __init__(self, module, backward=False):    # module can be any torch.nn.Module such as nn.Linear, nn.Conv2d
        if backward==False:
            self.hook = module.register_forward_hook(self.hook_fn)
        else:
            self.hook = module.register_backward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output
    def close(self):
        self.hook.remove()

In [None]:
#hookF = [Hook(layer[1]) for layer in list(model._modules.items())]          # forward pass hooked
#hookB = [Hook(layer[1],backward=True) for layer in list(model._modules.items())]     # backward pass hooked

hookF = None
hookB = None
for layer in list(model._modules.items()):
    if isinstance(layer,nn.MaxPool2d) or isinstance(layer,nn.Upsample):
        break
    hookF = Hook(layer[1])
    hookB = Hook(layer[1],backward=True)
    
# run a data batch
out=model(ap_I).to(device=device)


print('***'*3 + '  Forward Hooks Inputs & Outputs  '+'***'*3)
for hook in hookF:
    print('layer input = {}'.format(hook.input))
    print('layer output = {}'.format(hook.output))
    print('---'*17)
print('\n')

print('***'*3 + '  Backward Hooks Inputs & Outputs  '+'***'*3)
for hook in hookB:             
    print('layer input = {}'.format(hook.input))
    print('layer output = {}'.format(hook.output))       
    print('---'*17)

## Feature show

https://towardsdatascience.com/visualizing-convolution-neural-networks-using-pytorch-3dfa8443e74e
https://towardsdatascience.com/visualizing-convolution-neural-networks-using-pytorch-3dfa8443e74e

In [None]:
# Function for ploting CNN paraemters

def plot_filters_single_channel_big(t):
    
    #setting the rows and columns
    nrows = t.shape[0]*t.shape[2]
    ncols = t.shape[1]*t.shape[3]
    
    npimg = np.array(t.numpy(), np.float32)
    npimg = npimg.transpose((0, 2, 1, 3))
    npimg = npimg.ravel().reshape(nrows, ncols)
    
    npimg = npimg.T
    
    fig, ax = plt.subplots(figsize=(ncols/10, nrows/200))    
    imgplot = sns.heatmap(npimg, xticklabels=False, yticklabels=False, cmap='gray', ax=ax, cbar=False)

    
def plot_filters_single_channel(t):    # t = 4-Dimension tensor
    
    #kernels depth * number of kernels
    nplots = t.shape[0]*t.shape[1]
    ncols = 12
    
    nrows = 1 + nplots//ncols
    #convert tensor to numpy image
    #npimg = np.array(t.numpy(), np.float32)
    
    count = 0
    fig = plt.figure(figsize=(ncols, nrows))    # figsize=(ncols, nrows)
    #looping through all the kernels in each channel
    for i in range(t.shape[0]):
        for j in range(t.shape[1]):
            count += 1
            ax1 = fig.add_subplot(nrows, ncols, count)
            npimg = np.array(t[i, j].cpu().numpy(), np.float32)
            #standardize the numpy image
            #npimg = (npimg - np.mean(npimg)) / np.std(npimg)
            #npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
            ax1.imshow(npimg,cmap='gray')
            ax1.set_title(str(i) + ',' + str(j))
            ax1.axis('off')
            ax1.set_xticklabels([])
            ax1.set_yticklabels([])
   
    plt.tight_layout()
    plt.show()


def plot_normlayer_single_channel(t):    # t = 3-Dimension tensor
    
    #kernels depth * number of kernels
    nplots = t.shape[0]
    ncols = 12
    
    nrows = 1 + nplots//ncols
    #convert tensor to numpy image
    #npimg = np.array(t.numpy(), np.float32)
    
    count = 0
    fig = plt.figure(figsize=(20,20))
    #looping through all the kernels in each channel
    for i in range(t.size(dim=0)):
        count += 1
        ax1 = fig.add_subplot(nrows, ncols, count)
        npimg = np.array(t[i,:,:].numpy(), np.float32)
        #standardize the numpy image
        #npimg = (npimg - np.mean(npimg)) / np.std(npimg)
        #npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
        ax1.imshow(npimg,cmap='gray')
        ax1.set_title(str(i))
        ax1.axis('off')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        
    plt.tight_layout()
    plt.show()
    
def plot_filters_multi_channel(t):
    
    #get the number of kernals
    num_kernels = t.shape[0]    
    
    #define number of columns for subplots
    num_cols = 12
    #rows = num of kernels
    num_rows = num_kernels
    
    #set the figure size
    fig = plt.figure(figsize=(num_cols,num_rows))
    
    #looping through all the kernels
    for i in range(t.shape[0]):
        ax1 = fig.add_subplot(num_rows,num_cols,i+1)
        #for each kernel, we convert the tensor to numpy 
        npimg = np.array(t[i].numpy(), np.float32)
        #standardize the numpy image
        #npimg = (npimg - np.mean(npimg)) / np.std(npimg)
        #npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
        #npimg = npimg.transpose((1, 2, 0))
        ax1.imshow(npimg)
        ax1.axis('off')
        ax1.set_title(str(i))
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        
    plt.savefig('myimage.png', dpi=100)    
    plt.tight_layout()
    plt.show()

    
def plot_weights(model, layer_num, single_channel = True, collated = False):
    #extracting the model features at the particular layer number
    layer = model.features[layer_num]
  
    #checking whether the layer is convolution layer or not 
    if isinstance(layer, nn.Conv2d):
        #getting the weight tensor data
        weight_tensor = model.features[layer_num].weight.data
    
        if single_channel:
            if collated:
                plot_filters_single_channel_big(weight_tensor)
            else:
                plot_filters_single_channel(weight_tensor)
        else:
            if weight_tensor.shape[1] == 3:
                plot_filters_multi_channel(weight_tensor)
            else:
                print("Can only plot weights with three channels with single channel = False") 
    else:
        print("Can only visualize layers which are convolutional")

def plotFilterWeight(model):
    nameList = list()
    paramList = list()
    for n,p in model.named_parameters():
        nameList.append(n)
        paramList.append(p)
    
    print('Number of Parameter = {} \n'.format(len(nameList)))
    for i in range(len(nameList)):
        if i > 20:
            print('!!! Break !!!')
            break
        
        name = nameList.pop(0)
        param = paramList.pop(0).detach()
        print('Name: {}   |   Parameter.size = {}'.format(name,param.size()))
        if param.ndim == 4:
            plot_filters_single_channel(param)
        elif param.ndim == 1:
            print('Bias = {}'.format(param))
        elif param.ndim == 3:
            #plot_normlayer_single_channel(param)
            pass
        
        
        
        
        
        
#plot_weights(model, 1, single_channel = False)
plotFilterWeight(model)

### Useful function

In [None]:
count1 , count2 , count3 , count4 = 0 , 0 , 0 , 0
for n , w in model.named_parameters():
    if isinstance(n , nn.Conv2d):
        count1 += 1
        print('Conv2d = {} :'.format(n) , end=' :  ')
    elif isinstance(n , nn.LayerNorm):
        count2 += 1
        print('LayerNorm = {} '.format(n) , end=' :  ')
    elif isinstance(n , nn.MaxPool2d):
        count3 += 1
        print('MaxPool2d = {}'.format(n) , end=' :  ')
    elif isinstance(n , nn.Upsample):
        count4 += 1
        print('Upsample = {}'.format(n) , end=' :  ')
    print('{}'.format(n))

In [None]:
count1 , count2 , count3 , count4 = 0 , 0 , 0 , 0
for n in model.children():
    if isinstance(n , nn.Conv2d):
        count1 += 1
        print('Conv2d-({})   '.format(str(count1).zfill(2)) , end=' :  ')
    elif isinstance(n , nn.LayerNorm):
        count2 += 1
        print('LayerNorm-({})'.format(str(count2).zfill(2)) , end=' :  ')
    elif isinstance(n , nn.MaxPool2d):
        count3 += 1
        print('MaxPool2d-({})'.format(str(count3).zfill(2)) , end=' :  ')
    elif isinstance(n , nn.Upsample):
        count4 += 1
        print('Upsample-({}) '.format(str(count4).zfill(2)) , end=' :  ')
    print('{}'.format(n))

In [None]:
nameList = list()
paramList = list()
for name, param in model.named_parameters():
    #print('Layer = {}    Size = {}  '.format(name, param.size()))
    nameList.append(name)
    paramList.append(param)

In [None]:
name = nameList.pop(0)
param = paramList.pop(0)
print('Name = {} :   Weight size = {}'.format(name,param.size()))
#plot_filters_single_channel(param.detach())

In [None]:
# Function for ploting CNN paraemters

def plot_filters_single_channel_big(t):
    
    #setting the rows and columns
    nrows = t.shape[0]*t.shape[2]
    ncols = t.shape[1]*t.shape[3]
    
    npimg = np.array(t.numpy(), np.float32)
    npimg = npimg.transpose((0, 2, 1, 3))
    npimg = npimg.ravel().reshape(nrows, ncols)
    
    npimg = npimg.T
    
    fig, ax = plt.subplots(figsize=(ncols/10, nrows/200))    
    imgplot = sns.heatmap(npimg, xticklabels=False, yticklabels=False, cmap='gray', ax=ax, cbar=False)

    
def plot_filters_single_channel(t):
    
    #kernels depth * number of kernels
    nplots = t.shape[0]*t.shape[1]
    ncols = 12
    
    nrows = 1 + nplots//ncols
    #convert tensor to numpy image
    npimg = np.array(t.numpy(), np.float32)
    
    count = 0
    fig = plt.figure(figsize=(ncols, nrows))
    
    #looping through all the kernels in each channel
    for i in range(t.shape[0]):
        for j in range(t.shape[1]):
            count += 1
            ax1 = fig.add_subplot(nrows, ncols, count)
            npimg = np.array(t[i, j].numpy(), np.float32)
            npimg = (npimg - np.mean(npimg)) / np.std(npimg)
            npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
            ax1.imshow(npimg)
            ax1.set_title(str(i) + ',' + str(j))
            ax1.axis('off')
            ax1.set_xticklabels([])
            ax1.set_yticklabels([])
   
    plt.tight_layout()
    plt.show()

    
def plot_filters_multi_channel(t):
    
    #get the number of kernals
    num_kernels = t.shape[0]    
    
    #define number of columns for subplots
    num_cols = 12
    #rows = num of kernels
    num_rows = num_kernels
    
    #set the figure size
    fig = plt.figure(figsize=(num_cols,num_rows))
    
    #looping through all the kernels
    for i in range(t.shape[0]):
        ax1 = fig.add_subplot(num_rows,num_cols,i+1)
        
        #for each kernel, we convert the tensor to numpy 
        npimg = np.array(t[i].numpy(), np.float32)
        #standardize the numpy image
        npimg = (npimg - np.mean(npimg)) / np.std(npimg)
        npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
        npimg = npimg.transpose((1, 2, 0))
        ax1.imshow(npimg)
        ax1.axis('off')
        ax1.set_title(str(i))
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        
    plt.savefig('myimage.png', dpi=100)    
    plt.tight_layout()
    plt.show()

    
def plot_weights(model, layer_num, single_channel = True, collated = False):
    #extracting the model features at the particular layer number
    layer = model.features[layer_num]
  
    #checking whether the layer is convolution layer or not 
    if isinstance(layer, nn.Conv2d):
        #getting the weight tensor data
        weight_tensor = model.features[layer_num].weight.data
    
        if single_channel:
            if collated:
                plot_filters_single_channel_big(weight_tensor)
            else:
                plot_filters_single_channel(weight_tensor)
        else:
            if weight_tensor.shape[1] == 3:
                plot_filters_multi_channel(weight_tensor)
            else:
                print("Can only plot weights with three channels with single channel = False") 
    else:
        print("Can only visualize layers which are convolutional")
        
        
plot_weights(model, 1, single_channel = False)

## Visualize Feature Map

https://github.com/utkuozbulak/pytorch-cnn-visualizations#convolutional-neural-network-filter-visualization

In [None]:
from pytorchCNNVisualization.src import cnn_layer_visualization

In [None]:
cnn_layer = 17
filter_pos = 5

# Fully connected layer is not needed
pretrained_model = torchvision.models.vgg16(pretrained=True).features
layer_vis = cnn_layer_visualization.CNNLayerVisualization(pretrained_model, cnn_layer, filter_pos)

# Layer visualization with pytorch hooks
layer_vis.visualise_layer_with_hooks()

# Layer visualization without pytorch hooks
# layer_vis.visualise_layer_without_hooks()