First create a direct access to /datasets folder in your personal drive

In [None]:
# Mount drive if needed
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive



## Install dependencies

In [None]:
! pip install SimpleITK
! pip install antspyx

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting SimpleITK
  Downloading SimpleITK-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 MB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.2.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting antspyx
  Downloading antspyx-0.3.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (326.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m326.3/326.3 MB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
Collecting chart-studio (from antspyx)
  Downloading chart_studio-1.1.0-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.4/64.4 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
Collecting retrying>


## Load images to current session

In [None]:
! mkdir -v data
! unzip "/content/drive/MyDrive/integradora_fiec/datasets/NATIVE_FILTERED_MANUALLY.zip" -d "/data"

mkdir: created directory 'data'
Archive:  /content/drive/MyDrive/integradora_fiec/datasets/NATIVE_FILTERED_MANUALLY.zip
   creating: /data/NATIVE_FILTERED_MANUALLY/test/
   creating: /data/NATIVE_FILTERED_MANUALLY/test/Lacunar/
   creating: /data/NATIVE_FILTERED_MANUALLY/test/Lacunar/c0003s0005t01/
  inflating: /data/NATIVE_FILTERED_MANUALLY/test/Lacunar/c0003s0005t01/c0003s0005t01.nii.gz  
  inflating: /data/NATIVE_FILTERED_MANUALLY/test/Lacunar/c0003s0005t01/c0003s0005t01_LesionRaw.voi  
  inflating: /data/NATIVE_FILTERED_MANUALLY/test/Lacunar/c0003s0005t01/c0003s0005t01_LesionSmooth.nii.gz  
  inflating: /data/NATIVE_FILTERED_MANUALLY/test/Lacunar/c0003s0005t01/c0003s0005t01_LesionSmooth.voi  
   creating: /data/NATIVE_FILTERED_MANUALLY/test/Lacunar/c0003s0012t01/
  inflating: /data/NATIVE_FILTERED_MANUALLY/test/Lacunar/c0003s0012t01/c0003s0012t01.nii.gz  
  inflating: /data/NATIVE_FILTERED_MANUALLY/test/Lacunar/c0003s0012t01/c0003s0012t01_LesionRaw.voi  
  inflating: /data/NATIVE_F

## Preprocessing steps functions

In [None]:
%matplotlib inline
import os
import ants
import SimpleITK as sitk

print(f'AntsPy version = {ants.__version__}')
print(f'SimpleITK version = {sitk.__version__}')

AntsPy version = 0.3.8
SimpleITK version = 2.2.1


In [None]:
mni_T1_path = TEMPLATE_PATH = '/content/drive/MyDrive/integradora_fiec/datasets/templates/mni_icbm152_t1_tal_nlin_sym_09a.nii'

def load_template_ants() -> ants.ANTsImage:
    template_img_ants = ants.image_read(TEMPLATE_PATH)
    return template_img_ants

def load_img_ants(path: str) -> ants.ANTsImage:
    raw_img_ants = ants.image_read(path)
    return raw_img_ants

def register_to_mni(img: ants.ANTsImage, mask: ants.ANTsImage) -> ants.ANTsImage:
    """register ants mri image and mask to mni space"""
    template_img = load_template_ants()
    transformation = ants.registration(fixed=template_img, moving=img, type_of_transform='SyN')

    img_registered = transformation['warpedmovout']
  
    mask_registered = ants.apply_transforms(fixed=template_img,moving=mask,transformlist=transformation['fwdtransforms'])
    return img_registered, mask_registered

## Register

In [None]:
from glob import glob

xpaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*01.nii.gz') )
ypaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*_LesionSmooth.nii.gz'))
assert len(xpaths) == len(ypaths)

In [None]:
print("Number of samples:", len(xpaths))
for input_path, target_path in zip(xpaths, ypaths):
    print(input_path[-35:], "|", target_path[-48:])

Number of samples: 20
/c0003s0003t01/c0003s0003t01.nii.gz | /c0003s0003t01/c0003s0003t01_LesionSmooth.nii.gz
/c0003s0014t01/c0003s0014t01.nii.gz | /c0003s0014t01/c0003s0014t01_LesionSmooth.nii.gz
/c0003s0015t01/c0003s0015t01.nii.gz | /c0003s0015t01/c0003s0015t01_LesionSmooth.nii.gz
/c0003s0019t01/c0003s0019t01.nii.gz | /c0003s0019t01/c0003s0019t01_LesionSmooth.nii.gz
/c0003s0020t01/c0003s0020t01.nii.gz | /c0003s0020t01/c0003s0020t01_LesionSmooth.nii.gz
/c0004s0003t01/c0004s0003t01.nii.gz | /c0004s0003t01/c0004s0003t01_LesionSmooth.nii.gz
/c0005s0010t01/c0005s0010t01.nii.gz | /c0005s0010t01/c0005s0010t01_LesionSmooth.nii.gz
/c0005s0042t01/c0005s0042t01.nii.gz | /c0005s0042t01/c0005s0042t01_LesionSmooth.nii.gz
/c0006s0010t01/c0006s0010t01.nii.gz | /c0006s0010t01/c0006s0010t01_LesionSmooth.nii.gz
/c0007s0010t01/c0007s0010t01.nii.gz | /c0007s0010t01/c0007s0010t01_LesionSmooth.nii.gz
/c0003s0030t01/c0003s0030t01.nii.gz | /c0003s0030t01/c0003s0030t01_LesionSmooth.nii.gz
/c0004s0011t01/c0004s

In [None]:
for i,(xpath, ypath) in enumerate(zip(xpaths, ypaths)):
  folder = xpath[:-20]
  file_name = xpath[:-7][-13:]
  x_registered_path = folder + file_name + '_registered.nii.gz'
  y_registered_path = folder + file_name + '_LesionSmooth_registered.nii.gz'

  x3d = load_img_ants(xpath)
  y3d = load_img_ants(ypath)

  x3d_registered, y3d_registered = register_to_mni(img=x3d,mask=y3d)

  print(i, x_registered_path)
  print(i, y_registered_path)

  x3d_registered.to_file(x_registered_path)
  y3d_registered.to_file(y_registered_path)

  #if i == 0 : break


0 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0003t01/c0003s0003t01_registered.nii.gz
0 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0003t01/c0003s0003t01_LesionSmooth_registered.nii.gz
1 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0014t01/c0003s0014t01_registered.nii.gz
1 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0014t01/c0003s0014t01_LesionSmooth_registered.nii.gz
2 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0015t01/c0003s0015t01_registered.nii.gz
2 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0015t01/c0003s0015t01_LesionSmooth_registered.nii.gz
3 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0019t01/c0003s0019t01_registered.nii.gz
3 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0019t01/c0003s0019t01_LesionSmooth_registered.nii.gz
4 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0020t01/c0003s0020t01_registered.nii.gz
4 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0020t01/c0003s0020t01_LesionSmooth_registered.nii.gz
5 /da

## Bias Field Correction

In [None]:
xpaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*01_registered.nii.gz') )
ypaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*_LesionSmooth_registered.nii.gz'))
assert len(xpaths) == len(ypaths)

In [None]:
print("Number of samples:", len(xpaths))
for input_path, target_path in zip(xpaths, ypaths):
    print(input_path[-35:], "|", target_path[-48:])

Number of samples: 20
t01/c0003s0003t01_registered.nii.gz | t01/c0003s0003t01_LesionSmooth_registered.nii.gz
t01/c0003s0014t01_registered.nii.gz | t01/c0003s0014t01_LesionSmooth_registered.nii.gz
t01/c0003s0015t01_registered.nii.gz | t01/c0003s0015t01_LesionSmooth_registered.nii.gz
t01/c0003s0019t01_registered.nii.gz | t01/c0003s0019t01_LesionSmooth_registered.nii.gz
t01/c0003s0020t01_registered.nii.gz | t01/c0003s0020t01_LesionSmooth_registered.nii.gz
t01/c0004s0003t01_registered.nii.gz | t01/c0004s0003t01_LesionSmooth_registered.nii.gz
t01/c0005s0010t01_registered.nii.gz | t01/c0005s0010t01_LesionSmooth_registered.nii.gz
t01/c0005s0042t01_registered.nii.gz | t01/c0005s0042t01_LesionSmooth_registered.nii.gz
t01/c0006s0010t01_registered.nii.gz | t01/c0006s0010t01_LesionSmooth_registered.nii.gz
t01/c0007s0010t01_registered.nii.gz | t01/c0007s0010t01_LesionSmooth_registered.nii.gz
t01/c0003s0030t01_registered.nii.gz | t01/c0003s0030t01_LesionSmooth_registered.nii.gz
t01/c0004s0011t01_reg

In [None]:
def bias_field_correction(img: sitk.Image) -> sitk.Image:
    head_mask = sitk.RescaleIntensity(img, 0, 255)
    head_mask = sitk.LiThreshold(head_mask,0,1)

    shrinkFactor = 4
    inputImage = img
    inputImage = sitk.Shrink( img, [ shrinkFactor ] * inputImage.GetDimension() )
    maskImage = sitk.Shrink( head_mask, [ shrinkFactor ] * inputImage.GetDimension() )

    bias_corrector = sitk.N4BiasFieldCorrectionImageFilter()
    bias_corrector.Execute(inputImage, maskImage)

    log_bias_field = bias_corrector.GetLogBiasFieldAsImage(img)
    result = img / sitk.Exp( log_bias_field ) # corrected img at full resolution

    # output of division has 64 pixel type, we cast it to float32 to keep compatibility
    result = sitk.Cast(result, sitk.sitkFloat32)
    
    return result

def load_img_sitk(path: str) -> sitk.Image:
    raw_img_sitk = sitk.ReadImage(path, sitk.sitkFloat32)
    return raw_img_sitk


In [None]:
for i,(xpath, ypath) in enumerate(zip(xpaths, ypaths)):
  folder = xpath[:-20]
  file_name = xpath[:-7][-13:]
  x_out_path = folder + file_name + '_BF.nii.gz'

  x3d = load_img_sitk(xpath)
  x3d_bf_corrected = bias_field_correction(x3d)

  sitk.WriteImage(x3d_bf_corrected, x_out_path)

  print(i, x_out_path)

  #if i == 0 : break



0 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0003t01/c0003s0003t01_registered_BF.nii.gz
1 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0014t01/c0003s0014t01_registered_BF.nii.gz
2 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0015t01/c0003s0015t01_registered_BF.nii.gz
3 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0019t01/c0003s0019t01_registered_BF.nii.gz
4 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0003s0020t01/c0003s0020t01_registered_BF.nii.gz
5 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0004s0003t01/c0004s0003t01_registered_BF.nii.gz
6 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0005s0010t01/c0005s0010t01_registered_BF.nii.gz
7 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0005s0042t01/c0005s0042t01_registered_BF.nii.gz
8 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0006s0010t01/c0006s0010t01_registered_BF.nii.gz
9 /data/NATIVE_FILTERED_MANUALLY/train/Lacunar/c0007s0010t01/c0007s0010t01_registered_BF.nii.gz
10 /data/NATIVE_FILTERED_MANUALLY/train/

## Prepare training data

In [None]:
xpaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*01_registered_BF.nii.gz') )
ypaths = sorted(glob(f'/data/NATIVE_FILTERED_MANUALLY/train/*/*/*_LesionSmooth_registered.nii.gz'))
assert len(xpaths) == len(ypaths)

In [None]:
print("Number of samples:", len(xpaths))
for input_path, target_path in zip(xpaths, ypaths):
    print(input_path[-35:], "|", target_path[-48:])

Number of samples: 20
/c0003s0003t01_registered_BF.nii.gz | t01/c0003s0003t01_LesionSmooth_registered.nii.gz
/c0003s0014t01_registered_BF.nii.gz | t01/c0003s0014t01_LesionSmooth_registered.nii.gz
/c0003s0015t01_registered_BF.nii.gz | t01/c0003s0015t01_LesionSmooth_registered.nii.gz
/c0003s0019t01_registered_BF.nii.gz | t01/c0003s0019t01_LesionSmooth_registered.nii.gz
/c0003s0020t01_registered_BF.nii.gz | t01/c0003s0020t01_LesionSmooth_registered.nii.gz
/c0004s0003t01_registered_BF.nii.gz | t01/c0004s0003t01_LesionSmooth_registered.nii.gz
/c0005s0010t01_registered_BF.nii.gz | t01/c0005s0010t01_LesionSmooth_registered.nii.gz
/c0005s0042t01_registered_BF.nii.gz | t01/c0005s0042t01_LesionSmooth_registered.nii.gz
/c0006s0010t01_registered_BF.nii.gz | t01/c0006s0010t01_LesionSmooth_registered.nii.gz
/c0007s0010t01_registered_BF.nii.gz | t01/c0007s0010t01_LesionSmooth_registered.nii.gz
/c0003s0030t01_registered_BF.nii.gz | t01/c0003s0030t01_LesionSmooth_registered.nii.gz
/c0004s0011t01_regist

In [None]:
# load mni152 brain mask
TEMPLATE_BRAIN_MASK_PATH = '/content/drive/MyDrive/integradora_fiec/datasets/templates/mni_icbm152_t1_tal_nlin_sym_09a_mask.nii'
mni152_brain_mask = sitk.ReadImage(TEMPLATE_BRAIN_MASK_PATH, sitk.sitkFloat32)
mni152_T1 = sitk.ReadImage(TEMPLATE_PATH, sitk.sitkFloat32)


In [None]:
def preprocess_ximg(ximg: sitk.Image, flipped = False) -> np.ndarray:
  x3d = sitk.HistogramMatching(ximg, mni152_T1)
  x3d = sitk.Multiply(x3d, mni152_brain_mask) # mask brain
  x3d = sitk.CurvatureAnisotropicDiffusion(x3d, conductanceParameter=1, numberOfIterations=1) # denoise a bit
  
  if flipped:
    x3d = sitk.Flip(x3d,(True, False, False))
  
  x3d = sitk.GetArrayFromImage(x3d)
  x3d = x3d[30:160,4:228,14:190] # crop to size -> (130, 224, 176)
  x3d = x3d / 255.0
  x3d = np.expand_dims(x3d,3) # add channel -> (130, 224, 176, 1)
  assert x3d.shape == (130,224,176,1)
  return x3d

def preprocess_yimg(yimg: sitk.Image, flipped=False) -> np.ndarray:
  y3d = yimg

  if flipped:
    y3d = sitk.Flip(y3d,(True, False, False))
  
  y3d = sitk.GetArrayFromImage(y3d)
  y3d = y3d[30:160,4:228,14:190] # crop to size -> (130, 224, 176)
  y3d = y3d / 255.0
  y3d = np.expand_dims(y3d,3) # add channel -> (130, 224, 176, 1)
  assert x3d.shape == (130,224,176,1)
  return y3d


In [None]:
ROW_SIZE = 224 # shapes of model inpput
COL_SIZE = 176

X = np.empty((0,ROW_SIZE,COL_SIZE,1), dtype=np.float32)
Y = np.empty((0,ROW_SIZE,COL_SIZE,1), dtype=np.float32)

for i,(xpath, ypath) in enumerate(zip(xpaths, ypaths)):

    ximg        =   sitk.ReadImage(xpath, sitk.sitkFloat32)
    x3d         =  preprocess_ximg(ximg) 
    flipped_x3d =  preprocess_ximg(ximg, flipped=True)

    yimg        =   sitk.ReadImage(ypath, sitk.sitkFloat32)
    y3d         =  preprocess_yimg(yimg) 
    flipped_y3d =  preprocess_yimg(yimg, flipped=True)
    
    x3d = np.concatenate((x3d, flipped_x3d), axis=0)
    y3d = np.concatenate((y3d, flipped_y3d), axis=0)

    assert x3d.shape  == (260,224,176, 1)
    assert y3d.shape  == (260,224,176, 1)

    X = np.concatenate((X, x3d), axis=0)
    Y = np.concatenate((Y, y3d), axis=0)

    print('.', end='')

....................

In [None]:
print(X.shape, Y.shape)

(5200, 224, 176, 1) (5200, 224, 176, 1)


In [None]:
X[:,:,:,0].shape

(5200, 224, 176)

## Double check slices

In [None]:
def get_x2d_marked(x2d,y2d):
  dilation_level = 4
  m = (y2d).astype('uint8')
  m = sitk.GetImageFromArray(m)
  m = sitk.BinaryDilate(m,(dilation_level,1,1))
  m = sitk.BinaryContour(m)

  x2d_marked = sitk.GetImageFromArray(x2d)
  x2d_marked = sitk.MaskNegated(x2d_marked, sitk.Cast(m,sitk.sitkFloat32))
  x2d_marked = sitk.GetArrayFromImage(x2d_marked)
  return x2d_marked

def show_slices(slices: list[np.ndarray], cmap: str ='gray'):
  """ 
  Function to display a list of image slices (2D arrays). Optimal quantity is three slices.
  """
  fig, axes = plt.subplots(len(slices), 1, figsize=(15,15))
  for i, slice in enumerate(slices):
    axes[i].imshow(slice, cmap=cmap)

In [None]:
STEPS = 150
c=0
for i in range(0,len(X),STEPS):
  x, y = X[i], Y[i]
  if len(np.unique(y)) == 1:
    continue
  x2d_marked = get_x2d_marked(x[:,:,0],y[:,:,0])
  show_slices([x2d_marked,x[:,:,0]])
  c+=1
  if c==10:
    break

Output hidden; open in https://colab.research.google.com to view.

## Save training dataset as npy

In [None]:
from numpy import save
# contains data processed from 20 native ATLAS imgs trough: register to mni, bias field, histogram matching, brain extraction, denoise
X_output_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_X.npy'
Y_output_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_Y.npy'


save(X_output_path, X)
save(Y_output_path, Y)

## Load train set [JUMP HERE IF DATA AVAILABLE]

In [None]:
from numpy import load
X_input_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_X.npy'
Y_input_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper_lesions extended/dataset_clinet_input_processed_Y.npy'

X = load(X_input_path)
Y = load(Y_input_path)

In [None]:
print(X.shape, Y.shape)

(5200, 224, 176, 1) (5200, 224, 176, 1)


In [None]:
from sklearn.model_selection import train_test_split

X_train, X_valid, y_train, y_valid = train_test_split(X, Y, test_size=0.2, random_state=42)
print(X_train.shape, y_train.shape)
print(X_valid.shape, y_valid.shape)

(4160, 224, 176, 1) (4160, 224, 176, 1)
(1040, 224, 176, 1) (1040, 224, 176, 1)


## Define Train Model (CLCI net)

In [None]:
from keras import *
from keras.layers import *
import tensorflow as tf
kernel_regularizer = regularizers.l2(1e-5)
bias_regularizer = regularizers.l2(1e-5)
kernel_regularizer = None
bias_regularizer = None

def conv_lstm(input1, input2, channel=256):
    lstm_input1 = Reshape((1, input1.shape.as_list()[1], input1.shape.as_list()[2], input1.shape.as_list()[3]))(input1)
    lstm_input2 = Reshape((1, input2.shape.as_list()[1], input2.shape.as_list()[2], input1.shape.as_list()[3]))(input2)

    lstm_input = custom_concat(axis=1)([lstm_input1, lstm_input2])
    x = ConvLSTM2D(channel, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal', kernel_regularizer=kernel_regularizer)(lstm_input)
    return x

def conv_2(inputs, filter_num, kernel_size=(3,3), strides=(1,1), kernel_initializer='glorot_uniform', kernel_regularizer = kernel_regularizer):
    conv_ = Conv2D(filter_num, kernel_size=kernel_size, strides=strides, padding='same', kernel_initializer=kernel_initializer, kernel_regularizer = kernel_regularizer)(inputs)
    conv_ = BatchNormalization()(conv_)
    conv_ = Activation('relu')(conv_)
    conv_ = Conv2D(filter_num, kernel_size=kernel_size, strides=strides, padding='same', kernel_initializer=kernel_initializer, kernel_regularizer = kernel_regularizer)(conv_)
    conv_ = BatchNormalization()(conv_)
    conv_ = Activation('relu')(conv_)   
    return conv_

def conv_2_init(inputs, filter_num, kernel_size=(3,3), strides=(1,1)):
    return conv_2(inputs, filter_num, kernel_size=kernel_size, strides=strides, kernel_initializer='he_normal', kernel_regularizer = kernel_regularizer) 

def conv_2_init_regularization(inputs, filter_num, kernel_size=(3,3), strides=(1,1)):
    return conv_2(inputs, filter_num, kernel_size=kernel_size, strides=strides, kernel_initializer='he_normal', kernel_regularizer = regularizers.l2(5e-4)) 

def conv_1(inputs, filter_num, kernel_size=(3,3), strides=(1,1), kernel_initializer='glorot_uniform', kernel_regularizer = kernel_regularizer):
    conv_ = Conv2D(filter_num, kernel_size=kernel_size, strides=strides, padding='same', kernel_initializer=kernel_initializer, kernel_regularizer = kernel_regularizer)(inputs)
    conv_ = BatchNormalization()(conv_)
    conv_ = Activation('relu')(conv_)
    return conv_

def conv_1_init(inputs, filter_num, kernel_size=(3,3), strides=(1,1)):
    return conv_1(inputs, filter_num, kernel_size=kernel_size, strides=strides, kernel_initializer='he_normal', kernel_regularizer = kernel_regularizer) 

def conv_1_init_regularization(inputs, filter_num, kernel_size=(3,3), strides=(1,1)):
    return conv_1(inputs, filter_num, kernel_size=kernel_size, strides=strides, kernel_initializer='he_normal', kernel_regularizer = regularizers.l2(5e-4))

def dilate_conv(inputs, filter_num, dilation_rate):
    conv_ = Conv2D(filter_num, kernel_size=(3,3), dilation_rate=dilation_rate, padding='same', kernel_initializer='he_normal', kernel_regularizer = kernel_regularizer)(inputs)
    conv_ = BatchNormalization()(conv_)
    conv_ = Activation('relu')(conv_)
    return conv_

class custom_concat(Layer):

    def __init__(self, axis=-1, **kwargs):
        super(custom_concat, self).__init__(**kwargs)
        self.axis = axis

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.built = True
        super(custom_concat, self).build(input_shape)  # Be sure to call this somewhere!

    def call(self, x):
        self.res = tf.concat(x, self.axis)

        return self.res

    def compute_output_shape(self, input_shape):
        # return (input_shape[0][0],)+(len(input_shape),)+input_shape[0][2:]
        # print((input_shape[0][0],)+(len(input_shape),)+input_shape[0][2:])
        input_shapes = input_shape
        output_shape = list(input_shapes[0])

        for shape in input_shapes[1:]:
            if output_shape[self.axis] is None or shape[self.axis] is None:
                output_shape[self.axis] = None
                break
            output_shape[self.axis] += shape[self.axis]

        return tuple(output_shape)


class BilinearUpsampling(Layer):
    def __init__(self, upsampling=(2, 2), **kwargs):
        super(BilinearUpsampling, self).__init__(**kwargs)       
        self.upsampling = upsampling
        
    def compute_output_shape(self, input_shape):
        height = self.upsampling[0] * \
                 input_shape[1] if input_shape[1] is not None else None
        width = self.upsampling[1] * \
                input_shape[2] if input_shape[2] is not None else None
        return (input_shape[0],
                height,
                width,
                input_shape[3])

    def call(self, inputs):
        #return tf.image.resize_bilinear(inputs, (int(inputs.shape[1] * self.upsampling[0]),
        #                                           int(inputs.shape[2] * self.upsampling[1])))
        return tf.image.resize(inputs, (int(inputs.shape[1] * self.upsampling[0]),
                                                   int(inputs.shape[2] * self.upsampling[1])))



def concat_pool(conv, pool, filter_num, strides=(2, 2)):
    conv_downsample = Conv2D(filter_num, (3, 3), strides=strides, padding='same', kernel_initializer='he_normal', kernel_regularizer=kernel_regularizer)(conv)
    conv_downsample = BatchNormalization()(conv_downsample)
    conv_downsample = Activation('relu')(conv_downsample)
    concat_pool_ = Concatenate()([conv_downsample, pool])
    return concat_pool_
######################################
from keras.optimizers import Adam
import keras.backend as K
#from custom_layer import *


def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + 1)

def dice_coef_loss(y_true, y_pred):
    return 1. - dice_coef(y_true, y_pred)

def CLCI_Net(input_shape=(224, 176, 1), num_class=1):
    # The row and col of input should be resized or cropped to an integer multiple of 16.
    inputs = Input(shape=input_shape)

    conv1 = conv_2_init(inputs, 32)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    concat_pool11 = concat_pool(conv1, pool1, 32, strides=(2, 2))
    fusion1 = conv_1_init(concat_pool11, 64 * 4, kernel_size=(1, 1))

    conv2 = conv_2_init(fusion1, 64)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    concat_pool12 = concat_pool(conv1, pool2, 64, strides=(4, 4))
    concat_pool22 = concat_pool(conv2, concat_pool12, 64, strides=(2, 2))
    fusion2 = conv_1_init(concat_pool22, 128 * 4, kernel_size=(1, 1))

    conv3 = conv_2_init(fusion2, 128)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    concat_pool13 = concat_pool(conv1, pool3, 128, strides=(8, 8))
    concat_pool23 = concat_pool(conv2, concat_pool13, 128, strides=(4, 4))
    concat_pool33 = concat_pool(conv3, concat_pool23, 128, strides=(2, 2))
    fusion3 = conv_1_init(concat_pool33, 256 * 4, kernel_size=(1, 1))

    conv4 = conv_2_init(fusion3, 256)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    concat_pool14 = concat_pool(conv1, pool4, 256, strides=(16, 16))
    concat_pool24 = concat_pool(conv2, concat_pool14, 256, strides=(8, 8))
    concat_pool34 = concat_pool(conv3, concat_pool24, 256, strides=(4, 4))
    concat_pool44 = concat_pool(conv4, concat_pool34, 256, strides=(2, 2))
    fusion4 = conv_1_init(concat_pool44, 512 * 4, kernel_size=(1, 1))

    conv5 = conv_2_init(fusion4, 512)
    conv5 = Dropout(0.5)(conv5)

    clf_aspp = CLF_ASPP(conv5, conv1, conv2, conv3, conv4, input_shape)

    up_conv1 = UpSampling2D(size=(2, 2))(clf_aspp)
    up_conv1 = conv_1_init(up_conv1, 256, kernel_size=(2, 2))
    skip_conv4 = conv_1_init(conv4, 256, kernel_size=(1, 1))
    context_inference1 = conv_lstm(up_conv1, skip_conv4, channel=256)
    conv6 = conv_2_init(context_inference1, 256)

    up_conv2 = UpSampling2D(size=(2, 2))(conv6)
    up_conv2 = conv_1_init(up_conv2, 128, kernel_size=(2, 2))
    skip_conv3 = conv_1_init(conv3, 128, kernel_size=(1, 1))
    context_inference2 = conv_lstm(up_conv2, skip_conv3, channel=128)
    conv7 = conv_2_init(context_inference2, 128)

    up_conv3 = UpSampling2D(size=(2, 2))(conv7)
    up_conv3 = conv_1_init(up_conv3, 64, kernel_size=(2, 2))
    skip_conv2 = conv_1_init(conv2, 64, kernel_size=(1, 1))
    context_inference3 = conv_lstm(up_conv3, skip_conv2, channel=64)
    conv8 = conv_2_init(context_inference3, 64)

    up_conv4 = UpSampling2D(size=(2, 2))(conv8)
    up_conv4 = conv_1_init(up_conv4, 32, kernel_size=(2, 2))
    skip_conv1 = conv_1_init(conv1, 32, kernel_size=(1, 1))
    context_inference4 = conv_lstm(up_conv4, skip_conv1, channel=32)
    conv9 = conv_2_init(context_inference4, 32)


    if num_class == 1:
        conv10 = Conv2D(num_class, (1, 1), activation='sigmoid')(conv9)
    else:
        conv10 = Conv2D(num_class, (1, 1), activation='softmax')(conv9)

    model = Model(inputs=inputs, outputs=conv10)

    return model


def CLF_ASPP(conv5, conv1, conv2, conv3, conv4, input_shape):

    b0 = conv_1_init(conv5, 256, (1, 1))
    b1 = dilate_conv(conv5, 256, dilation_rate=(2, 2))
    b2 = dilate_conv(conv5, 256, dilation_rate=(4, 4))
    b3 = dilate_conv(conv5, 256, dilation_rate=(6, 6))

    out_shape0 = input_shape[0] // pow(2, 4)
    out_shape1 = input_shape[1] // pow(2, 4)
    b4 = AveragePooling2D(pool_size=(out_shape0, out_shape1))(conv5)
    b4 = conv_1_init(b4, 256, (1, 1))
    b4 = BilinearUpsampling((out_shape0, out_shape1))(b4)

    clf1 = conv_1_init(conv1, 256, strides=(16, 16))
    clf2 = conv_1_init(conv2, 256, strides=(8, 8))
    clf3 = conv_1_init(conv3, 256, strides=(4, 4))
    clf4 = conv_1_init(conv4, 256, strides=(2, 2))

    outs = Concatenate()([clf1, clf2, clf3, clf4, b0, b1, b2, b3, b4])

    outs = conv_1_init(outs, 256 * 4, (1, 1))
    outs = Dropout(0.5)(outs)

    return outs

## Training

In [None]:
from keras.metrics import  Recall, Precision
# https://stats.stackexchange.com/questions/323154/precision-vs-recall-acceptable-limits
# https://www.kdnuggets.com/2016/12/4-reasons-machine-learning-model-wrong.html#:~:text=Precision%20is%20a%20measure%20of,positive%20class%20are%20actually%20true.&text=Hence%2C%20a%20situation%20of%20Low,positive%20values%20are%20never%20predicted.
# Pre and Post processing # https://github.com/nikhilroxtomar/UNet-Segmentation-in-Keras-TensorFlow/blob/master/unet-segmentation.ipynb
model = CLCI_Net()
#model.summary()
model.compile(optimizer=Adam(lr=1e-4), loss=dice_coef_loss, metrics=[dice_coef,'acc',Recall(), Precision()])

  super().__init__(name, **kwargs)


In [None]:
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau

checkpoint_filepath = '/content/drive/MyDrive/integradora_fiec/modelos/clcinet-native-filtered-v2-{epoch:03d}-{dice_coef:03f}-{val_dice_coef:03f}.h5'
model_checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_dice_coef',
    mode='max',
    save_best_only=True)

reduce_lr = ReduceLROnPlateau(monitor='val_dice_coef', factor=0.2, patience=2, min_lr=2e-6)

callbacks = [
    model_checkpoint_callback,
    reduce_lr
]

In [None]:
history = model.fit(
      X_train, y_train,
      batch_size=8,
      epochs=60,
      verbose=1,
      callbacks=callbacks,
      validation_data=(X_valid,y_valid))

Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60
Epoch 16/60
Epoch 17/60
Epoch 18/60
Epoch 19/60
Epoch 20/60
Epoch 21/60
Epoch 22/60
Epoch 23/60
Epoch 24/60
Epoch 25/60
Epoch 26/60
Epoch 27/60
Epoch 28/60
Epoch 29/60
Epoch 30/60
Epoch 31/60

retrain after timeout...

In [None]:
model.load_weights("/content/drive/MyDrive/integradora_fiec/modelos/clcinet-native-filtered-v2-029-0.821825-0.787639.h5")

In [None]:
history2 = model.fit(
      X_train, y_train,
      batch_size=8,
      epochs=31,
      verbose=1,
      callbacks=callbacks,
      validation_data=(X_valid,y_valid))

Epoch 1/31
Epoch 2/31
Epoch 3/31
Epoch 4/31
Epoch 5/31
Epoch 6/31
Epoch 7/31
Epoch 8/31
Epoch 9/31
Epoch 10/31
Epoch 11/31
Epoch 12/31
Epoch 13/31
Epoch 14/31
Epoch 15/31
Epoch 16/31
Epoch 17/31
Epoch 18/31
Epoch 19/31
Epoch 20/31
Epoch 21/31
Epoch 22/31
Epoch 23/31
Epoch 24/31
Epoch 25/31
Epoch 26/31
Epoch 27/31
Epoch 28/31
Epoch 29/31
Epoch 30/31
Epoch 31/31

## Next steps...

In [None]:
# It is good, fortunately we got same val_dice_coef as prev ~ 0.84
# So it seems now I need to do same preprocessing for all images(or only the 20?) between lacunar and mca.
# Make  a dataset (.npy)
# extract features from this dataset (csv)
# evaluate models perfomance ->
#