In [None]:
import os
import numpy as np
from pandas import DataFrame
import nibabel
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_style('white')
from ia_mri_tools.ia_mri_tools import coil_correction, signal_likelihood, textures
from sklearn.covariance import EmpiricalCovariance
from sklearn.linear_model import LogisticRegression

In [None]:
# Load the data
data_path = 'data/HCP/100307'
t1 = nibabel.load(os.path.join(data_path, 'T1w_acpc_dc.nii.gz')).get_data()
t2 = nibabel.load(os.path.join(data_path, 'T2w_acpc_dc.nii.gz')).get_data()
nx, ny, nz = t1.shape

# Load the Freesurfer segmentation
seg = nibabel.load(os.path.join(data_path, 'aparc+aseg.nii.gz')).get_data()

# and label mapping
with open('../ia_mri_tools/freesurfer/FreeSurferColorLUT.txt', 'r') as f:
    labels = {}
    for line in f.readlines():
        try:
            bits = line.split()
            index = int(bits[0])
            labels[bits[1]] = index
        except:
            continue
            
# Make some ROI masks from the freesurfer labels
gray = np.logical_and(seg>=1000, seg<=1035)
white = seg == labels['Left-Cerebral-White-Matter']
thalamus = seg == labels['Left-Thalamus-Proper']
caudate = seg == labels['Left-Caudate']
putamen = seg == labels['Left-Putamen']
pallidum = seg == labels['Left-Pallidum']
csf = seg == labels['CSF']

ventricle = ((seg == labels['Left-Lateral-Ventricle'])
    + (seg == labels['Left-Inf-Lat-Vent']).astype(np.int16)
    + (seg == labels['3rd-Ventricle']).astype(np.int16)
    + (seg == labels['4th-Ventricle']).astype(np.int16))


In [None]:
# conservative background mask (use the sum of the uncorrected t1 and t2 data)
background = signal_likelihood(t1+t2) < 0.7
brain = 1-background

# z = 120
# plt.imshow(background[:,::-1,z].transpose(), cmap='gray', vmin=0, vmax=1)
# plt.axis('off')
# plt.title('Background')

In [None]:
# Compute the coil correction using the sum of T1 and T2 as the reference
c = coil_correction(t1+t2)

# Correct the intensities
ct1 = c*t1
ct2 = c*t2

In [None]:
# Compute the textures for the intensity corrected T1 and T2 data
scales = [1, 2, 4, 8]

t1_textures, t1_labels = textures(ct1, scales, basename='t1', whiten=True, mask=brain)
t2_textures, t2_labels = textures(ct2, scales, basename='t2', whiten=True, mask=brain)
ns = t1_textures.shape[-1]

In [None]:
# Create a linear classifier to separate CSF from gray and white
bkg_signal = select([t1_textures, t2_textures], background)[::100,:]
csf_signal = select([t1_textures, t2_textures], csf)
gray_signal = select([t1_textures, t2_textures], gray)[::50,:]
white_signal =  select([t1_textures, t2_textures], white)[::50,:]
ventricle_signal = select([t1_textures, t2_textures], ventricle)

bkg_label = 0*np.ones([bkg_signal.shape[0],1], dtype=np.int16)
csf_label = 1*np.ones([csf_signal.shape[0],1], dtype=np.int16)
gray_label = 2*np.ones([gray_signal.shape[0],1], dtype=np.int16)
white_label = 3*np.ones([white_signal.shape[0],1], dtype=np.int16)
ventricle_label = 4*np.ones([ventricle_signal.shape[0],1], dtype=np.int16)

class_names = ['Background', 'CSF', 'Gray', 'White', 'Ventricle']
nclasses = len(class_names)

Xtrain = np.vstack((bkg_signal, csf_signal, gray_signal, white_signal, ventricle_signal))
ytrain = np.vstack((bkg_label,  csf_label,  gray_label,  white_label,  ventricle_label))


In [None]:
csf_classifier = LogisticRegression(penalty='l2', solver='sag', max_iter=100, multi_class='multinomial', tol=1e-2)
csf_classifier.fit(Xtrain, ytrain)

In [None]:
z = 120
slice_signal = select([t1_textures[:,:,z,:], t2_textures[:,:,z,:]])

csf_pred = csf_classifier.predict(slice_signal.reshape([nx*ny,2]))
plt.figure(figsize=[10,5])
plt.subplot(1,3,1)
plt.imshow(ct1[:,::-1,z].transpose(), cmap='gray')
plt.axis('off')
plt.title('T1 image')
plt.subplot(1,3,2)
plt.imshow(ct2[:,::-1,z].transpose(), cmap='gray')
plt.axis('off')
plt.title('T2 image')
plt.subplot(1,3,3)
plt.imshow(csf_pred.reshape([nx,ny])[:,::-1].transpose(), cmap='jet')
plt.axis('off')
plt.title('Classification')

In [None]:
csf_pred = csf_classifier.predict_proba(slice_signal.reshape([nx*ny, 2])).reshape([nx,ny, nclasses])
plt.figure(figsize=[10,10])
ndraw = int(np.ceil(np.sqrt(nclasses)))
for n in range(nclasess):
plt.subplot(ndraw,ndraw,n+1)
plt.imshow(csf_pred[:,::-1,n].transpose(), cmap='gray', vmin=0, vmax=1)
plt.axis('off')
plt.title('P({}]'.format(class_names[n]))