# Demo FracReconNet

We recommend using Jupyter Lab to be consistent with K3D visualization toolbox.

In [None]:
# Simple Toolbox 
%matplotlib inline
from __future__ import print_function, division
import matplotlib
import matplotlib.pyplot as plt
import os
#import skimage
#from skimage import io, transform, measure
import scipy
from scipy.spatial.distance import directed_hausdorff
from sklearn.neighbors import NearestNeighbors
import numpy as np
#import cv2
#from PIL import Image
#import sys
import ipyvolume as ipv
import k3d
import pythreejs   # for controlling the Camera using with ipyvolume
import ipywidgets
import pandas as pd
import datetime
import time
import joblib
import ipywidgets as widgets
from tqdm.notebook import tqdm

# Advance Toolbox 
import torch
import torch.nn as nn
import torch.nn.functional as F                                      # useful stateless functions
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.autograd import Function, Variable

from torch.utils.data import Dataset, DataLoader
from torch.utils.data import sampler

import torchvision
import torchvision.datasets as dset
import torchvision.transforms as T
from torchvision import utils, transforms

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion()   # interactive mode

from model.utils import show_sample, FemurDataset, NormalizeSample, ToTensor, ToTensor2
from model.matricesOperator import iou, TP, FP, TN, FN, union, hausdorff_voxel, overlap_based_metrices, mesh_surface_nearest_distance, surface_distance_measurement
from model.losses import FocalLossMulticlass
import model.model as Model
from model.training import train_mixed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32        # we will be using float throughout this tutorial
print('Available device = {}'.format(device))

In [None]:
### FracReconNet's Configuration ###

# load save state dict
saved_name = '.\weight\FracReconNet.pt'
saved_dict = torch.load(saved_name, map_location=device)
print('Model: in_c  = {}'.format(saved_dict['in_c']))
print('Model: en_sz = {}'.format(saved_dict['en_sz']))
print('Model: de_sz = {}'.format(saved_dict['de_sz']))
print('Model: de3d_sz = {}'.format(saved_dict['de3d_sz']))
print('Model: final_sz = {}'.format(saved_dict['final_sz']))
print('Note: {}'.format(saved_dict['note']))
print('Timestamp = {}\n'.format(saved_dict['timestamp']))

# Building the model
fracReconNet = Model.fracReconNet(saved_dict['in_c'],saved_dict['en_sz'],saved_dict['de_sz'],saved_dict['de3d_sz'],saved_dict['final_sz'])
fracReconNet.load_state_dict(saved_dict['model_state_dict'])          # Load the best model to test and inference
fracReconNet.to(device=device)

# Optimizer configuration
learning_rate = 1e-4
optimizer = optim.Adam(fracReconNet.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1, threshold=1e-3, 
                              threshold_mode='rel', cooldown=0, min_lr=0, verbose=True)
optimizer.load_state_dict(saved_dict['optimizer_state_dict'])         # default: strict=False

# Inference and Visualization

In [None]:
#print(os.getcwd())
root_dir = '.\samples'
test_set = '.\samples\Testing_set.xlsx'

test_transformedFemur = FemurDataset(csv_file=test_set, root_dir=root_dir, transform=transforms.Compose([NormalizeSample(),ToTensor()]))  # For Testing
testLoader = DataLoader(test_transformedFemur, batch_size=1, shuffle=0, num_workers=4)
first = True

In [None]:
# Original
torch.cuda.empty_cache()
if first==True:
    print('First testing sample')
    femurIter = iter(testLoader)
    first = False
else:
    print('Next testing sample')
sample = next(femurIter)
target , ap , lat = sample['Target'] , sample['view1'] , sample['view2']
print('Raw: target={}  ap={}  lat={}'.format(target.size(),ap.size(),lat.size()))
target, ap, lat = target[0],ap[0].unsqueeze(0).to(device=device),lat[0].unsqueeze(0).to(device=device)  # for ToTensor6-8
print('Final: target={}  ap={}  lat={}'.format(target.size(), ap.size(), lat.size()))
    
### Inference ###
with torch.cuda.amp.autocast(enabled=True):
    t1 = time.time()
    fracReconNet.eval()
    output = fracReconNet(ap,lat).detach()
    t2 = time.time()
    print('Inference time = {:,.4f} sec.'.format(t2-t1))
    
##### OUTPUT: Bone Visualization #####
fig, ax = plt.subplots(1,2, figsize=(20,20))
ax[0].imshow(ap.detach().cpu().numpy().squeeze(), cmap='gray')
ax[0].set_title('Input view 1')
ax[1].imshow(lat.detach().cpu().numpy().squeeze(), cmap='gray')
ax[1].set_title('Input view 2')
plt.show()

# Accuracy check
gt = (target==1).detach().cpu().numpy()
ot = (output[0,1,:]>0.5).detach().cpu().numpy()
#print('IoU = {:,.3f}'.format(iou((target==1).float().cpu(), (output[0,1,:]>0.5).float().cpu())), end=', ')
#dist_metrics = surface_distance_measurement(output[0,1].cpu().numpy(), (target==1).cpu().numpy(), res=0.5, verbose=False)
#print('ASSD = {:.3f} mm'.format(dist_metrics['ASSD']))

# Camera view
cam_view1 = [0,0,-1.5, 0,0,0 ,0,-1,0]   # view1
cam_view2 = [1.5,0,0, 0,0,0 ,1,-1,0]    # view2

### Output : Bone Class ###
plot = k3d.plot(name='Plot 2 : [view2]')
obj = k3d.volume(ot, name='Output',
                 color_map=k3d.colormaps.matplotlib_color_maps.Bone,
                 gradient_step=0.005,
                 shadow='dynamic',
                 shadow_delay=10,
                )
plot += obj + k3d.text2d(text='Output', color=0, size=1 ,position=(0.01,0.025), label_box=False)
plot.display()
plot.camera = cam_view2

In [None]:
### Show Volume rotation360 ###
N = 3   # number of round
r_orbit = 1.5
camera_rotate = list([-r_orbit*np.sin(t), -0.2,r_orbit*np.cos(t), 0,0,0, 0,-1,0] for t in np.linspace(0, 2*np.pi*N, num=360*N) )  # 360*N
k3d.plot()
#plot4.grid_visible = False
for i, view in enumerate(camera_rotate):
    plot.camera = view
    time.sleep(6/360)
print('--- Rotation END ---')

In [None]:
t1 = time.time()
surf_dist_result = surface_distance_measurement(gt, ot, res=0.5, return_vert_dist=True, verbose=False)
t2 = time.time()
print('\nTotal time of calculation = {} sec.'.format(t2-t1))
print('key result = {}'.format(surf_dist_result.keys()))

background_color = 65536*255 + 256*255 + 255
pltmesh = k3d.plot(background_color=background_color)
title_text = k3d.text2d(text='Min. Surface Distance (GroundTruth-based', color=0, size=1 ,position=(0.01,0.025), label_box=False)
meshsurf = k3d.mesh(surf_dist_result['surface_dist_gt']['target_vert'], surf_dist_result['surface_dist_gt']['target_face'],
                    name='surface distance (mm)',
                    color_map=k3d.colormaps.basic_color_maps.Jet,
                    #color_map=k3d.colormaps.paraview_color_maps.Bone_Matlab
                    color_range=[0,3],   # [0,3] for align dataset   |   [0,6] for unalign dataset
                    attribute=surf_dist_result['surface_dist_gt']['distances'].astype(np.float32())
                   )   #     -target_vert[:,2]
pltmesh += title_text + meshsurf
pltmesh.display()
pltmesh.camera = [128,-100,512, 128,-128,128,  0,1,0]

pltmesh2 = k3d.plot(background_color=background_color)
title_text2 = k3d.text2d(text='Min. Surface Distance (Output-based)', color=0, size=1 ,position=(0.01,0.025), label_box=False)
meshsurf2 = k3d.mesh(surf_dist_result['surface_dist_ot']['target_vert'], surf_dist_result['surface_dist_ot']['target_face'],
                    name='surface distance (mm)',
                    color_map=k3d.colormaps.basic_color_maps.Jet, 
                    color_range=[0,3],  # [0,3] for align dataset   |   [0,6] for unalign dataset
                    attribute=surf_dist_result['surface_dist_ot']['distances'].astype(np.float32())
                   )   #     -target_vert[:,2]
pltmesh2 += title_text2 + meshsurf2
pltmesh2.display()
pltmesh2.camera = [128,-100,512, 128,-128,128,  0,1,0]
print('--- END ---')

In [None]:
### Show Mesh rotation360 ###
N = 3   # number of round
r_orbit = 384   #384
camera_rotate = list([128+r_orbit*np.sin(t),-90,128-r_orbit*np.cos(t), 128,-128,128, 0,1,0] 
                     for t in np.linspace(0+np.pi/2, 2*np.pi*N+np.pi/2, num=360*N) )  # 360*N
#k3d.plot()
#pltmesh.grid_visible = False
#pltmesh2.grid_visible = False

for i, view in enumerate(camera_rotate):
    #print('Order:{} | {}'.format(i,view))
    pltmesh.camera = view
    pltmesh2.camera = view
    time.sleep(6/360)

print('--- Rotation END ---')

# Training

In [None]:
# Training configuration
saved_name = r'.\weight\test_of_training.pt'
batch_sz = 3
epoch_number = 1
num_workers = 4
learning_rate = 1e-4
weight = torch.tensor([0.15,0.25,0.6], device=device)
criterion = FocalLossMulticlass(weight=weight, gamma=2.0, reduction='sum')

# Dataset preparation
root_dir = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2'
training_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\Training_set.xlsx'
val_file = r'D:\FEW PhD\Datasets\Chula DICOM 2021\Cleaning Data\Dataset2\Validation_set.xlsx'
train_transformedFemur = FemurDataset(csv_file=training_file, root_dir=root_dir, transform=transforms.Compose([NormalizeSample(), ToTensor()]))
val_transformedFemur = FemurDataset(csv_file=val_file, root_dir=root_dir, transform=transforms.Compose([NormalizeSample(), ToTensor()]))
trainLoader = DataLoader(train_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=num_workers)
valLoader = DataLoader(val_transformedFemur, batch_size=batch_sz, shuffle=True, num_workers=num_workers)


# FracReconNet's Configuration
in_c = 1
de3d_sz = None       # don't use now
en_sz = [[16,4,4],[32,8,4],[64,16,4],[128,16,4],[192,16,4],[256,16,4],[320,16,4]]
de_sz = [[64,4],[64,8],[128,16],[128,32],[128,64],[128,128],[256,256]]
final_sz = [[2,32,32],[5,32,32],[5,32,32],[5,16,16],[5,16,16],[5,16,16],[5,16,16]]  # fusion layer
note = 'Test training loop'

fracReconNet = Model.fracReconNet(in_c,en_sz,de_sz,de3d_sz,final_sz)
fracReconNet.to(device=device)
optimizer = optim.Adam(fracReconNet.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1, threshold=1e-3, 
                              threshold_mode='rel', cooldown=0, min_lr=0, verbose=True)

saved_dict = {'timestamp':None ,'note':note , 'in_c':in_c ,'en_sz':en_sz ,'de_sz':de_sz ,'de3d_sz':de3d_sz ,'final_sz':final_sz ,
              'train_loss_history':list(), 'train_acc_history':list(), 'val_loss_history':list(), 'val_acc_history':list() ,
              'model_state_dict':None,'optimizer_state_dict':None,'scheduler_state_dict':None }

print(' ##### Start training loop #####')
print('   Total iterations = {:,}'.format(len(saved_dict['train_loss_history'])))
print('   Start training at : {}'.format(str(datetime.datetime.now())))
timeT1 = time.time()
if torch.cuda.is_available():
    model = nn.DataParallel(fracReconNet)
    print('   Use mixed precision training')
    saved_dict = train_mixed(model, trainLoader, valLoader, optimizer, scheduler, criterion, batch_sz, epoch_number, saved_name, saved_dict, device)
else:
    assert torch.cuda.is_available(), '!!! Training with CPU not available !!!'

timeT2 = time.time()
print('\nTotal training time = {} hours \nTotal epoch = {}\n'.format((timeT2-timeT1)/3600, len(saved_dict['val_acc_history'])))
print('Saved Name: {}'.format(saved_name))
print('Finish task at : {}\n'.format(str(datetime.datetime.now()) ))