# Prepare diffusion weighted images for input into models - NCANDA

### Load libraries and image paths

In [None]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from nilearn.plotting import plot_anat
from scipy.ndimage import zoom

dwi_files = np.loadtxt('ncanda_train_files.txt', dtype=str)

# Load images
def loadb0(img_file, as_type = 'nib', normalize = True, resize = True):
    # as_type should either be nib for nibabel format or np for numpy array
    if as_type not in ['nib', 'np']:
        raise TypeError('Incorrect option for as_type. Use either nib or np.')
    # Load image
    img = nib.load(img_file)
    img_fixed = img.get_fdata()
    # Normalize intensity to range of 0 to 1
    if normalize:
        img_fixed = (img_fixed - np.min(img_fixed, axis = (0,1,2))) / (np.max(img_fixed, axis = (0,1,2)) - np.min(img_fixed, axis = (0,1,2)))
    if resize:
        if img_fixed.shape[0] != 96:
            img_fixed = zoom(img_fixed, (0.75,0.75,1))
    if as_type == 'np':
        return img_fixed
    elif as_type == 'nib':
        # Reform as image data type
        img_fixed = nib.Nifti1Image(img_fixed, np.eye(4))
        return img_fixed

### Visualize a random sample image

In [None]:
dwi1 = loadb0(dwi_files[0], resize = False)
dwi2 = loadb0(dwi_files[0], resize = True)
fig, ax = plt.subplots(2,1,figsize = (16,8))
plot_anat(dwi1, axes = ax[0], vmin = -0.1, vmax = 0.3)
plot_anat(dwi2, axes = ax[1], vmin = -0.1, vmax = 0.3)
plt.show()

### Data augmentation

We will need to perform data augmentation to expand our training set. We can do this by adding small rotations, translations, adding noise, or changing brightness and contrast. Let's create a function that can perform this and test it on a sample image.

In [None]:
from scipy.ndimage import rotate, shift
from skimage.util import random_noise

def add_rotation(image3d, amount, axis):
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        img = image3d.get_fdata()
    else:
        img = np.array(image3d)
    
    axes = [0,1,2]
    axes.remove(axis)
    # Apply rotation
    img_rotated = rotate(img, amount, axes = tuple(axes), reshape = False)
    
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        img_rotated = nib.Nifti1Image(img_rotated, np.eye(4))
    
    return img_rotated

def add_translation(image3d, amount, axis):
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        img = image3d.get_fdata()
    else:
        img = np.array(image3d)
        
    shift_apply = np.zeros((3,))
    shift_apply[axis] = amount
    # Apply translation
    img_translated = shift(img, shift_apply)
    
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        img_translated = nib.Nifti1Image(img_translated, np.eye(4))
        
    return img_translated

def add_noise(image3d, amount = 0.0001):
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        img = image3d.get_fdata()
    else:
        img = np.array(image3d)
    
    noisy_img = random_noise(img, mode = 'gaussian', clip = True, var = amount)
    
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        noisy_img = nib.Nifti1Image(noisy_img, np.eye(4))
    
    return noisy_img

idx = np.random.choice(dwi_files.shape[0], 1)[0]
dwi = loadb0(dwi_files[idx])
print('Adding rotation')
dwi_rotated = add_rotation(dwi, 10, 0)
print('Adding translation')
dwi_translated = add_translation(dwi, 10, 2)
print('Adding noise')
dwi_noisy = add_noise(dwi)

plt.close('all')
fig,ax = plt.subplots(4,1, figsize = (16,14))
plot_anat(dwi, axes = ax[0], vmin = -0.1, vmax = 0.3, cut_coords = (50,70,40))
ax[0].set_title('original')
plot_anat(dwi_rotated, axes = ax[1], vmin = -0.1, vmax = 0.3, cut_coords = (50,70,40))
ax[1].set_title('rotated 10 degrees')
plot_anat(dwi_translated, axes = ax[2], vmin = -0.1, vmax = 0.3, cut_coords = (50,70,40))
ax[2].set_title('translated 10 voxels')
plot_anat(dwi_noisy, axes = ax[3], vmin = -0.1, vmax = 0.3, cut_coords = (50,70,40))
ax[3].set_title('noisy')
plt.show()

In [None]:
idx = np.random.choice(dwi_files.shape[0], 1)[0]
dwi = loadb0(dwi_files[idx], dwi2b0, resize = True)
print('Adding rotation')
dwi_rotated = add_rotation(index_img(dwi,0), 10, 0)
print('Adding translation')
dwi_translated = add_translation(index_img(dwi,0), 10, 0)
print('Adding noise')
dwi_noisy = add_noise(index_img(dwi,0))

plt.close('all')
fig,ax = plt.subplots(4,1, figsize = (16,14))
plot_anat(index_img(dwi, 0), axes = ax[0], vmin = -0.1, vmax = 0.3, cut_coords = (80,70,60))
ax[0].set_title('original')
plot_anat(dwi_rotated, axes = ax[1], vmin = -0.1, vmax = 0.3, cut_coords = (80,70,60))
ax[1].set_title('rotated 10 degrees')
plot_anat(dwi_translated, axes = ax[2], vmin = -0.1, vmax = 0.3, cut_coords = (80,70,60))
ax[2].set_title('translated 10 voxels')
plot_anat(dwi_noisy, axes = ax[3], vmin = -0.1, vmax = 0.3, cut_coords = (80,70,60))
ax[3].set_title('noisy')
plt.show()

In [None]:
import os

file_list = []
num_files = dwi_files.shape[0] // 2
file_idx = np.random.randint(0, dwi_files.shape[0], size = num_files)
for ind,file_path in enumerate(dwi_files[file_idx]):
    #if ind < 130: continue
    #print('Augmenting ' + file_path)
    print('Processing %d of %d files' % ((ind + 1), dwi_files[file_idx].shape[0]))
    subj = file_path.split('/')[-4].split('_')[-1]
    filenum = file_path.split('/')[-1].split('.')[0].split('-')[-1]
    dwi = loadb0(file_path, as_type = 'np')
    # Original file
    prefix = 'ncanda/traindata/' + subj
    if not os.path.exists(prefix):
        os.makedirs(prefix)
    if '1.nii.gz' in file_path or '2.nii.gz' in file_path:
        prefix = prefix + '/b0vol' + filenum
    else:
        prefix = prefix + '/nonb0vol' + filenum
    savename = prefix + '_original.npy'
    np.save(savename, dwi)
    file_list.append(savename)
    # Rotate and translate
    i = 1
    for amount in [-10, 10]:
        for axis in [0,1,2]:
            dwi_preprocess = add_rotation(dwi, amount, axis)
            savename = prefix + '_rotate' + str(i) + '.npy'
            np.save(savename, dwi_preprocess)
            file_list.append(savename)
            if axis != 2:
                dwi_preprocess = add_translation(dwi, amount, axis)
                savename = prefix + '_translate' + str(i) + '.npy'
                np.save(savename, dwi_preprocess)
                file_list.append(savename)
            i += 1
#     # Noise
#     dwi_preprocess = add_noise(dwi)
#     savename = prefix + '_noisy.npy'
#     np.save(savename, dwi_preprocess)
#     file_list.append(savename)
print('Number of training samples: ' + str(len(file_list)))

np.savetxt('ncanda/X_files.txt', file_list, fmt = '%s')