First create a direct access to /datasets folder in your personal drive

In [None]:
# Mount drive if needed
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive



## Install dependencies

In [1]:
# ! pip install SimpleITK
# ! pip install antspyx


## Load images to current session

In [2]:
# ! mkdir -v data
# ! unzip "/content/drive/MyDrive/integradora_fiec/datasets/NATIVE.zip" -d "/data"

## Preprocessing steps functions

In [1]:
%matplotlib inline
import os
import ants
import SimpleITK as sitk
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import cv2

print(f'AntsPy version = {ants.__version__}')
print(f'SimpleITK version = {sitk.__version__}')

AntsPy version = 0.3.8
SimpleITK version = 2.2.0


In [None]:
mni_T1_path = TEMPLATE_PATH = '/content/drive/MyDrive/integradora_fiec/datasets/templates/mni_icbm152_t1_tal_nlin_sym_09a.nii'

def load_template_ants() -> ants.ANTsImage:
    template_img_ants = ants.image_read(TEMPLATE_PATH)
    return template_img_ants

def load_img_ants(path: str) -> ants.ANTsImage:
    raw_img_ants = ants.image_read(path)
    return raw_img_ants

def register_to_mni(img: ants.ANTsImage, mask: ants.ANTsImage) -> ants.ANTsImage:
    """register ants mri image and mask to mni space"""
    template_img = load_template_ants()
    transformation = ants.registration(fixed=template_img, moving=img, type_of_transform='SyN')

    img_registered = transformation['warpedmovout']
  
    mask_registered = ants.apply_transforms(fixed=template_img,moving=mask,transformlist=transformation['fwdtransforms'])
    return img_registered, mask_registered

## Register

In [None]:
from glob import glob

xpaths = sorted(glob(f'/data/NATIVE/*/*/*01.nii.gz') )
ypaths = sorted(glob(f'/data/NATIVE/*/*/*01_LesionSmooth.nii.gz'))
assert len(xpaths) == len(ypaths)

In [None]:
print("Number of samples:", len(xpaths))
for input_path, target_path in zip(xpaths, ypaths):
    print(input_path[-35:], "|", target_path[-48:])

Number of samples: 292
/c0001s0004t01/c0001s0004t01.nii.gz | /c0001s0004t01/c0001s0004t01_LesionSmooth.nii.gz
/c0001s0005t01/c0001s0005t01.nii.gz | /c0001s0005t01/c0001s0005t01_LesionSmooth.nii.gz
/c0001s0006t01/c0001s0006t01.nii.gz | /c0001s0006t01/c0001s0006t01_LesionSmooth.nii.gz
/c0001s0007t01/c0001s0007t01.nii.gz | /c0001s0007t01/c0001s0007t01_LesionSmooth.nii.gz
/c0001s0008t01/c0001s0008t01.nii.gz | /c0001s0008t01/c0001s0008t01_LesionSmooth.nii.gz
/c0001s0012t01/c0001s0012t01.nii.gz | /c0001s0012t01/c0001s0012t01_LesionSmooth.nii.gz
/c0002s0001t01/c0002s0001t01.nii.gz | /c0002s0001t01/c0002s0001t01_LesionSmooth.nii.gz
/c0002s0002t01/c0002s0002t01.nii.gz | /c0002s0002t01/c0002s0002t01_LesionSmooth.nii.gz
/c0002s0003t01/c0002s0003t01.nii.gz | /c0002s0003t01/c0002s0003t01_LesionSmooth.nii.gz
/c0002s0004t01/c0002s0004t01.nii.gz | /c0002s0004t01/c0002s0004t01_LesionSmooth.nii.gz
/c0002s0005t01/c0002s0005t01.nii.gz | /c0002s0005t01/c0002s0005t01_LesionSmooth.nii.gz
/c0002s0007t01/c0002

In [None]:
for i,(xpath, ypath) in enumerate(zip(xpaths, ypaths)):
  folder = xpath[:-20]
  file_name = xpath[:-7][-13:]
  x_registered_path = folder + file_name + '_registered.nii.gz'
  y_registered_path = folder + file_name + '_LesionSmooth_registered.nii.gz'

  x3d = load_img_ants(xpath)
  y3d = load_img_ants(ypath)

  x3d_registered, y3d_registered = register_to_mni(img=x3d,mask=y3d)

  print(i, x_registered_path)
  print(i, y_registered_path)

  x3d_registered.to_file(x_registered_path)
  y3d_registered.to_file(y_registered_path)

  #if i == 0 : break


0 /data/NATIVE/c0001/c0001s0004t01/c0001s0004t01_registered.nii.gz
0 /data/NATIVE/c0001/c0001s0004t01/c0001s0004t01_LesionSmooth_registered.nii.gz
1 /data/NATIVE/c0001/c0001s0005t01/c0001s0005t01_registered.nii.gz
1 /data/NATIVE/c0001/c0001s0005t01/c0001s0005t01_LesionSmooth_registered.nii.gz
2 /data/NATIVE/c0001/c0001s0006t01/c0001s0006t01_registered.nii.gz
2 /data/NATIVE/c0001/c0001s0006t01/c0001s0006t01_LesionSmooth_registered.nii.gz
3 /data/NATIVE/c0001/c0001s0007t01/c0001s0007t01_registered.nii.gz
3 /data/NATIVE/c0001/c0001s0007t01/c0001s0007t01_LesionSmooth_registered.nii.gz
4 /data/NATIVE/c0001/c0001s0008t01/c0001s0008t01_registered.nii.gz
4 /data/NATIVE/c0001/c0001s0008t01/c0001s0008t01_LesionSmooth_registered.nii.gz
5 /data/NATIVE/c0001/c0001s0012t01/c0001s0012t01_registered.nii.gz
5 /data/NATIVE/c0001/c0001s0012t01/c0001s0012t01_LesionSmooth_registered.nii.gz
6 /data/NATIVE/c0002/c0002s0001t01/c0002s0001t01_registered.nii.gz
6 /data/NATIVE/c0002/c0002s0001t01/c0002s0001t01_Le

## Bias Field Correction

In [None]:
xpaths = sorted(glob(f'/data/NATIVE/*/*/*01_registered.nii.gz') )
ypaths = sorted(glob(f'/data/NATIVE/*/*/*01_LesionSmooth_registered.nii.gz'))
assert len(xpaths) == len(ypaths)

In [None]:
print("Number of samples:", len(xpaths))
for input_path, target_path in zip(xpaths, ypaths):
    print(input_path[-35:], "|", target_path[-48:])

Number of samples: 292
t01/c0001s0004t01_registered.nii.gz | t01/c0001s0004t01_LesionSmooth_registered.nii.gz
t01/c0001s0005t01_registered.nii.gz | t01/c0001s0005t01_LesionSmooth_registered.nii.gz
t01/c0001s0006t01_registered.nii.gz | t01/c0001s0006t01_LesionSmooth_registered.nii.gz
t01/c0001s0007t01_registered.nii.gz | t01/c0001s0007t01_LesionSmooth_registered.nii.gz
t01/c0001s0008t01_registered.nii.gz | t01/c0001s0008t01_LesionSmooth_registered.nii.gz
t01/c0001s0012t01_registered.nii.gz | t01/c0001s0012t01_LesionSmooth_registered.nii.gz
t01/c0002s0001t01_registered.nii.gz | t01/c0002s0001t01_LesionSmooth_registered.nii.gz
t01/c0002s0002t01_registered.nii.gz | t01/c0002s0002t01_LesionSmooth_registered.nii.gz
t01/c0002s0003t01_registered.nii.gz | t01/c0002s0003t01_LesionSmooth_registered.nii.gz
t01/c0002s0004t01_registered.nii.gz | t01/c0002s0004t01_LesionSmooth_registered.nii.gz
t01/c0002s0005t01_registered.nii.gz | t01/c0002s0005t01_LesionSmooth_registered.nii.gz
t01/c0002s0007t01_re

In [None]:
def bias_field_correction(img: sitk.Image) -> sitk.Image:
    head_mask = sitk.RescaleIntensity(img, 0, 255)
    head_mask = sitk.LiThreshold(head_mask,0,1)

    shrinkFactor = 4
    inputImage = img
    inputImage = sitk.Shrink( img, [ shrinkFactor ] * inputImage.GetDimension() )
    maskImage = sitk.Shrink( head_mask, [ shrinkFactor ] * inputImage.GetDimension() )

    bias_corrector = sitk.N4BiasFieldCorrectionImageFilter()
    bias_corrector.Execute(inputImage, maskImage)

    log_bias_field = bias_corrector.GetLogBiasFieldAsImage(img)
    result = img / sitk.Exp( log_bias_field ) # corrected img at full resolution

    # output of division has 64 pixel type, we cast it to float32 to keep compatibility
    result = sitk.Cast(result, sitk.sitkFloat32)
    
    return result

def load_img_sitk(path: str) -> sitk.Image:
    raw_img_sitk = sitk.ReadImage(path, sitk.sitkFloat32)
    return raw_img_sitk


In [None]:
for i,(xpath, ypath) in enumerate(zip(xpaths, ypaths)):
  folder = xpath[:-20]
  file_name = xpath[:-7][-13:]
  x_out_path = folder + file_name + '_BF.nii.gz'

  x3d = load_img_sitk(xpath)
  x3d_bf_corrected = bias_field_correction(x3d)

  sitk.WriteImage(x3d_bf_corrected, x_out_path)

  print(i, x_out_path)

  #if i == 0 : break



0 /data/NATIVE/c0001/c0001s0004t01/c0001s0004t01_registered_BF.nii.gz
1 /data/NATIVE/c0001/c0001s0005t01/c0001s0005t01_registered_BF.nii.gz
2 /data/NATIVE/c0001/c0001s0006t01/c0001s0006t01_registered_BF.nii.gz
3 /data/NATIVE/c0001/c0001s0007t01/c0001s0007t01_registered_BF.nii.gz
4 /data/NATIVE/c0001/c0001s0008t01/c0001s0008t01_registered_BF.nii.gz
5 /data/NATIVE/c0001/c0001s0012t01/c0001s0012t01_registered_BF.nii.gz
6 /data/NATIVE/c0002/c0002s0001t01/c0002s0001t01_registered_BF.nii.gz
7 /data/NATIVE/c0002/c0002s0002t01/c0002s0002t01_registered_BF.nii.gz
8 /data/NATIVE/c0002/c0002s0003t01/c0002s0003t01_registered_BF.nii.gz
9 /data/NATIVE/c0002/c0002s0004t01/c0002s0004t01_registered_BF.nii.gz
10 /data/NATIVE/c0002/c0002s0005t01/c0002s0005t01_registered_BF.nii.gz
11 /data/NATIVE/c0002/c0002s0007t01/c0002s0007t01_registered_BF.nii.gz
12 /data/NATIVE/c0002/c0002s0008t01/c0002s0008t01_registered_BF.nii.gz
13 /data/NATIVE/c0002/c0002s0009t01/c0002s0009t01_registered_BF.nii.gz
14 /data/NATIVE/

## Prepare training data

In [1]:
xpaths = sorted(glob(f'/data/NATIVE/*/*/*01_registered_BF.nii.gz') )
ypaths = sorted(glob(f'/data/NATIVE/*/*/*01_LesionSmooth_registered.nii.gz'))
assert len(xpaths) == len(ypaths)

NameError: name 'glob' is not defined

In [2]:
print("Number of samples:", len(xpaths))
for input_path, target_path in zip(xpaths, ypaths):
    print(input_path[-35:], "|", target_path[-48:])

NameError: name 'xpaths' is not defined

In [None]:
# load mni152 brain mask
TEMPLATE_BRAIN_MASK_PATH = '/content/drive/MyDrive/integradora_fiec/datasets/templates/mni_icbm152_t1_tal_nlin_sym_09a_mask.nii'
mni152_brain_mask = sitk.ReadImage(TEMPLATE_BRAIN_MASK_PATH, sitk.sitkFloat32)
mni152_T1 = sitk.ReadImage(TEMPLATE_PATH, sitk.sitkFloat32)


In [None]:
import numpy as np
def preprocess_ximg(ximg: sitk.Image, flipped = False) -> np.ndarray:
  x3d = sitk.HistogramMatching(ximg, mni152_T1)
  x3d = sitk.Multiply(x3d, mni152_brain_mask) # mask brain
  x3d = sitk.CurvatureAnisotropicDiffusion(x3d, conductanceParameter=1, numberOfIterations=1) # denoise a bit
  
  if flipped:
    x3d = sitk.Flip(x3d,(True, False, False))
  
  x3d = sitk.GetArrayFromImage(x3d)
  x3d = x3d[30:160,4:228,14:190] # crop to size -> (130, 224, 176)
  x3d = x3d / 255.0
  x3d = np.expand_dims(x3d,3) # add channel -> (130, 224, 176, 1)
  assert x3d.shape == (130,224,176,1)
  return x3d

def preprocess_yimg(yimg: sitk.Image, flipped=False) -> np.ndarray:
  y3d = yimg

  if flipped:
    y3d = sitk.Flip(y3d,(True, False, False))
  
  y3d = sitk.GetArrayFromImage(y3d)
  y3d = y3d[30:160,4:228,14:190] # crop to size -> (130, 224, 176)
  y3d = y3d / 255.0
  y3d = np.expand_dims(y3d,3) # add channel -> (130, 224, 176, 1)
  assert x3d.shape == (130,224,176,1)
  return y3d


In [None]:
ROW_SIZE = 224 # shapes of model inpput
COL_SIZE = 176

X = np.empty((0,ROW_SIZE,COL_SIZE,1), dtype=np.float32)
Y = np.empty((0,ROW_SIZE,COL_SIZE,1), dtype=np.float32)

for i,(xpath, ypath) in enumerate(zip(xpaths, ypaths)):

    ximg        =   sitk.ReadImage(xpath, sitk.sitkFloat32)
    x3d         =  preprocess_ximg(ximg) 
    #flipped_x3d =  preprocess_ximg(ximg, flipped=True)

    yimg        =   sitk.ReadImage(ypath, sitk.sitkFloat32)
    y3d         =  preprocess_yimg(yimg) 
    #flipped_y3d =  preprocess_yimg(yimg, flipped=True)
    
    #x3d = np.concatenate((x3d, flipped_x3d), axis=0)
    #y3d = np.concatenate((y3d, flipped_y3d), axis=0)

    #assert x3d.shape  == (260,224,176, 1)
    #assert y3d.shape  == (260,224,176, 1)

    X = np.concatenate((X, x3d), axis=0)
    Y = np.concatenate((Y, y3d), axis=0)

    print('.', end='')

In [None]:
print(X.shape, Y.shape)

(37960, 224, 176, 1) (37960, 224, 176, 1)


In [None]:
X[:,:,:,0].shape

(130, 224, 176)

## Double check slices

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def get_x2d_marked(x2d,y2d):
  dilation_level = 4
  m = (y2d).astype('uint8')
  m = sitk.GetImageFromArray(m)
  m = sitk.BinaryDilate(m,(dilation_level,1,1))
  m = sitk.BinaryContour(m)

  x2d_marked = sitk.GetImageFromArray(x2d)
  x2d_marked = sitk.MaskNegated(x2d_marked, sitk.Cast(m,sitk.sitkFloat32))
  x2d_marked = sitk.GetArrayFromImage(x2d_marked)
  return x2d_marked

def show_slices(slices: list[np.ndarray], cmap: str ='gray'):
  """ 
  Function to display a list of image slices (2D arrays). Optimal quantity is three slices.
  """
  fig, axes = plt.subplots(len(slices), 1, figsize=(15,15))
  for i, slice in enumerate(slices):
    axes[i].imshow(slice, cmap=cmap)

In [None]:
STEPS = 150
c=0
for i in range(0,len(X),STEPS):
  x, y = X[i], Y[i]
  if len(np.unique(y)) == 1:
    continue
  x2d_marked = get_x2d_marked(x[:,:,0],y[:,:,0])
  show_slices([x2d_marked,x[:,:,0]])
  c+=1
  if c==10:
    break

Output hidden; open in https://colab.research.google.com to view.

## Save training dataset as npy

In [None]:
from numpy import save
# contains data processed from 292 native ATLAS imgs trough: register to mni, bias field, histogram matching, brain extraction, denoise
X_output_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_ALL_X.npy'
Y_output_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_ALL_Y.npy'


save(X_output_path, X)
save(Y_output_path, Y)

## Load first train set [JUMP HERE IF DATA AVAILABLE]

In [None]:
from numpy import load
X_input_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_ALL_X.npy'
Y_input_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_ALL_Y.npy'

X = load(X_input_path)
Y = load(Y_input_path)

In [None]:
print(X.shape, Y.shape)

(37960, 224, 176, 1) (37960, 224, 176, 1)


## Make flip operation to double dataset size

In [None]:
import numpy as np

ROW_SIZE = 224 # shapes of model inpput
COL_SIZE = 176

# need to allocate array before hand for efficient usage of memory
X_flipped = np.zeros((37960,ROW_SIZE,COL_SIZE,1), dtype=np.float32)
Y_flipped = np.zeros((37960,ROW_SIZE,COL_SIZE,1), dtype=np.float32)

for i, idx in enumerate(range(0, len(X))):
  xtemp = np.flip(X[idx,:,:,0], axis=1)
  ytemp = np.flip(Y[idx,:,:,0], axis=1)
  
  #xtemp = np.expand_dims(xtemp,2) # add channel -> (224, 176, 1)
  #ytemp = np.expand_dims(ytemp,2)

  X_flipped[idx] = np.expand_dims(xtemp,2) 
  Y_flipped[idx] = np.expand_dims(ytemp,2) 

  #if i == 100 : break
  if i % 130 == 0:
    print('.', end='')

....................................................................................................................................................................................................................................................................................................

In [None]:
print(X_flipped.shape, Y_flipped.shape)

(37960, 224, 176, 1) (37960, 224, 176, 1)


In [None]:
from numpy import save
# contains data processed from 292 native ATLAS imgs trough: register to mni, bias field, histogram matching, brain extraction, denoise
X_output_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_ALL_FLIPPED_X.npy'
Y_output_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_ALL_FLIPPED_Y.npy'

save(X_output_path, X_flipped)
save(Y_output_path, Y_flipped)

## Double check slices flipped

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
def get_x2d_marked(x2d,y2d):
  dilation_level = 4
  m = (y2d).astype('uint8')
  m = sitk.GetImageFromArray(m)
  m = sitk.BinaryDilate(m,(dilation_level,1,1))
  m = sitk.BinaryContour(m)

  x2d_marked = sitk.GetImageFromArray(x2d)
  x2d_marked = sitk.MaskNegated(x2d_marked, sitk.Cast(m,sitk.sitkFloat32))
  x2d_marked = sitk.GetArrayFromImage(x2d_marked)
  return x2d_marked

def show_slices(slices: list[np.ndarray], cmap: str ='gray'):
  """ 
  Function to display a list of image slices (2D arrays). Optimal quantity is three slices.
  """
  fig, axes = plt.subplots(len(slices), 1, figsize=(15,15))
  for i, slice in enumerate(slices):
    axes[i].imshow(slice, cmap=cmap)

In [None]:
STEPS = 150
c=0
for i in range(0,len(X_flipped),STEPS):
  x, y = X_flipped[i], Y_flipped[i]
  if len(np.unique(y)) == 1:
    continue
  x2d_marked = get_x2d_marked(x[:,:,0],y[:,:,0])
  show_slices([x2d_marked,x[:,:,0]])
  c+=1
  if c==10:
    break

Output hidden; open in https://colab.research.google.com to view.

## Join normal and flipped 

In [None]:
import numpy as np

ROW_SIZE = 224 # shapes of model inpput
COL_SIZE = 176

# need to allocate array before hand for efficient usage of memory
X_doubled = np.zeros((37960*2,ROW_SIZE,COL_SIZE,1), dtype=np.float32)
Y_doubled = np.zeros((37960*2,ROW_SIZE,COL_SIZE,1), dtype=np.float32)

In [None]:
print(X_doubled.shape, Y_doubled.shape)

(75920, 224, 176, 1) (75920, 224, 176, 1)


In [None]:
X_doubled[0:37960] = X
Y_doubled[0:37960] = Y

X_doubled[37960:37960*2] = X_flipped
Y_doubled[37960:37960*2] = Y_flipped


In [None]:
print(X_doubled.shape, Y_doubled.shape)

(75920, 224, 176, 1) (75920, 224, 176, 1)


In [None]:
X_output_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_ALL_DOUBLED_X.npy'
Y_output_path = '/content/drive/MyDrive/integradora_fiec/datasets/paper lesions extended/dataset_clinet_input_processed_ALL_DOUBLED_Y.npy'

save(X_output_path, X_doubled)
save(Y_output_path, Y_doubled)

## Double check doubled

In [None]:
STEPS = 2438
c=0
for i in range(0,len(X_doubled),STEPS):
  x, y = X_doubled[i], Y_doubled[i]
  if len(np.unique(y)) == 1:
    continue
  x2d_marked = get_x2d_marked(x[:,:,0],y[:,:,0])
  show_slices([x2d_marked,x[:,:,0]])
  c+=1
  if c==10:
    break

Output hidden; open in https://colab.research.google.com to view.