First create a direct access to /datasets folder in your personal drive

In [None]:
# Mount drive if needed
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive



## Install dependencies

In [None]:
! pip install SimpleITK
! pip install antspyx


## Load images to current session

In [None]:
! mkdir -v data
! unzip "/content/drive/MyDrive/integradora_fiec/datasets/NATIVE_FILTERED_MANUALLY.zip" -d "/data"

## Preprocessing steps functions

In [None]:
%matplotlib inline
import os
import ants
import SimpleITK as sitk

print(f'AntsPy version = {ants.__version__}')
print(f'SimpleITK version = {sitk.__version__}')

AntsPy version = 0.3.8
SimpleITK version = 2.2.1


In [None]:
mni_T1_path = TEMPLATE_PATH = '/content/drive/MyDrive/integradora_fiec/datasets/templates/mni_icbm152_t1_tal_nlin_sym_09a.nii'

def load_template_ants() -> ants.ANTsImage:
    """
    Loads the template image using the ANTs library.

    Returns:
        ants.ANTsImage: The loaded template image as an ANTsImage object.
    """
    template_img_ants = ants.image_read(TEMPLATE_PATH)
    return template_img_ants

def load_img_ants(path: str) -> ants.ANTsImage:
    """
    Loads an image from the specified file path using the ANTs library.

    Args:
        path (str): The file path of the image to be loaded.

    Returns:
        ants.ANTsImage: The loaded image as an ANTsImage object.
    """
    raw_img_ants = ants.image_read(path)
    return raw_img_ants

def register_to_mni(img: ants.ANTsImage, mask: ants.ANTsImage) -> ants.ANTsImage:
    """
    Registers an MRI image and its associated mask to the MNI space using ANTs library.

    Args:
        img (ants.ANTsImage): The MRI image to be registered.
        mask (ants.ANTsImage): The mask associated with the MRI image.

    Returns:
        ants.ANTsImage: The registered MRI image in MNI space.
        ants.ANTsImage: The registered mask in MNI space.
    """
    template_img = load_template_ants()
    transformation = ants.registration(fixed=template_img, moving=img, type_of_transform='SyN')

    img_registered = transformation['warpedmovout']
  
    mask_registered = ants.apply_transforms(fixed=template_img,moving=mask,transformlist=transformation['fwdtransforms'])
    return img_registered, mask_registered

## Register

In [None]:
from glob import glob

xpaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*01.nii.gz') )
ypaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*_LesionSmooth.nii.gz'))
assert len(xpaths) == len(ypaths)

In [None]:
print("Number of samples:", len(xpaths))
for input_path, target_path in zip(xpaths, ypaths):
    print(input_path[-35:], "|", target_path[-48:])

In [None]:
# Loop over the pairs of input and target file paths and perform registration to MNI space
for i,(xpath, ypath) in enumerate(zip(xpaths, ypaths)):
  folder = xpath[:-20]
  file_name = xpath[:-7][-13:]
  x_registered_path = folder + file_name + '_registered.nii.gz'
  y_registered_path = folder + file_name + '_LesionSmooth_registered.nii.gz'

  x3d = load_img_ants(xpath)
  y3d = load_img_ants(ypath)

  x3d_registered, y3d_registered = register_to_mni(img=x3d,mask=y3d)

  print(i, x_registered_path)
  print(i, y_registered_path)

  x3d_registered.to_file(x_registered_path)
  y3d_registered.to_file(y_registered_path)

  #if i == 0 : break


## Bias Field Correction

In [None]:
xpaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*01_registered.nii.gz') )
ypaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*_LesionSmooth_registered.nii.gz'))
assert len(xpaths) == len(ypaths)

In [None]:
print("Number of samples:", len(xpaths))
for input_path, target_path in zip(xpaths, ypaths):
    print(input_path[-35:], "|", target_path[-48:])

In [None]:
def bias_field_correction(img: sitk.Image) -> sitk.Image:
    """
    Perform bias field correction on the input image using N4BiasFieldCorrection.

    Args:
        img (sitk.Image): The input image to be bias corrected.

    Returns:
        sitk.Image: The bias-corrected image.
    """
    head_mask = sitk.RescaleIntensity(img, 0, 255)
    head_mask = sitk.LiThreshold(head_mask,0,1)

    shrinkFactor = 4
    inputImage = img
    inputImage = sitk.Shrink( img, [ shrinkFactor ] * inputImage.GetDimension() )
    maskImage = sitk.Shrink( head_mask, [ shrinkFactor ] * inputImage.GetDimension() )

    bias_corrector = sitk.N4BiasFieldCorrectionImageFilter()
    bias_corrector.Execute(inputImage, maskImage)

    log_bias_field = bias_corrector.GetLogBiasFieldAsImage(img)
    result = img / sitk.Exp( log_bias_field ) # corrected img at full resolution

    # output of division has 64 pixel type, we cast it to float32 to keep compatibility
    result = sitk.Cast(result, sitk.sitkFloat32)
    
    return result

def load_img_sitk(path: str) -> sitk.Image:
    """
    Load an image using SimpleITK (sitk) and return it.

    Args:
        path (str): The path to the image file.

    Returns:
        sitk.Image: The loaded image as a SimpleITK Image object.
    """
    raw_img_sitk = sitk.ReadImage(path, sitk.sitkFloat32)
    return raw_img_sitk


In [None]:
for i,(xpath, ypath) in enumerate(zip(xpaths, ypaths)):
  # Extract folder and file name information from the input and target file paths
  folder = xpath[:-20]
  file_name = xpath[:-7][-13:]
  # Create the output path for the bias field corrected image
  x_out_path = folder + file_name + '_BF.nii.gz'

  x3d = load_img_sitk(xpath)
  x3d_bf_corrected = bias_field_correction(x3d)

  sitk.WriteImage(x3d_bf_corrected, x_out_path)
  # Print the progress (index) and the output path of the bias field corrected image
  print(i, x_out_path)

  #if i == 0 : break



## Prepare training data

In [None]:
xpaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*01_registered_BF.nii.gz') )
ypaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*_LesionSmooth_registered.nii.gz'))
assert len(xpaths) == len(ypaths)

In [None]:
print("Number of samples:", len(xpaths))
for input_path, target_path in zip(xpaths, ypaths):
    print(input_path[-35:], "|", target_path[-48:])

In [None]:
# load mni152 brain mask
TEMPLATE_BRAIN_MASK_PATH = '/content/drive/MyDrive/integradora_fiec/datasets/templates/mni_icbm152_t1_tal_nlin_sym_09a_mask.nii'
mni152_brain_mask = sitk.ReadImage(TEMPLATE_BRAIN_MASK_PATH, sitk.sitkFloat32)
mni152_T1 = sitk.ReadImage(TEMPLATE_PATH, sitk.sitkFloat32)


In [None]:
def preprocess_ximg(ximg: sitk.Image, flipped = False) -> np.ndarray:
  """
    Preprocess the input image (ximg) using several SimpleITK image processing operations.
    
    Args:
        ximg (sitk.Image): The input image in SimpleITK format.
        flipped (bool, optional): Flag to determine whether to flip the image or not.
            Defaults to False.

    Returns:
        np.ndarray: The preprocessed 3D numpy array representing the image.

    """
  x3d = sitk.HistogramMatching(ximg, mni152_T1)
  x3d = sitk.Multiply(x3d, mni152_brain_mask) # mask brain
  x3d = sitk.CurvatureAnisotropicDiffusion(x3d, conductanceParameter=1, numberOfIterations=1) # denoise a bit
  
  if flipped:
    x3d = sitk.Flip(x3d,(True, False, False))
  
  x3d = sitk.GetArrayFromImage(x3d)
  x3d = x3d[30:160,4:228,14:190] # crop to size -> (130, 224, 176)
  x3d = x3d / 255.0
  x3d = np.expand_dims(x3d,3) # add channel -> (130, 224, 176, 1)
  assert x3d.shape == (130,224,176,1)
  return x3d

def preprocess_yimg(yimg: sitk.Image, flipped=False) -> np.ndarray:
  """
    Preprocess the target image (yimg) for segmentation using SimpleITK image processing operations.
    
    Args:
        yimg (sitk.Image): The target image in SimpleITK format.
        flipped (bool, optional): Flag to determine whether to flip the image or not.
            Defaults to False.

    Returns:
        np.ndarray: The preprocessed 3D numpy array representing the target segmentation.

    """
  y3d = yimg

  if flipped:
    y3d = sitk.Flip(y3d,(True, False, False))
  
  y3d = sitk.GetArrayFromImage(y3d)
  y3d = y3d[30:160,4:228,14:190] # crop to size -> (130, 224, 176)
  y3d = y3d / 255.0
  y3d = np.expand_dims(y3d,3) # add channel -> (130, 224, 176, 1)
  assert x3d.shape == (130,224,176,1)
  return y3d


In [None]:
ROW_SIZE = 224 # Height of the model input
COL_SIZE = 176 # Width of the model input

X = np.empty((0,ROW_SIZE,COL_SIZE,1), dtype=np.float32) # Placeholder for preprocessed input images
Y = np.empty((0,ROW_SIZE,COL_SIZE,1), dtype=np.float32) # Placeholder for preprocessed label images

for i,(xpath, ypath) in enumerate(zip(xpaths, ypaths)):

    ximg        =   sitk.ReadImage(xpath, sitk.sitkFloat32)
    x3d         =  preprocess_ximg(ximg) 
    flipped_x3d =  preprocess_ximg(ximg, flipped=True)

    yimg        =   sitk.ReadImage(ypath, sitk.sitkFloat32)
    y3d         =  preprocess_yimg(yimg) 
    flipped_y3d =  preprocess_yimg(yimg, flipped=True)

    # Concatenate the preprocessed images (original and flipped) along the first axis (number of samples)
    x3d = np.concatenate((x3d, flipped_x3d), axis=0)
    y3d = np.concatenate((y3d, flipped_y3d), axis=0)

    # Ensure the shapes of the concatenated arrays are as expected
    assert x3d.shape  == (260,224,176, 1)
    assert y3d.shape  == (260,224,176, 1)

    X = np.concatenate((X, x3d), axis=0)
    Y = np.concatenate((Y, y3d), axis=0)

    print('.', end='')

....................

In [None]:
print(X.shape, Y.shape)

(5200, 224, 176, 1) (5200, 224, 176, 1)


In [None]:
X[:,:,:,0].shape

(5200, 224, 176)

## Double check slices

In [None]:
def get_x2d_marked(x2d,y2d):
  """
    Mark the input 2D image (x2d) using the contours of the label image (y2d).

    Args:
        x2d (np.ndarray): The 2D input image.
        y2d (np.ndarray): The 2D label image.

    Returns:
        np.ndarray: The marked version of the input image.
    """
  dilation_level = 4
  m = (y2d).astype('uint8')
  m = sitk.GetImageFromArray(m)
  m = sitk.BinaryDilate(m,(dilation_level,1,1))
  m = sitk.BinaryContour(m)

  x2d_marked = sitk.GetImageFromArray(x2d)
  x2d_marked = sitk.MaskNegated(x2d_marked, sitk.Cast(m,sitk.sitkFloat32))
  x2d_marked = sitk.GetArrayFromImage(x2d_marked)
  return x2d_marked

def show_slices(slices: list[np.ndarray], cmap: str ='gray'):
  """
    Display a list of image slices (2D arrays) as subplots in a single figure.

    Args:
        slices (list[np.ndarray]): A list of 2D numpy arrays representing image slices.
        cmap (str, optional): The colormap to be used for visualization. Defaults to 'gray'.
    """
  fig, axes = plt.subplots(len(slices), 1, figsize=(15,15))
  for i, slice in enumerate(slices):
    axes[i].imshow(slice, cmap=cmap)

In [None]:
STEPS = 150
c=0
for i in range(0,len(X),STEPS):
  x, y = X[i], Y[i]
  if len(np.unique(y)) == 1:
    continue
  x2d_marked = get_x2d_marked(x[:,:,0],y[:,:,0])
  show_slices([x2d_marked,x[:,:,0]])
  c+=1
  if c==10:
    break

Output hidden; open in https://colab.research.google.com to view.

## Save training dataset as npy

In [None]:
from numpy import save
# contains data processed from 20 native ATLAS imgs trough: register to mni, bias field, histogram matching, brain extraction, denoise
X_output_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_X.npy'
Y_output_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_Y.npy'


save(X_output_path, X)
save(Y_output_path, Y)

## Load train set [JUMP HERE IF DATA AVAILABLE]

In [None]:
from numpy import load
X_input_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_X.npy'
Y_input_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper_lesions extended/dataset_clinet_input_processed_Y.npy'

X = load(X_input_path)
Y = load(Y_input_path)

In [None]:
print(X.shape, Y.shape)

(5200, 224, 176, 1) (5200, 224, 176, 1)


In [None]:
from sklearn.model_selection import train_test_split

X_train, X_valid, y_train, y_valid = train_test_split(X, Y, test_size=0.2, random_state=42)
print(X_train.shape, y_train.shape)
print(X_valid.shape, y_valid.shape)

(4160, 224, 176, 1) (4160, 224, 176, 1)
(1040, 224, 176, 1) (1040, 224, 176, 1)


## Define Train Model (CLCI net)

In [None]:
from keras import *
from keras.layers import *
import tensorflow as tf
kernel_regularizer = regularizers.l2(1e-5)
bias_regularizer = regularizers.l2(1e-5)
kernel_regularizer = None
bias_regularizer = None

def conv_lstm(input1, input2, channel=256):
    """
    Creates a ConvLSTM2D layer by combining two input tensors.

    This function reshapes the input tensors to add a time dimension and then concatenates them along the time
    dimension. It then applies a ConvLSTM2D layer with specified channel size, kernel size, strides, and
    optional regularization.

    Args:
        input1 (tf.Tensor): The first input tensor to the ConvLSTM layer.
        input2 (tf.Tensor): The second input tensor to the ConvLSTM layer.
        channel (int): The number of output channels (filters) of the ConvLSTM layer. Default is 256.

    Returns:
        tf.Tensor: The output tensor of the ConvLSTM layer.

    """
    lstm_input1 = Reshape((1, input1.shape.as_list()[1], input1.shape.as_list()[2], input1.shape.as_list()[3]))(input1)
    lstm_input2 = Reshape((1, input2.shape.as_list()[1], input2.shape.as_list()[2], input1.shape.as_list()[3]))(input2)

    lstm_input = custom_concat(axis=1)([lstm_input1, lstm_input2])
    x = ConvLSTM2D(channel, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal', kernel_regularizer=kernel_regularizer)(lstm_input)
    return x

def conv_2(inputs, filter_num, kernel_size=(3,3), strides=(1,1), kernel_initializer='glorot_uniform', kernel_regularizer = kernel_regularizer):
    """
    Defines a 2D Convolution block with optional regularization.

    This function creates a 2D convolutional block consisting of two Conv2D layers with batch normalization
    and ReLU activation functions. It allows for specifying the number of filters (channels), kernel size,
    strides, kernel initializer, and an optional regularization.

    Args:
        inputs (tf.Tensor): The input tensor to the convolutional block.
        filter_num (int): The number of filters (channels) in the convolutional layers.
        kernel_size (tuple): The size of the convolutional kernel, specified as a tuple of two
                             integers representing the height and width of the kernel, respectively.
                             Default is (3, 3).
        strides (tuple): The strides of the convolution along the height and width dimensions,
                         specified as a tuple of two integers. Default is (1, 1).
        kernel_initializer (str): The initializer for the convolutional kernels. Options include 'glorot_uniform'
                                  and 'he_normal'. Default is 'glorot_uniform'.
        kernel_regularizer (tf.keras.regularizers.Regularizer or None): An optional regularization applied
                                                                       to the convolutional kernel. Default is None.

    Returns:
        tf.Tensor: The output tensor of the convolutional block.

    """
    conv_ = Conv2D(filter_num, kernel_size=kernel_size, strides=strides, padding='same', kernel_initializer=kernel_initializer, kernel_regularizer = kernel_regularizer)(inputs)
    conv_ = BatchNormalization()(conv_)
    conv_ = Activation('relu')(conv_)
    conv_ = Conv2D(filter_num, kernel_size=kernel_size, strides=strides, padding='same', kernel_initializer=kernel_initializer, kernel_regularizer = kernel_regularizer)(conv_)
    conv_ = BatchNormalization()(conv_)
    conv_ = Activation('relu')(conv_)   
    return conv_

def conv_2_init(inputs, filter_num, kernel_size=(3,3), strides=(1,1)):
    """
    Defines a 2D Convolution block with 'he_normal' kernel initializer.

    This function creates a 2D convolutional layer with a specified number of filters (channels),
    kernel size, strides, and 'he_normal' kernel initializer.

    Args:
        inputs (tf.Tensor): The input tensor to the convolutional layer.
        filter_num (int): The number of filters (channels) in the convolutional layer.
        kernel_size (tuple): The size of the convolutional kernel, specified as a tuple of two
                             integers representing the height and width of the kernel, respectively.
                             Default is (3, 3).
        strides (tuple): The strides of the convolution along the height and width dimensions,
                         specified as a tuple of two integers. Default is (1, 1).

    Returns:
        tf.Tensor: The output tensor of the convolutional block.

    """
    return conv_2(inputs, filter_num, kernel_size=kernel_size, strides=strides, kernel_initializer='he_normal', kernel_regularizer = kernel_regularizer) 

def conv_2_init_regularization(inputs, filter_num, kernel_size=(3,3), strides=(1,1)):
    """
    Defines a 2D Convolution block 

    This function creates a 2D convolutional layer with a specified number of filters (channels),
    kernel size, strides, 'he_normal' kernel initializer, and L2 regularization with a weight decay
    of 5e-4.

    Args:
        inputs (tf.Tensor): The input tensor to the convolutional layer.
        filter_num (int): The number of filters (channels) in the convolutional layer.
        kernel_size (tuple): The size of the convolutional kernel, specified as a tuple of two
                             integers representing the height and width of the kernel, respectively.
                             Default is (3, 3).
        strides (tuple): The strides of the convolution along the height and width dimensions,
                         specified as a tuple of two integers. Default is (1, 1).

    Returns:
        tf.Tensor: The output tensor of the convolutional block.

    """
    return conv_2(inputs, filter_num, kernel_size=kernel_size, strides=strides, kernel_initializer='he_normal', kernel_regularizer = regularizers.l2(5e-4)) 

def conv_1(inputs, filter_num, kernel_size=(3,3), strides=(1,1), kernel_initializer='glorot_uniform', kernel_regularizer = kernel_regularizer):
    """
    Defines a 2D Convolution block

    This function creates a 2D convolutional layer with a specified number of filters (channels),
    kernel size, strides, kernel initializer, and optional regularization.

    Args:
        inputs (tf.Tensor): The input tensor to the convolutional layer.
        filter_num (int): The number of filters (channels) in the convolutional layer.
        kernel_size (tuple): The size of the convolutional kernel, specified as a tuple of two
                             integers representing the height and width of the kernel, respectively.
                             Default is (3, 3).
        strides (tuple): The strides of the convolution along the height and width dimensions,
                         specified as a tuple of two integers. Default is (1, 1).
        kernel_initializer (str): The kernel initializer for the convolutional layer. It can be either
                                   'glorot_uniform' or 'he_normal' or any other valid initializer
                                   available in Keras. Default is 'glorot_uniform'.
        kernel_regularizer (tf.keras.regularizers.Regularizer or None): Optional regularization to be
                                                                       applied to the kernel weights of
                                                                       the convolutional layer.
                                                                       Default is None.

    Returns:
        tf.Tensor: The output tensor of the convolutional block.
    """
    conv_ = Conv2D(filter_num, kernel_size=kernel_size, strides=strides, padding='same', kernel_initializer=kernel_initializer, kernel_regularizer = kernel_regularizer)(inputs)
    conv_ = BatchNormalization()(conv_)
    conv_ = Activation('relu')(conv_)
    return conv_

def conv_1_init(inputs, filter_num, kernel_size=(3,3), strides=(1,1)):
    """
    Defines a 2D Convolution block

    This function creates a 2D convolutional layer with a specified number of filters (channels),
    kernel size, strides, and 'he_normal' kernel initializer. Additionally, it allows for optional
    regularization by providing the `kernel_regularizer` parameter, which can be set to None if no
    regularization is desired.

    Args:
        inputs (tf.Tensor): The input tensor to the convolutional layer.
        filter_num (int): The number of filters (channels) in the convolutional layer.
        kernel_size (tuple): The size of the convolutional kernel, specified as a tuple of two
                             integers representing the height and width of the kernel respectively.
                             Default is (3, 3).
        strides (tuple): The strides of the convolution along the height and width dimensions,
                         specified as a tuple of two integers. Default is (1, 1).

    Returns:
        tf.Tensor: The output tensor of the convolutional block.

    """
    return conv_1(inputs, filter_num, kernel_size=kernel_size, strides=strides, kernel_initializer='he_normal', kernel_regularizer = kernel_regularizer) 

def conv_1_init_regularization(inputs, filter_num, kernel_size=(3,3), strides=(1,1)):
    """
    Creates a 2D Convolution block

    This function defines a 2D convolutional layer with a specified number of filters (channels),
    kernel size, strides, 'he_normal' kernel initializer, and fixed L2 regularization with a
    regularization strength of 5e-4.

    Args:
        inputs (tf.Tensor): The input tensor to the convolutional layer.
        filter_num (int): The number of filters (channels) in the convolutional layer.
        kernel_size (tuple): The size of the convolutional kernel, specified as a tuple of two
                             integers representing the height and width of the kernel respectively.
                             Default is (3, 3).
        strides (tuple): The strides of the convolution along the height and width dimensions,
                         specified as a tuple of two integers. Default is (1, 1).

    Returns:
        tf.Tensor: The output tensor of the convolutional block.

    """
    return conv_1(inputs, filter_num, kernel_size=kernel_size, strides=strides, kernel_initializer='he_normal', kernel_regularizer = regularizers.l2(5e-4))

def dilate_conv(inputs, filter_num, dilation_rate):
    """
    Creates a dilated Conv2D layer.

    This function defines a 2D convolutional layer with a specified number of filters and
    dilation rate. 

    Args:
        inputs (tf.Tensor): The input tensor to the convolutional layer.
        filter_num (int): The number of filters (channels) in the convolutional layer.
        dilation_rate (tuple): The dilation rate for the convolutional layer, specified
                               as a tuple of two integers representing the dilation rate
                               along the height and width dimensions respectively.

    Returns:
        tf.Tensor: The output tensor of the dilated convolutional layer.

    """
    conv_ = Conv2D(filter_num, kernel_size=(3,3), dilation_rate=dilation_rate, padding='same', kernel_initializer='he_normal', kernel_regularizer = kernel_regularizer)(inputs)
    conv_ = BatchNormalization()(conv_)
    conv_ = Activation('relu')(conv_)
    return conv_

class custom_concat(Layer):
    """
    Custom Keras layer to perform concatenation along a specified axis.

    This layer takes multiple input tensors and concatenates them along the specified axis.
    The concatenation is performed element-wise along the axis, preserving the dimensions
    of all other axes.

    Args:
        axis (int): The axis along which to concatenate the inputs. The default value is -1,
                    which corresponds to the last axis.

    Attributes:
        axis (int): The axis along which the concatenation is performed.

    """
    def __init__(self, axis=-1, **kwargs):
        """
        Initializes the custom_concat layer.

        Args:
            axis (int): The axis along which to concatenate the inputs.

        """
        super(custom_concat, self).__init__(**kwargs)
        self.axis = axis

    def build(self, input_shape):
        """
        Builds the custom_concat layer.

        Args:
            input_shape (tuple): The shape of the input tensor(s).

        """
        # Create a trainable weight variable for this layer.
        self.built = True
        super(custom_concat, self).build(input_shape)  # Be sure to call this somewhere!

    def call(self, x):
        """
        Performs the concatenation operation on the input tensors.

        Args:
            x (list): A list of input tensors to be concatenated.

        Returns:
            tf.Tensor: The concatenated tensor.

        """
        self.res = tf.concat(x, self.axis)

        return self.res

    def compute_output_shape(self, input_shape):
        """
        Computes the output shape of the custom_concat layer.

        Args:
            input_shape (tuple): The shape of the input tensor(s).

        Returns:
            tuple: The shape of the concatenated output tensor.

        """
        input_shapes = input_shape
        output_shape = list(input_shapes[0])

        for shape in input_shapes[1:]:
            if output_shape[self.axis] is None or shape[self.axis] is None:
                output_shape[self.axis] = None
                break
            output_shape[self.axis] += shape[self.axis]

        return tuple(output_shape)


class BilinearUpsampling(Layer):
    def __init__(self, upsampling=(2, 2), **kwargs):
        """
        Initializes the BilinearUpsampling layer.

        Args:
            upsampling: A tuple specifying the upsampling factor along the height and width dimensions.
                       Default is (2, 2), i.e., upsampling by a factor of 2 in both height and width.
            **kwargs: Additional keyword arguments to pass to the base class (Layer).

        """
        super(BilinearUpsampling, self).__init__(**kwargs)       
        self.upsampling = upsampling
        
    def compute_output_shape(self, input_shape):
        """
        Computes the output shape of the layer based on the input shape.

        Args:
            input_shape: A tuple representing the input shape (batch_size, height, width, channels).

        Returns:
            Tuple representing the output shape (batch_size, new_height, new_width, channels)
            after applying the upsampling factor.

        """
        height = self.upsampling[0] * \
                 input_shape[1] if input_shape[1] is not None else None
        width = self.upsampling[1] * \
                input_shape[2] if input_shape[2] is not None else None
        return (input_shape[0],
                height,
                width,
                input_shape[3])

    def call(self, inputs):
        """
        Performs the upsampling operation using bilinear interpolation.

        Args:
            inputs: The input tensor (batch_size, height, width, channels).

        Returns:
            The upscaled tensor obtained using bilinear interpolation.

        """
        #return tf.image.resize_bilinear(inputs, (int(inputs.shape[1] * self.upsampling[0]),
        #                                           int(inputs.shape[2] * self.upsampling[1])))
        return tf.image.resize(inputs, (int(inputs.shape[1] * self.upsampling[0]),
                                                   int(inputs.shape[2] * self.upsampling[1])))



def concat_pool(conv, pool, filter_num, strides=(2, 2)):
    """
    Concatenates a Convolutional layer with a Pooling layer.

    Args:
        conv: Input Convolutional layer.
        pool: Input Pooling layer.
        filter_num: Number of filters for the Convolutional layer.
        strides: Strides for the Convolutional layer. Default is (2, 2).

    Returns:
        A concatenated layer obtained by concatenating the Convolutional layer and the Pooling layer.
    """
    conv_downsample = Conv2D(filter_num, (3, 3), strides=strides, padding='same', kernel_initializer='he_normal', kernel_regularizer=kernel_regularizer)(conv)
    conv_downsample = BatchNormalization()(conv_downsample)
    conv_downsample = Activation('relu')(conv_downsample)
    concat_pool_ = Concatenate()([conv_downsample, pool])
    return concat_pool_
######################################
from keras.optimizers import Adam
import keras.backend as K
#from custom_layer import *


def dice_coef(y_true, y_pred):
    """
    This function calculates the Dice coefficient, which is a metric commonly used in image segmentation tasks
    to evaluate the similarity between the predicted segmentation and the ground truth.

    Args:
        y_true (tf.Tensor): The ground truth segmentation mask.
        y_pred (tf.Tensor): The predicted segmentation mask.

    Returns:
        tf.Tensor: The Dice coefficient.
    """
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + 1)

def dice_coef_loss(y_true, y_pred):
    """
    Dice Coefficient Loss Function.

    Args:
        y_true (tf.Tensor): The ground truth segmentation mask.
        y_pred (tf.Tensor): The predicted segmentation mask.

    Returns:
        tf.Tensor: The Dice coefficient loss.
    """
    return 1. - dice_coef(y_true, y_pred)

def CLCI_Net(input_shape=(224, 176, 1), num_class=1):
    """
    Creates the CLCI_Net model for semantic segmentation.

    Args:
        input_shape: Tuple representing the shape of the input tensor (height, width, channels).
                     The row and column of the input should be resized or cropped to an integer multiple of 16.
        num_class: Number of classes for segmentation. For binary segmentation, num_class should be set to 1.

    Returns:
        Model: Keras model representing the CLCI_Net for semantic segmentation.

    """
    
    # The row and col of input should be resized or cropped to an integer multiple of 16.
    inputs = Input(shape=input_shape)

    conv1 = conv_2_init(inputs, 32)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    concat_pool11 = concat_pool(conv1, pool1, 32, strides=(2, 2))
    fusion1 = conv_1_init(concat_pool11, 64 * 4, kernel_size=(1, 1))

    conv2 = conv_2_init(fusion1, 64)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    concat_pool12 = concat_pool(conv1, pool2, 64, strides=(4, 4))
    concat_pool22 = concat_pool(conv2, concat_pool12, 64, strides=(2, 2))
    fusion2 = conv_1_init(concat_pool22, 128 * 4, kernel_size=(1, 1))

    conv3 = conv_2_init(fusion2, 128)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    concat_pool13 = concat_pool(conv1, pool3, 128, strides=(8, 8))
    concat_pool23 = concat_pool(conv2, concat_pool13, 128, strides=(4, 4))
    concat_pool33 = concat_pool(conv3, concat_pool23, 128, strides=(2, 2))
    fusion3 = conv_1_init(concat_pool33, 256 * 4, kernel_size=(1, 1))

    conv4 = conv_2_init(fusion3, 256)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    concat_pool14 = concat_pool(conv1, pool4, 256, strides=(16, 16))
    concat_pool24 = concat_pool(conv2, concat_pool14, 256, strides=(8, 8))
    concat_pool34 = concat_pool(conv3, concat_pool24, 256, strides=(4, 4))
    concat_pool44 = concat_pool(conv4, concat_pool34, 256, strides=(2, 2))
    fusion4 = conv_1_init(concat_pool44, 512 * 4, kernel_size=(1, 1))

    conv5 = conv_2_init(fusion4, 512)
    conv5 = Dropout(0.5)(conv5)

    clf_aspp = CLF_ASPP(conv5, conv1, conv2, conv3, conv4, input_shape)

    up_conv1 = UpSampling2D(size=(2, 2))(clf_aspp)
    up_conv1 = conv_1_init(up_conv1, 256, kernel_size=(2, 2))
    skip_conv4 = conv_1_init(conv4, 256, kernel_size=(1, 1))
    context_inference1 = conv_lstm(up_conv1, skip_conv4, channel=256)
    conv6 = conv_2_init(context_inference1, 256)

    up_conv2 = UpSampling2D(size=(2, 2))(conv6)
    up_conv2 = conv_1_init(up_conv2, 128, kernel_size=(2, 2))
    skip_conv3 = conv_1_init(conv3, 128, kernel_size=(1, 1))
    context_inference2 = conv_lstm(up_conv2, skip_conv3, channel=128)
    conv7 = conv_2_init(context_inference2, 128)

    up_conv3 = UpSampling2D(size=(2, 2))(conv7)
    up_conv3 = conv_1_init(up_conv3, 64, kernel_size=(2, 2))
    skip_conv2 = conv_1_init(conv2, 64, kernel_size=(1, 1))
    context_inference3 = conv_lstm(up_conv3, skip_conv2, channel=64)
    conv8 = conv_2_init(context_inference3, 64)

    up_conv4 = UpSampling2D(size=(2, 2))(conv8)
    up_conv4 = conv_1_init(up_conv4, 32, kernel_size=(2, 2))
    skip_conv1 = conv_1_init(conv1, 32, kernel_size=(1, 1))
    context_inference4 = conv_lstm(up_conv4, skip_conv1, channel=32)
    conv9 = conv_2_init(context_inference4, 32)


    if num_class == 1:
        conv10 = Conv2D(num_class, (1, 1), activation='sigmoid')(conv9)
    else:
        conv10 = Conv2D(num_class, (1, 1), activation='softmax')(conv9)

    model = Model(inputs=inputs, outputs=conv10)

    return model


def CLF_ASPP(conv5, conv1, conv2, conv3, conv4, input_shape):
    """
    Creates the ASPP (Atrous Spatial Pyramid Pooling) block.

    Args:
        conv5: Convolutional layer from the 5th encoder stage.
        conv1: Convolutional layer from the 1st encoder stage.
        conv2: Convolutional layer from the 2nd encoder stage.
        conv3: Convolutional layer from the 3rd encoder stage.
        conv4: Convolutional layer from the 4th encoder stage.
        input_shape: Shape of the input tensor (batch_size, height, width, channels).

    Returns:
        Output tensor after the ASPP block.

    """

    b0 = conv_1_init(conv5, 256, (1, 1))
    b1 = dilate_conv(conv5, 256, dilation_rate=(2, 2))
    b2 = dilate_conv(conv5, 256, dilation_rate=(4, 4))
    b3 = dilate_conv(conv5, 256, dilation_rate=(6, 6))

    out_shape0 = input_shape[0] // pow(2, 4)
    out_shape1 = input_shape[1] // pow(2, 4)
    b4 = AveragePooling2D(pool_size=(out_shape0, out_shape1))(conv5)
    b4 = conv_1_init(b4, 256, (1, 1))
    b4 = BilinearUpsampling((out_shape0, out_shape1))(b4)

    clf1 = conv_1_init(conv1, 256, strides=(16, 16))
    clf2 = conv_1_init(conv2, 256, strides=(8, 8))
    clf3 = conv_1_init(conv3, 256, strides=(4, 4))
    clf4 = conv_1_init(conv4, 256, strides=(2, 2))

    outs = Concatenate()([clf1, clf2, clf3, clf4, b0, b1, b2, b3, b4])

    outs = conv_1_init(outs, 256 * 4, (1, 1))
    outs = Dropout(0.5)(outs)

    return outs

## Training

In [None]:
from keras.metrics import  Recall, Precision
# https://stats.stackexchange.com/questions/323154/precision-vs-recall-acceptable-limits
# https://www.kdnuggets.com/2016/12/4-reasons-machine-learning-model-wrong.html#:~:text=Precision%20is%20a%20measure%20of,positive%20class%20are%20actually%20true.&text=Hence%2C%20a%20situation%20of%20Low,positive%20values%20are%20never%20predicted.
# Pre and Post processing # https://github.com/nikhilroxtomar/UNet-Segmentation-in-Keras-TensorFlow/blob/master/unet-segmentation.ipynb
model = CLCI_Net()
#model.summary()
model.compile(optimizer=Adam(lr=1e-4), loss=dice_coef_loss, metrics=[dice_coef,'acc',Recall(), Precision()])

  super().__init__(name, **kwargs)


In [None]:
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau

checkpoint_filepath = '/content/drive/MyDrive/integradora_fiec/modelos/clcinet-native-filtered-v2-{epoch:03d}-{dice_coef:03f}-{val_dice_coef:03f}.h5'
# Create a ModelCheckpoint callback to save the best model weights during training
model_checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_dice_coef',
    mode='max',
    save_best_only=True)
# Create a ReduceLROnPlateau callback to adjust the learning rate during training
reduce_lr = ReduceLROnPlateau(monitor='val_dice_coef', factor=0.2, patience=2, min_lr=2e-6)
# Create a list of callbacks to be used during training
callbacks = [
    model_checkpoint_callback,
    reduce_lr
]

In [None]:
history = model.fit(
      X_train, y_train,
      batch_size=8,
      epochs=60,
      verbose=1,
      callbacks=callbacks,
      validation_data=(X_valid,y_valid))

retrain after timeout...

In [None]:
model.load_weights("/content/drive/MyDrive/integradora_fiec/modelos/clcinet-native-filtered-v2-029-0.821825-0.787639.h5")

In [None]:
history2 = model.fit(
      X_train, y_train,
      batch_size=8,
      epochs=31,
      verbose=1,
      callbacks=callbacks,
      validation_data=(X_valid,y_valid))

## Next steps...

In [None]:
# It is good, fortunately we got same val_dice_coef as prev ~ 0.84
# So it seems now I need to do same preprocessing for all images(or only the 20?) between lacunar and mca.
# Make  a dataset (.npy)
# extract features from this dataset (csv)
# evaluate models perfomance ->
#