In [1]:
#!/usr/bin/python
"""
Class and helper functions for fitting the DeepN4 model in TF/Keras.
"""
import logging
import numpy as np

from dipy.data import get_fnames
from dipy.testing.decorators import doctest_skip_parser
from dipy.utils.optpkg import optional_package
from dipy.nn.utils import normalize, unnormalize, set_logger_level

tf, have_tf, _ = optional_package('tensorflow')#, min_version='2.0.0')
tfa, have_tfa, _ = optional_package('tensorflow_addons')
if have_tf and have_tfa:
    from tensorflow.keras.models import Model
    from tensorflow.keras.layers import MaxPool3D, Conv3DTranspose
    from tensorflow.keras.layers import Conv3D, LeakyReLU
    from tensorflow.keras.layers import Concatenate, Layer
    from tensorflow_addons.layers import InstanceNormalization
else:
    class Model:
        pass

    class Layer:
        pass
    logging.warning('This model requires Tensorflow and Tensorflow\
                    -addons. Please install these packages using \
                    pip. If using mac, please refer to this \
                    link for installation. \
                    https://github.com/apple/tensorflow_macos')

logging.basicConfig()
logger = logging.getLogger('synb0')

class EncoderBlock(Layer):
    def __init__(self, out_channels, kernel_size, strides, padding):
        super(EncoderBlock, self).__init__()
        self.conv3d = Conv3D(out_channels,
                             kernel_size,
                             strides=strides,
                             padding=padding,
                             use_bias=False)
        self.instnorm = InstanceNormalization(axis=-1, center=False, scale=False)
        self.activation = LeakyReLU(0.01)

    def call(self, input):
        x = self.conv3d(input)
        x = self.instnorm(x)
        x = self.activation(x)

        return x

class DecoderBlock(Layer):
    def __init__(self, out_channels, kernel_size, strides, padding):
        super(DecoderBlock, self).__init__()
        self.conv3d = Conv3DTranspose(out_channels,
                                      kernel_size,
                                      strides=strides,
                                      padding=padding,
                                      use_bias=False)
        self.instnorm = InstanceNormalization(axis=-1, center=False, scale=False)
        self.activation = LeakyReLU(0.01)

    def call(self, input):
        x = self.conv3d(input)
        x = self.instnorm(x)
        x = self.activation(x)

        return x

def UNet3D(input_shape):
    r"""
    Function to create model for Synb0

    Parameters
    ----------
    input_shape : tuple
        The input shape of the model

    Returns
    -------
    tf.keras.Model
    """
    inputs = tf.keras.Input(input_shape)
    # Encode
    x = EncoderBlock(32, kernel_size=3,
                     strides=1, padding='same')(inputs)
    syn0 = EncoderBlock(64, kernel_size=3,
                        strides=1, padding='same')(x)

    x = MaxPool3D()(syn0)
    x = EncoderBlock(64, kernel_size=3,
                     strides=1, padding='same')(x)
    syn1 = EncoderBlock(128, kernel_size=3,
                        strides=1, padding='same')(x)

    x = MaxPool3D()(syn1)
    x = EncoderBlock(128, kernel_size=3,
                     strides=1, padding='same')(x)
    syn2 = EncoderBlock(256, kernel_size=3,
                        strides=1, padding='same')(x)

    x = MaxPool3D()(syn2)
    x = EncoderBlock(256, kernel_size=3,
                     strides=1, padding='same')(x)
    x = EncoderBlock(512, kernel_size=3,
                     strides=1, padding='same')(x)

    # Last layer without relu
    x = Conv3D(512, kernel_size=1,
               strides=1, padding='same')(x)

    x = DecoderBlock(512, kernel_size=2,
                     strides=2, padding='valid')(x)

    x = Concatenate()([x, syn2])

    x = DecoderBlock(256, kernel_size=3,
                     strides=1, padding='same')(x)
    x = DecoderBlock(256, kernel_size=3,
                     strides=1, padding='same')(x)
    x = DecoderBlock(256, kernel_size=2,
                     strides=2, padding='valid')(x)

    x = Concatenate()([x, syn1])

    x = DecoderBlock(128, kernel_size=3,
                     strides=1, padding='same')(x)
    x = DecoderBlock(128, kernel_size=3,
                     strides=1, padding='same')(x)
    x = DecoderBlock(128, kernel_size=2,
                     strides=2, padding='valid')(x)

    x = Concatenate()([x, syn0])

    x = DecoderBlock(64, kernel_size=3,
                     strides=1, padding='same')(x)
    x = DecoderBlock(64, kernel_size=3,
                     strides=1, padding='same')(x)

    x = DecoderBlock(1, kernel_size=1,
                     strides=1, padding='valid')(x)

    # Last layer without relu
    out = Conv3DTranspose(1, kernel_size=1,
                          strides=1, padding='valid')(x)

    return Model(inputs, out)

keras_model = UNet3D(input_shape=(128, 128, 128, 1))
keras_model.summary()

  from .autonotebook import tqdm as notebook_tqdm
2024-02-16 09:54:44.824430: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-16 09:54:44.826248: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-16 09:54:44.858694: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-16 09:54:44.859501: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

TensorFlow Addons (TFA) has ended development and i

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 128, 128, 128, 1)]   0         []                            
                                                                                                  
 encoder_block (EncoderBloc  (None, 128, 128, 128, 32)    864       ['input_1[0][0]']             
 k)                                                                                               
                                                                                                  
 encoder_block_1 (EncoderBl  (None, 128, 128, 128, 64)    55296     ['encoder_block[0][0]']       
 ock)                                                                                             
                                                                                              

In [2]:
"""
The DeepN4 model.
"""
import torch
from torch import nn

class Synbo_UNet3D(nn.Module):
    def __init__(self, n_in, n_out):
        super(Synbo_UNet3D, self).__init__()
        # Encoder
        c = 32
        self.ec0 = self.encoder_block(      n_in,    1*c, kernel_size=3, stride=1, padding=1)
        self.ec1 = self.encoder_block(        c,    c*2, kernel_size=3, stride=1, padding=1)
        self.pool0 = nn.MaxPool3d(2)
        self.ec2 = self.encoder_block(        c*2,    c*2, kernel_size=3, stride=1, padding=1)
        self.ec3 = self.encoder_block(        c*2,   c*4, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool3d(2)
        self.ec4 = self.encoder_block(       c*4,   c*4, kernel_size=3, stride=1, padding=1)
        self.ec5 = self.encoder_block(       c*4,   c*8, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool3d(2)
        self.ec6 = self.encoder_block(       c*8,   c*8, kernel_size=3, stride=1, padding=1)
        self.ec7 = self.encoder_block(       c*8,   c*16, kernel_size=3, stride=1, padding=1)
        self.el  =          nn.Conv3d(       c*16,   c*16, kernel_size=1, stride=1, padding=0)

        # Decoder
        self.dc9 = self.decoder_block(       c*16,   c*16, kernel_size=2, stride=2, padding=0)
        self.dc8 = self.decoder_block( c*16 + c*8,   c*8, kernel_size=3, stride=1, padding=1)
        self.dc7 = self.decoder_block(       c*8,   c*8, kernel_size=3, stride=1, padding=1)
        self.dc6 = self.decoder_block(       c*8,   c*8, kernel_size=2, stride=2, padding=0)
        self.dc5 = self.decoder_block( c*8 + c*4,   c*4, kernel_size=3, stride=1, padding=1)
        self.dc4 = self.decoder_block(       c*4,   c*4, kernel_size=3, stride=1, padding=1)
        self.dc3 = self.decoder_block(       c*4,   c*4, kernel_size=2, stride=2, padding=0)
        self.dc2 = self.decoder_block(  c*4 + c*2,    c*2, kernel_size=3, stride=1, padding=1)
        self.dc1 = self.decoder_block(        c*2,    c*2, kernel_size=3, stride=1, padding=1)
        self.dc0 = self.decoder_block(        c*2, n_out, kernel_size=1, stride=1, padding=0)
        self.dl  = nn.ConvTranspose3d(     n_out, n_out, kernel_size=1, stride=1, padding=0)
        # self.act  = nn.LazyLinear(n_out)
        # self.act = nn.Linear(n_out, n_out)
        # self.dl  = BSplineLayer(     4, 4, n_bases=6, shared_weights=True,bias=False, weighted_sum=False)#, kernel_size=1, stride=1, padding=0)

    def encoder_block(self, in_channels, out_channels, kernel_size, stride, padding):
        layer = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
            nn.InstanceNorm3d(out_channels),
            nn.LeakyReLU())
        return layer

    def decoder_block(self, in_channels, out_channels, kernel_size, stride, padding):
        layer = nn.Sequential(
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
            nn.InstanceNorm3d(out_channels),
            nn.LeakyReLU())
        return layer

    def forward(self, x):
        # Encodes
        e0   = self.ec0(x)
        syn0 = self.ec1(e0)
        del e0

        e1   = self.pool0(syn0)
        e2   = self.ec2(e1)
        syn1 = self.ec3(e2)
        del e1, e2

        e3   = self.pool1(syn1)
        e4   = self.ec4(e3)
        syn2 = self.ec5(e4)
        del e3, e4

        e5   = self.pool2(syn2)
        e6   = self.ec6(e5)
        e7   = self.ec7(e6)

        # Last layer without relu
        el   = self.el(e7)
        del e5, e6, e7

        # Decode
        d9   = torch.cat((self.dc9(el), syn2), 1)
        del el, syn2

        d8   = self.dc8(d9)
        d7   = self.dc7(d8)
        del d9, d8

        d6   = torch.cat((self.dc6(d7), syn1), 1)
        del d7, syn1

        d5   = self.dc5(d6)
        d4   = self.dc4(d5)
        del d6, d5

        d3   = torch.cat((self.dc3(d4), syn0), 1)
        del d4, syn0

        d2   = self.dc2(d3)
        d1   = self.dc1(d2)
        del d3, d2

        d0   = self.dc0(d1) 
        del d1

        # Last layer without relu
        out  = self.dl(d0)#. reshape(-1, 1*128*128*128)
        return out #self.act(out)



In [3]:
"Load Pytorch model weights"
# Assuming `pytorch_model` is your model's architecture in PyTorch
pytorch_model = Synbo_UNet3D(1,1)
checkpoint =  torch.load("/home/local/VANDERBILT/kanakap/Downloads/trained_weights_checkpoint")
model_state_dict = checkpoint['model_state_dict']
pytorch_model.eval()  # Set the model to evaluation mode


Synbo_UNet3D(
  (ec0): Sequential(
    (0): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): LeakyReLU(negative_slope=0.01)
  )
  (ec1): Sequential(
    (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): LeakyReLU(negative_slope=0.01)
  )
  (pool0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (ec2): Sequential(
    (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): LeakyReLU(negative_slope=0.01)
  )
  (ec3): Sequential(
    (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): InstanceN

In [4]:
"Load encoder weights"
encoder_keys = ['ec0.0.weight', 'ec1.0.weight', 'ec2.0.weight', 'ec3.0.weight', 'ec4.0.weight', 'ec5.0.weight', 'ec6.0.weight', 'ec7.0.weight']
encoder_layers_keras = ['encoder_block','encoder_block_1', 'encoder_block_2','encoder_block_3', 'encoder_block_4','encoder_block_5', 'encoder_block_6','encoder_block_7']

for i,j in zip(encoder_keys,encoder_layers_keras):
    py_weights = model_state_dict[i].cpu()  # This should match your PyTorch model's layer
    print(np.shape(py_weights)) # (out_channels, in_channels, D, H, W)
    py_weights = py_weights.permute(2, 3, 4, 1, 0).numpy()  # Transpose to match Keras/TensorFlow format: (D, H, W, in_channels, out_channels)
    print(np.shape(py_weights))

    keras_layer = keras_model.get_layer(name=j)
    keras_layer.set_weights([py_weights])

    

torch.Size([32, 1, 3, 3, 3])
(3, 3, 3, 1, 32)
torch.Size([64, 32, 3, 3, 3])
(3, 3, 3, 32, 64)
torch.Size([64, 64, 3, 3, 3])
(3, 3, 3, 64, 64)
torch.Size([128, 64, 3, 3, 3])
(3, 3, 3, 64, 128)
torch.Size([128, 128, 3, 3, 3])
(3, 3, 3, 128, 128)
torch.Size([256, 128, 3, 3, 3])
(3, 3, 3, 128, 256)
torch.Size([256, 256, 3, 3, 3])
(3, 3, 3, 256, 256)
torch.Size([512, 256, 3, 3, 3])
(3, 3, 3, 256, 512)


In [5]:
"Load decoder weights"
deencoder_keys = ['dc9.0.weight', 'dc8.0.weight', 'dc7.0.weight', 'dc6.0.weight', 'dc5.0.weight', 'dc4.0.weight', 'dc3.0.weight', 'dc2.0.weight', 'dc1.0.weight', 'dc0.0.weight']
deencoder_layers_keras = ['decoder_block','decoder_block_1', 'decoder_block_2','decoder_block_3', 'decoder_block_4','decoder_block_5', 'decoder_block_6','decoder_block_7']


for i, j  in zip(deencoder_keys,deencoder_layers_keras):
    py_weights = model_state_dict[i].cpu()  # This should match your PyTorch model's layer
    py_weights = py_weights.permute(2, 3, 4, 1, 0).numpy()
    keras_layer = keras_model.get_layer(name=j)
    keras_layer.set_weights([py_weights])

In [6]:
py_weights = model_state_dict['el.weight'].cpu()  # This should match your PyTorch model's layer
py_weights = py_weights.permute(2, 3, 4, 1, 0).numpy()
py_bias = model_state_dict['el.bias'].cpu().numpy() 

# keras_layer = keras_model.get_layer(name="conv3d_8")
keras_layer = keras_model.get_layer(name="conv3d_8")
keras_layer.set_weights([py_weights, py_bias])

In [7]:
py_weights = model_state_dict['dl.weight'].cpu()  # This should match your PyTorch model's layer
py_weights = py_weights.permute(2, 3, 4, 1, 0).numpy()
py_bias = model_state_dict['dl.bias'].cpu().numpy() 

# keras_layer = keras_model.get_layer(name="conv3d_transpose_10")
keras_layer = keras_model.get_layer(name="conv3d_transpose_10")
keras_layer.set_weights([py_weights, py_bias])

In [8]:
"Save model"
keras_model.save("/nfs/masi/kanakap/projects/DeepN4/src/trained_model_tf/model_weights5.h5")




  saving_api.save_model(


In [9]:
"Load keras model"
keras_model.load_weights("/nfs/masi/kanakap/projects/DeepN4/src/trained_model_tf/model_weights5.h5")
keras_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 128, 128, 128, 1)]   0         []                            
                                                                                                  
 encoder_block (EncoderBloc  (None, 128, 128, 128, 32)    864       ['input_1[0][0]']             
 k)                                                                                               
                                                                                                  
 encoder_block_1 (EncoderBl  (None, 128, 128, 128, 64)    55296     ['encoder_block[0][0]']       
 ock)                                                                                             
                                                                                              

In [10]:
"Test keras model"
random_array = np.random.rand(1, 128, 128, 128, 1 )
keras_model.predict(random_array)



array([[[[[-2.6313016 ],
          [-2.6053243 ],
          [-2.603585  ],
          ...,
          [-2.6133199 ],
          [-2.5287077 ],
          [-2.5490243 ]],

         [[-2.5788512 ],
          [-1.2045957 ],
          [-2.5452304 ],
          ...,
          [-2.4989452 ],
          [-1.9940536 ],
          [ 4.5190725 ]],

         [[-2.5580876 ],
          [ 1.905601  ],
          [-2.5161934 ],
          ...,
          [ 2.6142788 ],
          [ 4.678073  ],
          [ 6.6737056 ]],

         ...,

         [[-2.5260692 ],
          [ 7.5223565 ],
          [-2.2425637 ],
          ...,
          [ 9.645262  ],
          [ 6.4849277 ],
          [ 9.440056  ]],

         [[-2.5516343 ],
          [ 8.369431  ],
          [ 1.5200534 ],
          ...,
          [ 6.327391  ],
          [ 7.257959  ],
          [ 5.6582713 ]],

         [[ 6.983667  ],
          [ 4.9142737 ],
          [ 2.4202938 ],
          ...,
          [ 4.82819   ],
          [ 5.6051173 ],
          