# Model Fitting comparison

In [None]:
import sys
import numpy as np
np.random.RandomState(seed=2014)
import matplotlib.pyplot as plt
import dipy.reconst.cross_validation as xval
import dipy.reconst.dti as dti
import dipy.reconst.dki as dki
import scipy.stats as stats
from dipy.io.image import load_nifti
from dipy.core.gradients import gradient_table

#### Load data denoised with P2S and P2S2 along with the Raw Noisy data

In [None]:
data, affine = load_nifti('data.nii.gz')
data_p2s, affine = load_nifti('denoised_patch2self_7T.nii.gz')
data_p2s2, _ = load_nifti('p2s2_denoised_7T_50K.nii.gz') 

In [None]:
bvals = np.loadtxt('bvals_7T.bval')
bvecs = np.loadtxt('bvecs_7T.bvec')

gtab = gradient_table(bvals, bvecs, b0_threshold=100)

#### Select 80th slice 

In [None]:
data_slice = data[..., 80:81, :]
data_slice_p2s = data_p2s[..., 80:81, :]
data_slice_p2s2 = data_p2s2[..., 80:81, :]

### Mask the data using the Median Otsu algorithm from DIPY

In [None]:
from dipy.segment.mask import median_otsu
_, mask = median_otsu(data, vol_idx=[0, 1])

In [None]:
data_slice[mask[..., 80:81]==0] = 0
data_slice_p2s[mask[..., 80:81]==0] = 0
data_slice_p2s2[mask[..., 80:81]==0] = 0

#### Use the DTI and DKI models for performing the cross-validation to evaluate goodness of fit

In [None]:
dti_model = dti.TensorModel(gtab)
dki_model = dki.DiffusionKurtosisModel(gtab)

#### Perform crossvalidation

In [None]:
dti_slice = xval.kfold_xval(dti_model, data_slice, 2)
dki_slice = xval.kfold_xval(dki_model, data_slice, 2)

dti_slice_p2s = xval.kfold_xval(dti_model, data_slice_p2s, 2)
dki_slice_p2s = xval.kfold_xval(dki_model, data_slice_p2s, 2)

dti_slice_p2s2 = xval.kfold_xval(dti_model, data_slice_p2s2, 2)
dki_slice_p2s2 = xval.kfold_xval(dki_model, data_slice_p2s2, 2)

#### Get the arrays containing $R^2$ values for all voxels (skip voxels with `nan` values)

In [None]:
r2s_dti = []
for i in range(0, dti_slice.shape[0]):
    for j in range(0, dti_slice.shape[1]):
        for k in range(0, dti_slice.shape[2]):
            dti_r2 = stats.pearsonr(data_slice[i, j, k, :], dti_slice[i, j, k, :])[0]**2
            r2s_dti.append(dti_r2)
r2s_dti = np.array(r2s_dti)
r2s_dti = r2s_dti[~np.isnan(r2s_dti)]

r2s_dki = []
for i in range(0, dki_slice.shape[0]):
    for j in range(0, dki_slice.shape[1]):
        for k in range(0, dki_slice.shape[2]):
            dki_r2 = stats.pearsonr(data_slice[i, j, k, :], dki_slice[i, j, k, :])[0]**2
            r2s_dki.append(dki_r2)
r2s_dki = np.array(r2s_dki)
r2s_dki = r2s_dki[~np.isnan(r2s_dki)]

In [None]:
r2s_dti_p2s2 = []
for i in range(0, dti_slice_p2s2.shape[0]):
    for j in range(0, dti_slice_p2s2.shape[1]):
        for k in range(0, dti_slice_p2s2.shape[2]):
            dti_r2_p2s2 = stats.pearsonr(data_slice[i, j, k, :], dti_slice_p2s2[i, j, k, :])[0]**2
            r2s_dti_p2s2.append(dti_r2_p2s2)
r2s_dti_p2s2 = np.array(r2s_dti_p2s2)
r2s_dti_p2s2 = r2s_dti_p2s2[~np.isnan(r2s_dti_p2s2)]

r2s_dki_p2s2 = []
for i in range(0, dki_slice.shape[0]):
    for j in range(0, dki_slice.shape[1]):
        for k in range(0, dki_slice.shape[2]):
            dki_r2_p2s2 = stats.pearsonr(data_slice[i, j, k, :], dki_slice_p2s2[i, j, k, :])[0]**2
            r2s_dki_p2s2.append(dki_r2_p2s2)
r2s_dki_p2s2 = np.array(r2s_dki_p2s2)
r2s_dki_p2s2 = r2s_dki_p2s2[~np.isnan(r2s_dki_p2s2)]

In [None]:
r2s_dti_p2s = []
for i in range(0, dti_slice.shape[0]):
    for j in range(0, dti_slice.shape[1]):
        for k in range(0, dti_slice.shape[2]):
            dti_r2_p2s = stats.pearsonr(data_slice[i, j, k, :], dti_slice_p2s[i, j, k, :])[0]**2
            r2s_dti_p2s.append(dti_r2_p2s)
r2s_dti_p2s = np.array(r2s_dti_p2s)
r2s_dti_p2s = r2s_dti_p2s[~np.isnan(r2s_dti_p2s)]

r2s_dki_p2s = []
for i in range(0, dki_slice.shape[0]):
    for j in range(0, dki_slice.shape[1]):
        for k in range(0, dki_slice.shape[2]):
            dki_r2_p2s = stats.pearsonr(data_slice[i, j, k, :], dki_slice_p2s[i, j, k, :])[0]**2
            r2s_dki_p2s.append(dki_r2_p2s)
r2s_dki_p2s = np.array(r2s_dki_p2s)
r2s_dki_p2s = r2s_dki_p2s[~np.isnan(r2s_dki_p2s)]

### Append all scores in a dataframe for plotting

In [None]:
import pandas as pd

df = pd.DataFrame({'Raw DTI':r2s_dti,
                   'Raw DKI':r2s_dki,
                   'Patch2Self2 DTI':r2s_dti_p2s2,
                   'Patch2Self2 DKI':r2s_dki_p2s,
                   'Patch2Self DTI':r2s_dti_p2s2,
                   'Patch2Self DKI':r2s_dki_p2s})

#### Subtract the $R^2$ scores of Noisy Data from both P2S and P2S2 for both DTI and DKI models

In [None]:
df_diff = pd.DataFrame({'(P2S2 - Noisy) DTI':r2s_dti_p2s2 - r2s_dti,
                        '(P2S - Noisy) DTI':r2s_dti_p2s - r2s_dti,
                        '(P2S2 - Noisy) DKI':r2s_dki_p2s2 - r2s_dki,
                        '(P2S - Noisy) DKI':r2s_dki_p2s - r2s_dki})

#### Make strip-plots

In [None]:
%matplotlib qt
import seaborn as sns
sns.set(style="white")
ax = sns.stripplot(x="variable", y="value", data=pd.melt(df_diff), palette="Set2")