In [1]:
import os, sys
sys.path.append('/home/ltetrel/DeepNeuroAN/deepneuroan/')
sys.path.append('/home/ltetrel/DeepNeuroAN/')

import numpy as np
np.set_printoptions(precision=2)
import tensorflow as tf

from preproc import create_ref_grid
import deepneuroan.utils as utils
import SimpleITK as sitk

tf.keras.backend.clear_session()

if tf.config.list_physical_devices('GPU'):
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(1)
print(tf.config.list_physical_devices('GPU'))

[]


In [25]:
inputs = [tf.ones((4, 220, 220, 220, 1))]

num_batch = tf.shape(inputs[0])[0]
num_channels = tf.shape(inputs[0])[-1]
ref_size = tf.shape(inputs[0])[1:-1]
ref_size_xyz = tf.concat([ref_size[1::-1], ref_size[2:]], axis=0)
tf.concat([tf.expand_dims(num_batch, axis=0), ref_size, tf.expand_dims(num_channels, axis=0)], axis=0)

<tf.Tensor: shape=(5,), dtype=int32, numpy=array([  4, 220, 220, 220,   1], dtype=int32)>

In [2]:
# inputs creation   
batch_size = 4
x = np.empty((batch_size, 220, 220, 220, 1), dtype=np.float64)
trf = np.empty((batch_size, 7), dtype=np.float64)
data_dir = "./data"

for i in range(batch_size):
#     x[i, :, :, :, 0] = sitk.GetArrayFromImage(sitk.ReadImage(data_dir + "/ses-vid001_task-video_run-01_bold_vol-0001_transfo-%06d.nii.gz" %(i+1)
#                                    , sitk.sitkFloat64))[100:109, 110:119, 120:129]
    x[i, :, :, :, 0] = sitk.GetArrayFromImage(sitk.ReadImage(data_dir + "/ses-vid001_task-video_run-01_bold_vol-0001_transfo-%06d.nii.gz" %(i+1)
                                   , sitk.sitkFloat64))

    trf[i,] = utils.load_trf_file(data_dir + "/ses-vid001_task-video_run-01_bold_vol-0001_transfo-%06d.txt" %(i+1))
    
    # Inversing quaternions to compare volumes with base one
    q = sitk.VersorRigid3DTransform([trf[i, 1], trf[i, 2], trf[i, 3], trf[i, 0]])
    t = sitk.TranslationTransform(3, tuple(trf[i, 4:]))
    q.SetTranslation(t.GetOffset())
    q = q.GetInverse().GetParameters()
    trf[i, 1:4] = [-trf[i, 1], -trf[i, 2], -trf[i, 3]]
    trf[i, 4:] = q[3:]
    

In [32]:
class LinearTransformation(tf.keras.layers.Layer):
    def __init__(self, min_ref_grid=[-1.], max_ref_grid=[1.], interp_method="nn", padding_mode="zeros", padding_mode_value=0., **kwargs):
        self.trainable = False
        super(self.__class__, self).__init__(**kwargs)
        self.min_ref_grid = tf.constant(min_ref_grid, dtype=tf.float32)
        self.max_ref_grid = tf.constant(max_ref_grid, dtype=tf.float32)
        self.interp_method = tf.constant(interp_method, dtype=tf.string)
        self.padding_mode = tf.constant(padding_mode, dtype=tf.string)
        self.padding_mode_value = tf.constant(padding_mode_value, dtype=tf.float32)

    def build(self, input_shape):
        assert isinstance(input_shape, list)
        
        num_dims = input_shape[0].ndims - 2
        if num_dims != self.min_ref_grid.shape[0]:
            self.min_ref_grid = (-1) * tf.ones(num_dims, dtype=tf.float32)
            self.max_ref_grid = tf.ones(num_dims, dtype=tf.float32)
        
        super(self.__class__, self).build(input_shape)
    
    @tf.function
    def _transform_grid(self, ref_size_xyz, transfos, min_ref_grid, max_ref_grid):

        # constants
        num_batch = tf.shape(transfos)[0]
        num_elems = tf.reduce_prod(ref_size_xyz)
        thetas = utils.get_matrix_from_params(transfos, num_elems)

        # grid creation from volume affine
        mz, my, mx = tf.meshgrid(tf.linspace(min_ref_grid[2], max_ref_grid[2], ref_size_xyz[2])
                                , tf.linspace(min_ref_grid[1], max_ref_grid[1], ref_size_xyz[1])
                                , tf.linspace(min_ref_grid[0], max_ref_grid[0], ref_size_xyz[0])
                                , indexing='ij')

        # preparing grid for quaternion rotation
        grid = tf.concat([tf.reshape(mx, (1, -1)), tf.reshape(my, (1, -1)), tf.reshape(mz, (1, -1))], axis=0)
        grid = tf.expand_dims(grid, axis=0)
        grid = tf.tile(grid, (num_batch, 1, 1))

        # preparing grid for augmented transformation
        grid = tf.concat([grid, tf.ones((num_batch, 1, num_elems))], axis=1)

        return tf.linalg.matmul(thetas, grid)
    
    @tf.function
    def _interpolate(self, im, points, min_ref_grid, max_ref_grid, method="nn", padding_mode="zeros", padding_mode_value=0.):

        #constants
        num_batch = tf.shape(im)[0]
        vol_shape_xyz = tf.cast( tf.concat([tf.shape(im)[1:-1][1::-1], tf.shape(im)[1:-1][2:]], axis=0), dtype=tf.float32)
        width = vol_shape_xyz[0]
        height = vol_shape_xyz[1]
        depth = vol_shape_xyz[2]
        width_i = tf.cast(width, dtype=tf.int32)
        height_i = tf.cast(height, dtype=tf.int32)
        depth_i = tf.cast(depth, dtype=tf.int32)
        channels = tf.shape(im)[-1]
        num_row_major = tf.cast(tf.math.cumprod(vol_shape_xyz), dtype=tf.int32)
        zero = tf.zeros([], dtype=tf.float32)
        output = tf.zeros((num_batch, num_row_major[-1] , 1), dtype=tf.float32)
        valid = tf.ones_like(output)
        ibatch = utils.repeat(num_row_major[-1] * tf.range(num_batch, dtype=tf.int32), num_row_major[-1])

        # scale positions to [0, width/height - 1]
        coeff_x = (width - 1.)/(max_ref_grid[0] - min_ref_grid[0])
        coeff_y = (height - 1.)/(max_ref_grid[1] - min_ref_grid[1])
        coeff_z = (depth - 1.)/(max_ref_grid[2] - min_ref_grid[2])
        ix = (coeff_x * points[:, 0, :]) - (coeff_x *  min_ref_grid[0])
        iy = (coeff_y * points[:, 1, :]) - (coeff_y *  min_ref_grid[1])
        iz = (coeff_z * points[:, 2, :]) - (coeff_z *  min_ref_grid[2])

        # padding mode, for positions outside of refrence grid
        if (padding_mode == "zero") | (padding_mode == "value"):
            valid = tf.expand_dims(tf.cast(tf.less_equal(ix, width - 1.) & tf.greater_equal(ix, zero)
                                            & tf.less_equal(iy, height - 1.) & tf.greater_equal(iy, zero)
                                            & tf.less_equal(iz, depth - 1.) & tf.greater_equal(iz, zero)
                                            , dtype=tf.float32), -1)

        # if we use bilinear interpolation, we calculate each area between corners and positions to get the weights for each input pixel
        if method == "bilinear":
            # get north-west-top corner indexes based on the scaled positions
            ix_nwt = tf.clip_by_value(tf.floor(ix), zero, width - 1.)
            iy_nwt = tf.clip_by_value(tf.floor(iy), zero, height - 1.)
            iz_nwt = tf.clip_by_value(tf.floor(iz), zero, depth - 1.)
            ix_nwt_i = tf.cast(ix_nwt, dtype=tf.int32)
            iy_nwt_i = tf.cast(iy_nwt, dtype=tf.int32)
            iz_nwt_i = tf.cast(iz_nwt, dtype=tf.int32)       

            #gettings all offsets to create corners
            offset_corner = tf.constant([ [0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [0., 1., 1.], [1., 0., 0.], [1., 0., 1.], [1., 1., 0.], [1., 1., 1.]], dtype=tf.float32)
            offset_corner_i =  tf.cast(offset_corner, dtype=tf.int32)

            for c in range(8):
                # getting all corner indexes from north-west-top corner
                ix_c = ix_nwt + offset_corner[-c - 1, 0]
                iy_c = iy_nwt + offset_corner[-c - 1, 1]
                iz_c = iz_nwt + offset_corner[-c - 1, 2]

                # area is computed using the opposite corner
                nc = tf.expand_dims(tf.abs((ix - ix_c) * (iy - iy_c) * (iz - iz_c)), -1)

                # current corner position
                ix_c = ix_nwt_i + offset_corner_i[c, 0]
                iy_c = iy_nwt_i + offset_corner_i[c, 1]
                iz_c = iz_nwt_i + offset_corner_i[c, 2]

                # gather input image values from corners idx, and calculate weighted pixel value
                idx_c = ibatch + tf.math.minimum( width_i - 1, ix_c) + num_row_major[0] * tf.math.minimum( height_i - 1, iy_c) + num_row_major[1] * tf.math.minimum( depth_i - 1, iz_c)
                Ic = tf.gather(tf.reshape(im, [-1, channels]), idx_c)

                output += nc * Ic


        # else if method is nearest neighbor, we get the nearest corner
        elif method == "nn":
            # get rounded indice corner based on the scaled positions
            ix_nn = tf.cast(tf.clip_by_value(tf.round(ix), zero, width - 1.), dtype=tf.int32)
            iy_nn = tf.cast(tf.clip_by_value(tf.round(iy), zero, height - 1.), dtype=tf.int32)
            iz_nn = tf.cast(tf.clip_by_value(tf.round(iz), zero, depth - 1.), dtype=tf.int32)

            # gather input pixel values from nn corner indexes
            idx_nn = ibatch + ix_nn + num_row_major[0] * iy_nn + num_row_major[1] * iz_nn

            output = tf.gather(tf.reshape(im, [-1, channels]), idx_nn)

        if padding_mode == "zero":
            output = output * valid
        elif padding_mode == "value":
            output = output * valid + padding_mode_value * (1. - valid)

        return output
        
    def call(self, inputs):
        assert isinstance(inputs, list)
        
        input_shape = tf.shape(inputs[0])
        num_batch = input_shape[0]
        num_channels = input_shape[-1]
        ref_size = input_shape[1:-1]
        ref_size_xyz = tf.concat([ref_size[1::-1], ref_size[2:]], axis=0)
        
        input_transformed = self._transform_grid(ref_size_xyz, transfos=inputs[1], min_ref_grid=self.min_ref_grid, max_ref_grid=self.max_ref_grid)
        input_transformed = self._interpolate(im=inputs[0]
                                         , points=input_transformed
                                         , min_ref_grid=self.min_ref_grid
                                         , max_ref_grid=self.max_ref_grid
                                         , method=self.interp_method
                                         , padding_mode=self.padding_mode
                                         , padding_mode_value=self.padding_mode_value)
        output = tf.reshape(input_transformed, input_shape)
        
        return output
            
    def compute_output_shape(self, input_shape):
        return input_shape[0]

In [30]:
#TODO: When it is working, uncomment

# n_batch = 7
# transfos = tf.random.uniform(shape=(n_batch, 7), seed=0) #quaternions (4,) + translations (3,) + scales (3,)
# U = tf.random.uniform((n_batch, 220, 220, 220, 1))
# out_size = [10, 10, 10]
# name='BatchSpatialTransformer3dAffine'

# ref_grid = create_ref_grid()
# sz_ref = ref_grid.GetSize()
# min_ref_grid = ref_grid.GetOrigin()
# max_ref_grid = ref_grid.TransformIndexToPhysicalPoint(sz_ref)

# with tf.compat.v1.variable_scope(name):


In [33]:
ref_grid = create_ref_grid()
sz_ref = ref_grid.GetSize()
min_ref_grid = ref_grid.GetOrigin()
max_ref_grid = ref_grid.TransformIndexToPhysicalPoint(sz_ref)
interp_method = "nn"
padding_mode = "border"

src = tf.keras.Input(shape=(220, 220, 220, 1))
conv = tf.keras.layers.Conv3D(filters=1, kernel_size=(3, 3, 3), strides=(15, 15, 15), padding='valid')(src)
conv_flatten = tf.keras.layers.Flatten()(conv)
conv1 = tf.keras.layers.Dense(units=7, kernel_initializer=tf.keras.initializers.glorot_uniform(seed=0), activation=None)(conv_flatten)
reg_out = LinearTransformation(min_ref_grid=min_ref_grid, max_ref_grid=max_ref_grid, interp_method=interp_method, padding_mode=padding_mode)([src, conv1])
model = tf.keras.Model(inputs=[src], outputs=[reg_out])

model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3)
              , loss=["mae"]
              , metrics=["mae"])

res = model.predict(x=[x, trf], batch_size=1)
print(res[0])

[[[[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  ...

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]]


 [[[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  ...

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]]


 [[[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  ...

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.

In [None]:
import time
tic = time.time()

src = tf.keras.Input(shape=(220, 220, 220))
y = tf.keras.Input(shape=(7))
reg_out = LinearTransformation()([src, y])
model = tf.keras.Model(inputs=[src, y], outputs=[reg_out])
res = model.predict(x=[x, trf], batch_size=2)
print(len(res))

ElpsTime = time.time() - tic
print("*** Total %1.3f s ***"%(ElpsTime))

In [None]:
#saving output
!rm data/vol*

def save_array_to_sitk(data, name, data_dir):
    ref_grid = create_ref_grid()
    sitk_img = utils.get_sitk_from_numpy(data, ref_grid)
    sitk.WriteImage(sitk_img, os.path.join(data_dir, name + ".nii.gz"))

for vol in range(res.shape[0]):
    save_array_to_sitk(data=res[vol,], name="vol%02d" %(vol+1), data_dir=data_dir)

In [None]:
n_batch = 2
input_size = (220, 220, 220)
#quaternions (4,) + translations (3,) + scales (3,)
transfos = tf.stack( [[1., 0., 0., 0., 0., 0., 0.], [0.7071, 0.7071, 0., 0., 0., 0., 0.], [0.7071, 0.7071, 0., 0., -20., -20., -20.]] )[:n_batch]

q = transfos[:, :4]

tmp=tf.stack([1 - 2.*(q[...,2]**2 + q[...,3]**2), 2*(q[...,1]*q[...,2] - q[...,0]*q[...,3]), 2*(q[...,0]*q[...,2] + q[...,1]*q[...,3]),
              2.*(q[...,1]*q[...,2] + q[...,0]*q[...,3]), 1 - 2.*(q[...,1]**2 + q[...,3]**2), 2.*(q[...,2]*q[...,3] - q[...,0]*q[...,1]),
              2.*(q[...,1]*q[...,3] - q[...,0]*q[...,2]), 2.*(q[...,0]*q[...,1] + q[...,2]*q[...,3]), 1 - 2.*(q[...,1]**2 + q[...,2]**2)],axis=-1)

M = tf.reshape(tmp, (-1, 3, 3))

In [None]:
p = np.array([ [1., 2., 3., 4., 5.], [0., 0., 0., 0., 0.], [3., 3., 3., 3., 3.] ], dtype=np.float32)
p = tf.tile(tf.expand_dims(p, axis=0), multiples=(n_batch, 1, 1))

print(M)
tf.matmul(M, p)

In [None]:
print(tf.matmul(tfg.geometry.transformation.rotation_matrix_3d.from_quaternion(q), p))
print(tf.matmul(M, p))

In [None]:
# trying with a simple transform
t_np = np.array([[1, 0, 0, 0]
                 ,[0, 1, 0, -10]
                 ,[0, 0, 1, -30]
                 ,[0, 0, 0, 1]])

affine = sitk.Euler3DTransform()
affine.SetTranslation((-10, 30, 0))

data_dir = "./data"
filepath = data_dir + "/ses-vid001_task-video_run-01_bold_vol-0001_transfo-%06d.nii.gz" %(1)
ref_grid = create_ref_grid()
source_brain = sitk.ReadImage(filepath, sitk.sitkFloat32)

import time
tic = time.time()
for i in range(1):
    brain_to_grid = sitk.Resample(source_brain, ref_grid, affine, sitk.sitkLinear, 0.0, sitk.sitkFloat32)
ElpsTime = time.time() - tic
print("*** Total %1.3f s ***"%(ElpsTime))