In [1]:
import tensorflow as tf
import numpy as np
from DepthwiseConv3D import DepthwiseConv3D
import time
import SimpleITK as sitk

In [2]:
'''''
To test how to initialize weights with tensorflow
'''''
in_tensor = tf.random.uniform(shape=(2, 5, 5, 3), minval=0, maxval=1, seed=0)
f = np.stack([np.ones((1,1))]*3, axis=2)
f = np.stack([f*0, f*1], axis=3)

out_tensor = tf.nn.conv2d(
    input=in_tensor,
    filter=f, 
    strides=[1, 1, 1, 1],
    padding="VALID")
# Construct a `Session` to execute the graph.
sess = tf.compat.v1.Session()
# Execute the graph and store the value that `e` represents in `result`.
split_inputs = tf.split(in_tensor, in_tensor.shape[-1], axis=-1)
result = sess.run([out_tensor, split_inputs, in_tensor])

# check if in_tensor result[2] is indeed splitted or not result[1]
print(result[2][0,:,:,0])
print(result[1][0][0,:,:,0])
# check if convolution sum the channels or not
print(np.sum( np.sum(result[2][0,:,:,:], axis=2) ) == np.sum( result[0][0,:,:,1] ))
print(result[0][0,:,:,1])

[[0.10086262 0.04828131 0.844468   0.6552025  0.37789786]
 [0.4662125  0.7741407  0.97539485 0.79527795 0.94884145]
 [0.36522734 0.20397687 0.3132056  0.45779288 0.04997289]
 [0.65840876 0.42018986 0.25117612 0.14169264 0.6759862 ]
 [0.30540073 0.9572604  0.6609794  0.35055983 0.8883625 ]]
[[0.10086262 0.04828131 0.844468   0.6552025  0.37789786]
 [0.4662125  0.7741407  0.97539485 0.79527795 0.94884145]
 [0.36522734 0.20397687 0.3132056  0.45779288 0.04997289]
 [0.65840876 0.42018986 0.25117612 0.14169264 0.6759862 ]
 [0.30540073 0.9572604  0.6609794  0.35055983 0.8883625 ]]
True
[[1.9197936  0.874279   1.7715013  2.0103426  2.064529  ]
 [2.154733   1.7008486  2.0440488  2.459166   2.1362352 ]
 [1.8843015  0.84527624 1.6288254  1.520882   1.3406955 ]
 [2.0531244  1.1880996  2.0071173  1.3580973  1.431062  ]
 [0.9358636  1.642369   1.1738203  1.6314249  2.2915263 ]]


In [3]:
'''''
To test how to initialize weights with keras
'''''
in_tensor = tf.random.uniform(shape=(2, 5, 5, 3), minval=0, maxval=1, seed=0)
f = np.stack([np.ones((1,1))]*3, axis=2)
f = np.stack([f*0, f*1], axis=3)

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(input_shape=(5, 5, 3), filters=2, kernel_size=(1,1), weights=[f,np.zeros(2)], activation='relu'))
res_keras = model.predict(in_tensor, steps=2)

print(model.get_weights())
print(res_keras[0,:,:,1])

W0827 15:23:55.262417 140561068492608 deprecation.py:506] From /home/ltetrel/.local/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


[array([[[[0., 1.],
         [0., 1.],
         [0., 1.]]]], dtype=float32), array([0., 0.], dtype=float32)]
[[1.9197936  0.874279   1.7715013  2.0103426  2.064529  ]
 [2.154733   1.7008486  2.0440488  2.459166   2.1362352 ]
 [1.8843015  0.84527624 1.6288254  1.520882   1.3406955 ]
 [2.0531244  1.1880996  2.0071173  1.3580973  1.431062  ]
 [0.9358636  1.642369   1.1738203  1.6314249  2.2915263 ]]


In [4]:
'''''
Depthwise conv3d
'''''
in_tensor = tf.random.uniform(shape=(2, 5, 5, 5, 3), minval=0, maxval=1, seed=0)
f = np.stack([np.ones((1, 1, 1))]*3, axis=3)
f = np.stack([f*0, f*1], axis=4)
print(f.shape)
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv3D(input_shape=(5, 5, 5, 3), filters=2, kernel_size=(1, 1, 1), weights=[f,np.zeros(2)], activation='relu'))
res_keras = model.predict(in_tensor, steps=1)

# print(np.sum(sess.run(in_tensor)[1,], axis=3))
# res_keras[1,:,:,:,1]

(1, 1, 1, 3, 2)


In [2]:
'''''
Test Channel wise convolution
Input (batch, height, width, depth, channels)
Output should be (batch, new_height, new_width, new_depth, filters, channels)
'''''

class ChannelwiseConv3D(tf.keras.layers.Layer):
    #just working for channel last
    def __init__(self, filters=1, kernel_size=(1, 1, 1), dilation_rate = (1, 1, 1), padding='SAME', strides=(1, 1, 1), **kwargs):
        self.filters = filters
        self.kernel_size = kernel_size
        self.dilation_rate = (1,) + dilation_rate + (1,)
        self.padding = padding
        self.strides = (1,) + strides + (1,)
        
        super(ChannelwiseConv3D, self).__init__(**kwargs)
        
    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        
        # if the input length is > 5, then it has more than 1 feature map (number of kernel filters)
        # this would happen usually wfor deeper layer but not the input layer
        self.n_input_fmps = 1
        if len(input_shape) > 5 :
            self.n_input_fmps = int(input_shape[-2])
            
        self.kernel = self.add_weight(name='kernel', 
                                      shape=self.kernel_size + (self.n_input_fmps,) + (self.filters,),
                                      initializer='glorot_uniform',
                                      trainable=True)
        self.bias = self.add_weight(name='bias',
                                    shape=(self.filters,),
                                    initializer='glorot_uniform',
                                    trainable=True)
        super(ChannelwiseConv3D, self).build(input_shape)  # Be sure to call this at the end
        
    def call(self, x):
        outputs = []
        split_inputs = tf.split(x, x.shape[-1], axis=-1)
        for split_input in split_inputs :
            if len(x.shape) > 5:
                split_input = tf.squeeze(split_input, axis=-1)
            out = tf.nn.conv3d(split_input
                              , self.kernel
                              , strides=self.strides
                              , padding=self.padding
                              , dilations=self.dilation_rate)
            out = tf.nn.bias_add(out, self.bias)
            outputs += [out]
        outputs = tf.stack(outputs, axis=-1)
        return outputs

    def compute_output_shape(self, input_shape):
        output_shape = input_shape[:3] + (self.filters,) + (input_shape[-1],)
        return output_shape

In [4]:
in_tensor = tf.random.uniform(shape=(2, 220, 220, 220, 3), minval=0, maxval=1, seed=0)
f = np.ones((1, 1, 1, 1))
fa = np.stack([f*0, f*1], axis=4)
f = np.ones((1, 1, 1, 2))
fb = np.stack([f*0, f*1, f*2, f*3], axis=4)

inp = tf.keras.Input(shape=(220, 220, 220, 3))
cw_conv3da = ChannelwiseConv3D(filters=2, kernel_size=(1, 1, 1), weights=[fa,np.zeros(2)])(inp)
cw_conv3db = ChannelwiseConv3D(filters=4, kernel_size=(1, 1, 1), weights=[fb,np.zeros(4)])(cw_conv3da)
model = tf.keras.Model(inputs=inp, outputs=[cw_conv3da, cw_conv3db])
res = model.predict(in_tensor, steps=3)
res_conv3da = res[0]
res_conv3db = res[1]

# # Test conv3da ( should be the same because filter is one for 2nd filter, 0 for the 1st filter)
# print(sess.run(in_tensor)[0,][:,:,:,2])
# print(res_conv3da[0,:,:,:,1,2])
# Test conv3db ( should be the sum of res_conv3da for 1st filter (in each channel), 2nd filter is the first one multiplied by two etc..., 0 for the 1st filter)
# print(np.sum(res_conv3da[1,:,:,:,:,1], axis=3))
# print(res_conv3db[1,:,:,:,1,1])
# print(3*np.sum(res_conv3da[0,:,:,:,:,0], axis=3))
# print(res_conv3db[0,:,:,:,3,0])
# print(np.sum(res_conv3db[0,]))
# print(np.sum(res_conv3db[2,]))
# print(np.sum(res_conv3db[4,]))

In [42]:
'''''
Test gaussian filtering and laplacian
'''''

x = np.empty((1, 220, 220, 220, 2), dtype=np.float64)
data_dir = "/home/ltetrel/Documents/data/neuromod/derivatives/deepneuroan/training/generated_data"

# Preprocess template
template = sitk.ReadImage(data_dir + "/template_on_grid.nii.gz", sitk.sitkFloat64)
volume_data = sitk.GetArrayFromImage(template)
# volume_data = (volume_data - np.mean(volume_data)) / np.std(volume_data)
template_preprocessed = sitk.GetImageFromArray(volume_data)
template_preprocessed.SetOrigin(template.GetOrigin())
template_preprocessed.SetSpacing(template.GetSpacing())
template_preprocessed.SetDirection(template.GetDirection())
sitk.WriteImage(template_preprocessed, "/home/ltetrel/template_preprocessed.nii.gz")

# Preprocess source image
img = sitk.ReadImage(data_dir + "/ses-vid001_task-video_run-01_bold_vol-001_transfo-000001.nii.gz", sitk.sitkFloat64)
img_data = sitk.GetArrayFromImage(img)
# img_data = (img_data - np.mean(img_data)) / np.std(img_data)
img_preprocessed = sitk.GetImageFromArray(img_data)
img_preprocessed.SetOrigin(img.GetOrigin())
img_preprocessed.SetSpacing(img.GetSpacing())
img_preprocessed.SetDirection(img.GetDirection())
sitk.WriteImage(img_preprocessed, "/home/ltetrel/source_preprocessed.nii.gz")

# creating input for keras
x[0, :, :, :, 0] = volume_data
x[0, :, :, :, 1] = img_data

f = np.zeros((3, 3, 3, 1, 1))
f[:, :, 0, 0, 0] = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]])
f[:, :, 1, 0, 0] = np.array([[2, 4, 2], [4, 4, 4], [2, 4, 2]])
f[:, :, 2, 0, 0] = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]])

f2 = (-1)*np.ones((3, 3, 3, 1, 1))
f2[1, 1, 1,] = 8
f2[:, :, 0, 0, 0] = np.array([[0, 0, 0], [0, -1, 0], [0, 0, 0]])
f2[:, :, 1, 0, 0] = np.array([[0, -1, 0], [-1, 6, -1], [0, -1, 0]])
f2[:, :, 2, 0, 0] = np.array([[0, 0, 0], [0, -1, 0], [0, 0, 0]])

inp = tf.keras.Input(shape=(220, 220, 220, 2))
gauss = ChannelwiseConv3D(filters=1, strides=(2, 2, 2), kernel_size=(3, 3, 3), weights=[f, np.zeros(1)])(inp)
laplacian = ChannelwiseConv3D(filters=1, strides=(1, 1, 1), kernel_size=(3, 3, 3), weights=[f2, np.zeros(1)])(gauss)
model = tf.keras.Model(inputs=inp, outputs=[gauss, laplacian])
res = model.predict(x, steps=1)

template_preprocessed = sitk.GetImageFromArray(res[0][0, :, :, :, 0, 0])
template_preprocessed.SetOrigin(template.GetOrigin())
template_preprocessed.SetSpacing(tuple(2*np.array(template.GetSpacing())))
template_preprocessed.SetDirection(template.GetDirection())
sitk.WriteImage(template_preprocessed, "/home/ltetrel/template_gaussian.nii.gz")

template_preprocessed = sitk.GetImageFromArray(res[1][0, :, :, :, 0, 0])
template_preprocessed.SetOrigin(template.GetOrigin())
template_preprocessed.SetSpacing(tuple(2*np.array(template.GetSpacing())))
template_preprocessed.SetDirection(template.GetDirection())
sitk.WriteImage(template_preprocessed, "/home/ltetrel/template_laplace.nii.gz")

img_preprocessed = sitk.GetImageFromArray(res[0][0, :, :, :, 0, 1])
img_preprocessed.SetOrigin(img.GetOrigin())
img_preprocessed.SetSpacing(tuple(2*np.array(img.GetSpacing())))
img_preprocessed.SetDirection(img.GetDirection())
sitk.WriteImage(img_preprocessed, "/home/ltetrel/source_gaussian.nii.gz")

img_preprocessed = sitk.GetImageFromArray(res[1][0, :, :, :, 0, 1])
img_preprocessed.SetOrigin(img.GetOrigin())
img_preprocessed.SetSpacing(tuple(2*np.array(img.GetSpacing())))
img_preprocessed.SetDirection(img.GetDirection())
sitk.WriteImage(img_preprocessed, "/home/ltetrel/source_glaplace.nii.gz")

In [28]:
tuple(2*np.array(img.GetSpacing()))

(2.0, 2.0, 2.0)

In [156]:
'''''
Test Channel wise maxpool
Input (batch, height, width, depth, filters, channels)
Output should be (batch, new_height, new_width, new_depth, filters, channels)
'''''

class ChannelwiseMaxpool3D(tf.keras.layers.Layer):
    # just working for channel last
    def __init__(self, pool_size=(1, 1, 1), padding="SAME", strides=None, **kwargs):
        self.pool_size = pool_size
        self.padding = padding
        self.strides = strides
        if strides is None:
            self.strides = pool_size
        super(ChannelwiseMaxpool3D, self).__init__(**kwargs)

    def build(self, input_shape):
        super(ChannelwiseMaxpool3D, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        outputs = []
        split_inputs = tf.split(x, x.shape[-1], axis=-1)
        for split_input in split_inputs:
            if len(x.shape) > 5:
                split_input = tf.squeeze(split_input, axis=-1)
            out = tf.nn.max_pool3d(split_input
                                   , ksize=self.pool_size
                                   , strides=self.strides
                                   , padding=self.padding)
            outputs += [out]
        outputs = tf.stack(outputs, axis=-1)
        return outputs

    def compute_output_shape(self, input_shape):
        vol_shape = (int((input_shape[1] - self.pool_size[0])/self.strides[0] + 1)
                     , int((input_shape[2] - self.pool_size[1])/self.strides[1] + 1)
                     , int((input_shape[3] - self.pool_size[2])/self.strides[2] + 1))
        output_shape = vol_shape + input_shape[3:]
        return output_shape
    
in_tensor = tf.random.uniform(shape=(2, 256, 256, 256, 2), minval=0, maxval=1, seed=0)

inp = tf.keras.Input(shape=(256, 256, 256, 2))
cw_pool3da = ChannelwiseMaxpool3D(pool_size=(2, 2, 2), padding="VALID")(inp)
model = tf.keras.Model(inputs=inp, outputs=[cw_pool3da])
res = model.predict(in_tensor, steps=1)
res.shape
# print("input")
# print(sess.run(in_tensor)[0, :, :, :, 0, 0])
# print("output")
# print(res[0, :, :, :, 0, 0])

(2, 128, 128, 128, 1, 2)

In [154]:
'''''
Test maxpool with reshape
Input (batch, height, width, depth, filters, channels)
Output should be (batch, new_height, new_width, new_depth, filters, channels)
'''''
in_tensor = tf.random.uniform(shape=(2, 5, 5, 5, 2, 3), minval=0, maxval=1, seed=0)

tic = time.time()

inp = tf.keras.Input(shape=(5, 5, 5, 2, 3))
# new_shape = tuple(inp.shape[:-2]) + (int(inp.shape[-2]*inp.shape[-1]),)
# print(inp.shape.as_list())
if(len(inp.shape.as_list()) > 5):
    new_shape = [-1] + inp.shape.as_list()[1:4] + [inp.shape.as_list()[4]*inp.shape.as_list()[5]]
    squeezed = tf.keras.backend.reshape(x=inp, shape=new_shape)
    cw_pool3da = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2), padding="VALID")(squeezed)
    output = tf.keras.backend.reshape(x=cw_pool3da, shape=[-1, 2, 2, 2, 2, 3])
model = tf.keras.Model(inputs=inp, outputs=[output])

res = model.predict(in_tensor, steps=1000)

ElpsTime = time.time() - tic
print("%1.2f s"%(ElpsTime))

# print("input")
# print(sess.run(in_tensor)[0, :, :, :, 0, 0])
# print("output")
# print(res[0, :, :, :, 0, 0])

0.42 s
