# Prepare diffusion weighted images for input into models

### Load libraries and image paths

In [None]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from nilearn.plotting import plot_anat
from nilearn.image import index_img

# load list of subject DWI files
dwi_files = np.loadtxt('dwi_file_list.txt', dtype=str)
# create list of subject bval files
bvals = [file[:-7] + '.bval' for file in dwi_files]

# Extract indices of image corresponding to b0 volumes
def getb0(bval_file, nonb0 = False):
    b = np.loadtxt(bval_file)
    # Use only non-b0 volumes with b values of 1000, any higher and the SNR gets very low
    if nonb0:
        b_ind = np.concatenate((np.where(b == 995)[0], np.where(b == 1000)[0], np.where(b == 1005)[0]))
    # The b0 volumes in HCP data actually have a value of 5, so check both 0 and 5
    else:
        b_ind = np.concatenate((np.where(b == 5)[0], np.where(b == 0)[0]))
    return b_ind

# Dictionary that maps file name to indices of where the b0 volumes are within each file
dwi2b0 = {dwi_files[ind] : getb0(file, nonb0 = True) for ind,file in enumerate(bvals)}

### Visualize a random sample image

In [None]:
from scipy.ndimage import zoom

# Takes an image file and the map from filename to indices of b0 volumes to load the right image data
def loadb0(img_file, file2b_ind, as_type = 'nib', normalize = True, resize = False):
    # as_type should either be nib for nibabel format or np for numpy array
    if as_type not in ['nib', 'np']:
        raise TypeError('Incorrect option for as_type. Use either nib or np.')
    # Load image
    img = nib.load(img_file)
    # Remove non b0 volumes
    img_fixed = img.get_fdata()[:,:,:,file2b_ind[img_file]]
    # Normalize intensity to range of 0 to 1
    if normalize:
        img_fixed = (img_fixed - np.min(img_fixed, axis = (0,1,2))) / (np.max(img_fixed, axis = (0,1,2)) - np.min(img_fixed, axis = (0,1,2)))
    if resize:
        if img_fixed.shape[0] != 96:
            img_fixed = zoom(img_fixed, (96 / img_fixed.shape[0], 96 / img_fixed.shape[1], 64 / img_fixed.shape[2], 1))
    if as_type == 'np':
        return img_fixed
    elif as_type == 'nib':
        # Reform as image data type
        img_fixed = nib.Nifti1Image(img_fixed, np.eye(4))
        return img_fixed

In [None]:
# Random file
idx = np.random.choice(dwi_files.shape[0], 1)[0]
dwi = loadb0(dwi_files[idx], dwi2b0, resize = False)
dwi = dwi.get_fdata()
dwi[:,:,-15:,:] = dwi[:,:,-15:,:] + np.random.normal(0.1, 0.01, size = dwi[:,:,-15:,:].shape)
dwi = nib.Nifti1Image(dwi, np.eye(4))
num_vol = dwi.shape[-1]
for i in range(num_vol):
    plt.close('all')
    fig,ax = plt.subplots(1,1, figsize = (16,4))
    # Plot all b0 volumes, intensity scaled such that we can see the skull clearly
    plot_anat(index_img(dwi, i), axes = ax, vmin = -0.1, vmax = 0.3)
    plt.show()

### Data augmentation

We will need to perform data augmentation to expand our training set. We can do this by adding small rotations, translations, adding noise, or changing brightness and contrast. Let's create a function that can perform this and test it on a sample image.

In [None]:
from scipy.ndimage import rotate, shift
from skimage.util import random_noise

def add_rotation(image3d, amount, axis):
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        img = image3d.get_fdata()
    else:
        img = np.array(image3d)
    
    axes = [0,1,2]
    axes.remove(axis)
    # Apply rotation
    img_rotated = rotate(img, amount, axes = tuple(axes), reshape = False)
    
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        img_rotated = nib.Nifti1Image(img_rotated, np.eye(4))
    
    return img_rotated

def add_translation(image3d, amount, axis):
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        img = image3d.get_fdata()
    else:
        img = np.array(image3d)
        
    shift_apply = np.zeros((3,))
    shift_apply[axis] = amount
    # Apply translation
    img_translated = shift(img, shift_apply)
    
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        img_translated = nib.Nifti1Image(img_translated, np.eye(4))
        
    return img_translated

def add_noise(image3d, amount = 0.0001):
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        img = image3d.get_fdata()
    else:
        img = np.array(image3d)
    
    noisy_img = random_noise(img, mode = 'gaussian', clip = True, var = amount)
    
    if isinstance(image3d, nib.nifti1.Nifti1Image):
        noisy_img = nib.Nifti1Image(noisy_img, np.eye(4))
    
    return noisy_img

In [None]:
idx = np.random.choice(dwi_files.shape[0], 1)[0]
dwi = loadb0(dwi_files[idx], dwi2b0, resize = True)
print('Adding rotation')
dwi_rotated = add_rotation(index_img(dwi,0), 10, 0)
print('Adding translation')
dwi_translated = add_translation(index_img(dwi,0), 10, 0)
print('Adding noise')
dwi_noisy = add_noise(index_img(dwi,0))

plt.close('all')
fig,ax = plt.subplots(4,1, figsize = (16,14))
plot_anat(index_img(dwi, 0), axes = ax[0], vmin = -0.1, vmax = 0.3, cut_coords = (80,70,60))
ax[0].set_title('original')
plot_anat(dwi_rotated, axes = ax[1], vmin = -0.1, vmax = 0.3, cut_coords = (80,70,60))
ax[1].set_title('rotated 10 degrees')
plot_anat(dwi_translated, axes = ax[2], vmin = -0.1, vmax = 0.3, cut_coords = (80,70,60))
ax[2].set_title('translated 10 voxels')
plot_anat(dwi_noisy, axes = ax[3], vmin = -0.1, vmax = 0.3, cut_coords = (80,70,60))
ax[3].set_title('noisy')
plt.show()

In [None]:
import os

file_list = []
for ind,file_path in enumerate(dwi_files):
    #if ind < 315: continue
    #print('Augmenting ' + file_path)
    print('Processing %d of %d files' % ((ind + 1), dwi_files.shape[0]))
    subj = file_path.split('/')[1]
    dwi = loadb0(file_path, dwi2b0, as_type = 'np', resize = True)
    # Use each b0 volume as an independent input
    for b0 in range(dwi.shape[-1]):
        #print('\tb0 volume %d out of %d' % ((b0+1),dwi.shape[-1]))
        # Original file
        prefix = 'hcp_resized/' + subj
        if not os.path.exists(prefix):
            os.makedirs(prefix)
        prefix = prefix + '/' + file_path.split('/')[-1][:-7] + '_b0vol' + str(b0+1)
        savename = prefix + '_original.npy'
        np.save(savename, dwi[:,:,:,b0])
        file_list.append(savename)
        # Rotate and translate
        i = 1
        for amount in [-10, 10]:
            for axis in [0,1,2]:
                dwi_preprocess = add_rotation(dwi[:,:,:,b0], amount, axis)
                savename = prefix + '_rotate' + str(i) + '.npy'
                np.save(savename, dwi_preprocess)
                file_list.append(savename)
                if axis != 2:
                    dwi_preprocess = add_translation(dwi[:,:,:,b0], amount, axis)
                    savename = prefix + '_translate' + str(i) + '.npy'
                    np.save(savename, dwi_preprocess)
                    file_list.append(savename)
                i += 1
#         # Noise
#         dwi_preprocess = add_noise(dwi[:,:,:,b0])
#         savename = prefix + '_noisy.npy'
#         np.save(savename, dwi_preprocess)
#         file_list.append(savename)
print('Number of training samples: ' + str(len(file_list)))

np.savetxt('X_resized_files.txt', file_list, fmt = '%s')