In [1]:
import os, sys
sys.path.append('/home/ltetrel/DeepNeuroAN/deepneuroan/')
sys.path.append('/home/ltetrel/DeepNeuroAN/')

import numpy as np
import tensorflow as tf

from preproc import create_ref_grid
import deepneuroan.utils as utils
import SimpleITK as sitk

tf.keras.backend.clear_session()

In [2]:
# inputs creation   
batch_size = 10
x = np.empty((batch_size, 9, 9, 9, 1), dtype=np.float64)
trf = np.empty((batch_size, 7), dtype=np.float64)
data_dir = "./data"

for i in range(batch_size):
    x[i, :, :, :, 0] = sitk.GetArrayFromImage(sitk.ReadImage(data_dir + "/ses-vid001_task-video_run-01_bold_vol-0001_transfo-%06d.nii.gz" %(i+1)
                                   , sitk.sitkFloat64))[100:109, 110:119, 120:129]

    trf[i,] = utils.load_trf_file(data_dir + "/ses-vid001_task-video_run-01_bold_vol-0001_transfo-%06d.txt" %(i+1))
    
    # Inversing quaternions to compare volumes with base one
    q = sitk.VersorRigid3DTransform([trf[i, 1], trf[i, 2], trf[i, 3], trf[i, 0]])
    t = sitk.TranslationTransform(3, tuple(trf[i, 4:]))
    q.SetTranslation(t.GetOffset())
    q = q.GetInverse().GetParameters()
    trf[i, 1:4] = [-trf[i, 1], -trf[i, 2], -trf[i, 3]]
    trf[i, 4:] = q[3:]

In [3]:
class LinearTransformation(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        self.trainable = False
        super(self.__class__, self).__init__(**kwargs)

    def build(self, input_shape):
        assert isinstance(input_shape, list)
        if len(input_shape) > 2:
            raise Exception("LinearRegistration must be called on a list of length 2. "
                            "First argument is the volume, second is the quaternion transform.")

        super(self.__class__, self).build(input_shape)

    def call(self, inputs):
        assert isinstance(inputs, list)
        
        # map transform across batch
        return tf.map_fn(self._tf_single_transform, inputs, dtype=tf.float32)
            
    def compute_output_shape(self, input_shape):
        return input_shape[0]   
    
    def _single_transform(self, src, trf):
        ref_grid = create_ref_grid()

        sitk_src = utils.get_sitk_from_numpy(src, ref_grid)
        sitk_tgt = utils.transform_volume(sitk_src, ref_grid, rigid=trf)

        return sitk.GetArrayFromImage(sitk_tgt)
    
    @tf.function
    def _tf_single_transform(self, inputs): 
        out = tf.numpy_function(self._single_transform, inp=[inputs[0], inputs[1]], Tout=tf.float32) 
        return out


In [5]:
src = tf.keras.Input(shape=(9, 9, 9, 1))
inp_flatten = tf.keras.layers.Flatten()(src)
conv = tf.keras.layers.Dense(units=7, kernel_initializer=tf.keras.initializers.glorot_uniform(seed=0), activation=None)(inp_flatten)
reg_out = LinearTransformation()([src, conv])
model = tf.keras.Model(inputs=[src], outputs=[reg_out])

model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3)
              , loss=["mae"]
              , metrics=["mae"])

res = model.predict(x=[x, trf], batch_size=1)
print(res[0])

TypeError: object of type 'NoneType' has no len()

In [5]:
src = tf.keras.Input(shape=(220, 220, 220))
y = tf.keras.Input(shape=(7))
reg_out = LinearTransformation()([src, y])
model = tf.keras.Model(inputs=[src, y], outputs=[reg_out])
res = model.predict(x=[x, trf], batch_size=2)
print(len(res))

ValueError: Error when checking input: expected input_2 to have shape (220, 220, 220) but got array with shape (9, 9, 9)

In [5]:
#saving output
!rm data/vol*

def save_array_to_sitk(data, name, data_dir):
    ref_grid = create_ref_grid()
    sitk_img = utils.get_sitk_from_numpy(data, ref_grid)
    sitk.WriteImage(sitk_img, os.path.join(data_dir, name + ".nii.gz"))

for vol in range(res.shape[0]):
    save_array_to_sitk(data=res[vol,], name="vol%02d" %(vol+1), data_dir=data_dir)