In [None]:
import os, sys, warnings
import numpy as np
np.set_printoptions(precision=2)
import tensorflow as tf
import SimpleITK as sitk

sys.path.append("../")
sys.path.append("../deepneuroan")
sys.path.append('/home/ltetrel/DeepNeuroAN/deepneuroan/')
sys.path.append('/home/ltetrel/DeepNeuroAN/')
warnings.filterwarnings('ignore')

import deepneuroan.utils as utils
from deepneuroan.preproc import create_ref_grid
from deepneuroan.data_generator import DataGenerator
from deepneuroan.models import ChannelwiseConv3D, LinearTransformation
from deepneuroan.generate_train_data import generate_random_quaternions, generate_random_transformations, quaternion_from_euler

In [None]:
# inputs creation   
batch_size = 1
x = np.empty((batch_size, 220, 220, 220, 1), dtype=np.float64)
truth = np.empty((batch_size, 220, 220, 220, 1), dtype=np.float64)
trf = np.empty((batch_size, 7), dtype=np.float64)
data_dir = "/home/ltetrel/Documents/data/neuromod/derivatives/deepneuroan/training/generated_data"
target_name = "ses-game001_task-shinobi1_run-01_bold_vol-0001"
source_name = target_name + "_transfo-{:06d}"

for i in range(batch_size):
    x[i, :, :, :, 0] = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(data_dir, source_name.format(i+1) + ".nii.gz"), sitk.sitkFloat64))
    truth[i, :, :, :, 0] = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(data_dir, target_name + ".nii.gz"), sitk.sitkFloat64))
    trf[i,] = utils.load_trf_file(os.path.join(data_dir, source_name.format(i+1) + ".txt"))
    
    # Inversing quaternions to compare volumes with base one
    trf[i, 1:4] = [-trf[i, 1], -trf[i, 2], -trf[i, 3]]
    trf[i, 4:] = (-1)*trf[i, 4:]
    

In [None]:
ref_grid = create_ref_grid()
sz_ref = ref_grid.GetSize()
min_ref_grid = ref_grid.GetOrigin()
max_ref_grid = ref_grid.TransformIndexToPhysicalPoint(sz_ref)
params_reg = dict(min_ref_grid=min_ref_grid, max_ref_grid=max_ref_grid, interp_method="bilinear", padding_mode="zeros")

In [None]:
import time
tic = time.time()

src = tf.keras.Input(shape=(220, 220, 220, 1))
transfo = tf.keras.Input(shape=(7))
quaternion = [0.982341, 0.024108, 0.170581, 0.072989]
reg_out = LinearTransformation(**params_reg)([src, transfo])
model = tf.keras.Model(inputs=[src, transfo], outputs=[reg_out])
model.compile()
res = model.predict(x=[x, trf], batch_size=1)

print(len(res))
ElpsTime = time.time() - tic
print("*** Total %1.3f s ***"%(ElpsTime))

In [None]:
#saving output
!rm vol*

def save_array_to_sitk(data, name, data_dir):
    ref_grid = create_ref_grid()
    sitk_img = utils.get_sitk_from_numpy(data, ref_grid)
    sitk.WriteImage(sitk_img, os.path.join(data_dir, name + ".nii.gz"))

for vol in range(res.shape[0]):
    save_array_to_sitk(data=res[vol,], name="vol%02d" %(vol+1), data_dir="./")