In [12]:
"""
Taken from
https://www.nipreps.org/qc-book/auto-qc/classification.html
"""

from mriqc_learn.datasets import load_dataset
from mriqc_learn.models.preprocess import SiteRobustScaler

(train_x, train_y), _ = load_dataset(split_strategy="none")
train_x = train_x.drop(columns=[
    "size_x",
    "size_y",
    "size_z",
    "spacing_x",
    "spacing_y",
    "spacing_z",
])
numeric_columns = train_x.columns.tolist()
train_x["site"] = train_y.site

# Harmonize between sites
scaled_x = SiteRobustScaler(unit_variance=True).fit_transform(train_x)

train_y = train_y[["rater_3"]].values.squeeze().astype(int)
print(f"Discard={100 * (train_y == -1).sum() / len(train_y):.2f}%.")
print(f"Doubtful={100 * (train_y == 0).sum() / len(train_y):.2f}%.")
print(f"Accept={100 * (train_y == 1).sum() / len(train_y):.2f}%.")
train_y += 1
train_y = (~train_y.astype(bool)).astype(int)
print(100 * train_y.sum() / len(train_y))  # Let's double check we still have 15% discard

Discard=14.17%.
Doubtful=1.54%.
Accept=84.29%.
14.168937329700272


In [13]:
"""
Taken from
https://www.nipreps.org/qc-book/auto-qc/classification.html
"""
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier as RFC
from mriqc_learn.models.preprocess import DropColumns

model = Pipeline([
    ("drop_site", DropColumns(drop=["site"])),
    ("pca", PCA(n_components=4)),
    (
        "rfc", 
        RFC(
          bootstrap=True,
          class_weight=None,
          criterion="gini",
          max_depth=10,
          max_features="sqrt",
          max_leaf_nodes=None,
          min_impurity_decrease=0.0,
          min_samples_leaf=10,
          min_samples_split=10,
          min_weight_fraction_leaf=0.0,
          n_estimators=400,
          oob_score=True,
      ),
    ),
])

In [14]:
"""
Taken from
https://www.nipreps.org/qc-book/auto-qc/classification.html
"""
from mriqc_learn.model_selection import split
from sklearn.model_selection import cross_val_score

# Define a splitting strategy
outer_cv = split.LeavePSitesOut(1, robust=True)

cv_score = cross_val_score(
    model,
    X=scaled_x,
    y=train_y,
    cv=outer_cv,
    scoring="roc_auc",
)
print(f"AUC (cross-validated, {len(cv_score)} folds): {cv_score.mean()} ± {cv_score.std()}")

AUC (cross-validated, 15 folds): 0.6248064225804376 ± 0.22580627083999089


In [15]:
"""
Taken from
https://www.nipreps.org/qc-book/auto-qc/classification.html
"""
(test_x, test_y), _ = load_dataset(
    "ds030",
    split_strategy="none",
)
test_x = test_x.drop(columns=[
    "size_x",
    "size_y",
    "size_z",
    "spacing_x",
    "spacing_y",
    "spacing_z",
])
test_x["site"] = test_y.site

# Discard datasets with ghost
has_ghost = test_y.has_ghost.values.astype(bool)
test_y = test_y[~has_ghost]

# Harmonize between sites
scaled_test_x = SiteRobustScaler(unit_variance=True).fit_transform(
    test_x[~has_ghost]
)

test_y = test_y[["rater_1"]].values.squeeze().astype(int)
print(f"Discard={100 * (test_y == -1).sum() / len(test_y):.2f}%.")
print(f"Doubtful={100 * (test_y == 0).sum() / len(test_y):.2f}%.")
print(f"Accept={100 * (test_y == 1).sum() / len(test_y):.2f}%.")
test_y += 1
test_y = (~test_y.astype(bool)).astype(int)
print(100 * test_y.sum() / len(test_y))

Discard=12.90%.
Doubtful=66.36%.
Accept=20.74%.
12.903225806451612


In [16]:
"""
Taken from
https://www.nipreps.org/qc-book/auto-qc/classification.html
"""
classifier = model.fit(X=train_x, y=train_y)
predicted_y = classifier.predict(scaled_test_x)

from sklearn.metrics import roc_auc_score as auc

print(f"AUC on DS030 is {auc(test_y, predicted_y)}.")

AUC on DS030 is 0.5.


In [17]:
"""
Taken from
https://www.nipreps.org/qc-book/auto-qc/classification.html
"""
from mriqc_learn.models.production import init_pipeline
from sklearn.metrics import classification_report

# Reload the datasets as MRIQC's model will want to see the removed columns
(train_x, sites), _ = load_dataset(split_strategy="none")
train_x["site"] = sites.site
(test_x, sites), _ = load_dataset("ds030", split_strategy="none")
test_x["site"] = sites.site

#print(test_x)
#print(test_y)

cv_score_3 = cross_val_score(
    init_pipeline(),
    X=train_x,
    y=train_y,
    cv=outer_cv,
    scoring="roc_auc",
    n_jobs=16,
)

print(f"AUC (cross-validated, {len(cv_score_3)} folds): {cv_score_3.mean()} ± {cv_score_3.std()}")

pred_y_3 = init_pipeline().fit(X=train_x, y=train_y).predict(test_x[~has_ghost])
print(f"AUC on DS030 is {auc(test_y, pred_y_3)}.")
print(classification_report(test_y, pred_y_3, zero_division=0))



AUC (cross-validated, 15 folds): 0.8653745780432359 ± 0.1755699843561357




AUC on DS030 is 0.6527777777777778.
              precision    recall  f1-score   support

           0       0.91      0.98      0.94       189
           1       0.75      0.32      0.45        28

    accuracy                           0.90       217
   macro avg       0.83      0.65      0.70       217
weighted avg       0.89      0.90      0.88       217



In [18]:
print(test_y)

[1 0 0 1 0 0 1 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 1 0 1 0 0 0 0
 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1
 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0
 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 1 0 0 1 0 0
 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [47]:
#load in the wrap_x and wrap_y csvs for testing
import pandas as pd
from mriqc_learn.models.production import init_pipeline
from sklearn.metrics import classification_report
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)


wrap_x = "/fs5/p_masi/kimm58/MRIQC_experiments/wrap_anat_mriqc_iqms.csv"
wrap_y = "/fs5/p_masi/kimm58/MRIQC_experiments/wrap_y.csv"
wrap_x = pd.read_csv(wrap_x)
wrap_y = pd.read_csv(wrap_y)
wrap_x['site'] = 'wrap'
wrap_y['site'] = 'wrap'
#remove rows from y if they do not appear in x
wrap_y = wrap_y[wrap_y["bids_name"].isin(wrap_x["bids_name"])]
#rename the bids_name column to subject_id
#wrap_x = wrap_x.rename(columns={"bids_name": "subject_id"})
#wrap_y = wrap_y.rename(columns={"bids_name": "subject_id"})

#make sure the columns are in the same order (ordered by bids name))
wrap_x = wrap_x.sort_values(by="bids_name")
wrap_y = wrap_y.sort_values(by="bids_name")
og_wrap_x = wrap_x.copy()
og_wrap_y = wrap_y.copy()
#reset the index
wrap_x = wrap_x.reset_index(drop=True)
wrap_y = wrap_y.reset_index(drop=True)

#drop the bids_name column
wrap_x = wrap_x.drop(columns=["bids_name"])

wrap_y = wrap_y[["rater_1"]].values.squeeze().astype(int)
print(f"Discard={100 * (wrap_y == -1).sum() / len(wrap_y):.2f}%.")
print(f"Doubtful={100 * (wrap_y == 0).sum() / len(wrap_y):.2f}%.")
print(f"Accept={100 * (wrap_y == 1).sum() / len(wrap_y):.2f}%.")
wrap_y += 1
wrap_y = (~wrap_y.astype(bool)).astype(int)
print(100 * wrap_y.sum() / len(wrap_y))

#do they prediction on WRAP
pipeline = init_pipeline().fit(X=train_x, y=train_y)
#pred_y_wrap = init_pipeline().fit(X=train_x, y=train_y).predict(wrap_x)
pred_y_wrap = pipeline.predict(wrap_x)
print(f"AUC on WRAP is {auc(wrap_y, pred_y_wrap)}.")
print(classification_report(wrap_y, pred_y_wrap, zero_division=0))

Discard=0.13%.
Doubtful=2.50%.
Accept=97.37%.
0.1282051282051282
AUC on WRAP is 0.7487163029525031.
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      1558
           1       0.20      0.50      0.29         2

    accuracy                           1.00      1560
   macro avg       0.60      0.75      0.64      1560
weighted avg       1.00      1.00      1.00      1560



In [48]:
#print the confusion matrix
from sklearn.metrics import confusion_matrix
print(confusion_matrix(wrap_y, pred_y_wrap))

[[1554    4]
 [   1    1]]


In [49]:
#print the indices of the false positives, the false negatives, and the true negatives
fp = []
fn = []
tn = []
for i in range(len(wrap_y)):
    if wrap_y[i] == 1 and pred_y_wrap[i] == 0: #manual is labeled bad, model labes as good
        fp.append(i)
    elif wrap_y[i] == 0 and pred_y_wrap[i] == 1: #manual is labeled good, model labels as bad
        fn.append(i)
    elif wrap_y[i] == 1 and pred_y_wrap[i] == 1: #both labeled as bad
        tn.append(i)

print("False Positives:")
for i in fp:
    print(og_wrap_x.iloc[i]['bids_name'])
print("\nFalse Negatives:")
for i in fn:
    print(og_wrap_x.iloc[i]['bids_name'])
print("\nTrue Negatives:")
for i in tn:
    print(og_wrap_x.iloc[i]['bids_name'])

False Positives:
sub-wrap0860_ses-year4_acq-unknown_T1w

False Negatives:
sub-wrap0224_ses-baseline_acq-MPRAGE_T1w
sub-wrapL0121_ses-year2_acq-unknown_T1w
sub-wrapL0248_ses-baseline_acq-unknown_T1w
sub-wrapL0248_ses-year2_acq-SPGR_T1w

True Negatives:
sub-wrap0732_ses-year2_acq-SPGR_T1w


In [70]:
def get_slice(img, dim, slicenum):
    if dim==0:
        return img[slicenum,:,:]
    elif dim==1:
        return img[:,slicenum,:]
    else:
        return img[:,:,slicenum]

def get_aspect_ratio(dim, vox_dim):
    if dim == 0:
        vox_ratio = vox_dim[2]/vox_dim[1]
    elif dim == 1:
        vox_ratio = vox_dim[2]/vox_dim[0]
    elif dim == 2:
        vox_ratio = vox_dim[1]/vox_dim[0]
    return vox_ratio

def make_LAS(nii):
    """
    Given a nibabel image, return the LAS orientated image.
    """
    # Get the data
    data = nii.get_fdata()
    # Get the affine
    aff = nii.affine
    # Get the orientation
    orient = nib.aff2axcodes(aff)
    #print(orient)
    
    # Reorient to LAS
    if orient != ('L', 'A', 'S'):
        # Get the current orientation
        ornt = nib.orientations.io_orientation(aff)
        # Define the target orientation (LAS)
        target_ornt = nib.orientations.axcodes2ornt(('L', 'A', 'S'))
        # Get the transformation matrix
        transform = nib.orientations.ornt_transform(ornt, target_ornt)
        # Reorient the data
        data = nib.orientations.apply_orientation(data, transform)
        # Update the affine
        aff = nib.orientations.inv_ornt_aff(transform, data.shape)
    
    # Create a new Nifti image
    new_nii = nib.Nifti1Image(data, aff)
    return new_nii

#for SLANT segmentation
def setup_and_make_png_seg(imgfile, atlas_name, outfile, axis=-1, slice=None, return_slice=False, save=False):
    #load in image file
    img = make_LAS(nib.load(imgfile))
    imgdata = np.squeeze(img.get_fdata()[:,:,:])
    if len(imgdata.shape) == 4:
        imgdata = imgdata[:,:,:,0]

    #create the png
    if axis == -1:
        create_seg_png(imgdata, outfile, atlas_name, img.header)
    else:
        img_slice = create_seg_png_slice(imgdata, outfile, atlas_name, img.header, save=save, axis=axis, slice=slice, return_slice=return_slice)
        if return_slice:
            #print("Returning slice", img_slice.shape)
            return img_slice

def create_seg_png(imgdata, outfile, atlas_name, imghd, save=False):
    #create the plt figure
    f, ax = plt.subplots(3,3,figsize=(10.5, 8), constrained_layout=True)
    imgdata = np.clip(imgdata, 0, np.percentile(imgdata, 99))
    #loop through sag, coronal, axial slices
    for dim in range(3):
        #get the center slice
        slice=imgdata.shape[dim]//2
        for i,slice_offput in enumerate(range(-10, 20, 10)):
            #get the aspect ratio for plotting purposes
            vox_dims = imghd.get_zooms()
            ratio = get_aspect_ratio(dim, vox_dims)
            #get the slices we want to show
            try:
                img_slice = np.rot90(get_slice(imgdata, dim, slice+slice_offput), k=1)
                #seg_slice = np.rot90(get_slice(segdata, dim, slice+slice_offput), k=1)
            #if this doesnt work, get the maximum slice in that dimension
            except:
                if i == 0:
                    img_slice = np.rot90(get_slice(imgdata, dim, 0), k=1)
                    #seg_slice = np.rot90(get_slice(segdata, dim, 0), k=1)
                elif i == 2:
                    img_slice = np.rot90(get_slice(imgdata, dim, -1), k=1)
                    #seg_slice = np.rot90(get_slice(segdata, dim, -1), k=1)
            #plot the slices
            ax[dim,i].imshow(img_slice, cmap='gray', aspect=ratio)
            #ax[dim,i].imshow(seg_slice, alpha=0.6, cmap=cmap, aspect=ratio, interpolation='nearest')
            ax[dim,i].axis('off')
    
    if save:
        #save the slices
        #print(outdir)
        f.suptitle(atlas_name)
        f.savefig(outfile, bbox_inches='tight')
        #close the figure
        plt.close(f)

def create_seg_png_slice(imgdata, outfile, atlas_name, imghd, axis=0, save=False, slice=None, return_slice=False):
    #create the plt figure
    f, ax = plt.subplots(1,1,figsize=(5, 5), constrained_layout=True)
    imgdata = np.clip(imgdata, 0, np.percentile(imgdata, 99))
    
    if not slice:
        slice=imgdata.shape[axis]//2
    else:
        slice = slice
    vox_dims = imghd.get_zooms()
    ratio = get_aspect_ratio(axis, vox_dims)
    img_slice = np.rot90(get_slice(imgdata, axis, slice), k=1)
    ax.imshow(img_slice, cmap='gray', aspect=ratio)
    ax.axis('off')

    if return_slice and not save:
        return img_slice
    if save:
        #save the slices
        #print(outdir)
        f.suptitle(atlas_name)
        f.savefig(outfile, bbox_inches='tight')
        #close the figure
        plt.close(f)

#create the confusion matrix/table

#first plot slices of the brains that you want to show (a false positive, a false negative, and a true negative)
    #maybe have all the false negatives as a supplementary figure

import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib

# true negative
true_negative = "/nfs2/harmonization/BIDS/WRAP/sub-wrap0732/ses-year2/anat/sub-wrap0732_ses-year2_acq-SPGR_T1w.nii.gz"
tn1_slice = setup_and_make_png_seg(true_negative, "sub-wrap0732_ses-year2_acq-SPGR_T1w.nii.gz", "true_negative.png", axis=0, slice=135, return_slice=True, save=True)

# false positive
false_positive = "/nfs2/harmonization/BIDS/WRAP/sub-wrap0860/ses-year4/anat/sub-wrap0860_ses-year4_acq-unknown_T1w.nii.gz"
fp1_slice = setup_and_make_png_seg(false_positive, "sub-wrap0860_ses-year4_acq-unknown_T1w.nii.gz", "false_positive.png", axis=2, slice=135, return_slice=True, save=True)

#false negative examples
fn1 = "/nfs2/harmonization/BIDS/WRAP/sub-wrap0224/ses-baseline/anat/sub-wrap0224_ses-baseline_acq-MPRAGE_T1w.nii.gz"
fn2 = "/nfs2/harmonization/BIDS/WRAP/sub-wrapL0121/ses-year2/anat/sub-wrapL0121_ses-year2_acq-unknown_T1w.nii.gz"
fn3 = "/nfs2/harmonization/BIDS/WRAP/sub-wrapL0248/ses-baseline/anat/sub-wrapL0248_ses-baseline_acq-unknown_T1w.nii.gz"
#fn4 = "/nfs2/harmonization/BIDS/WRAP/sub-wrapL02/ses-year2/anat/sub-wrapL0260_ses-year2_acq-unknown_T1w.nii.gz"
fn1_slice = setup_and_make_png_seg(fn1, "sub-wrap0224_ses-baseline_acq-MPRAGE_T1w.nii.gz", "false_negative1.png", axis=2, slice=160, return_slice=True, save=True)
fn2_slice = setup_and_make_png_seg(fn2, "sub-wrapL0121_ses-year2_acq-unknown_T1w.nii.gz", "false_negative2.png", axis=2, slice=135, return_slice=True, save=True)
fn3_slice = setup_and_make_png_seg(fn3, "sub-wrapL0248_ses-baseline_acq-unknown_T1w.nii.gz", "false_negative3.png", axis=2, slice=150, return_slice=True, save=True)



# f,ax = plt.subplots(1,5,figsize=(20,5))
# for i, slice in enumerate([tn1_slice, fp1_slice, fn1_slice, fn2_slice, fn3_slice]):
#     print(slice.shape)
#     ax[i].imshow(slice)
#     ax[i].axis('off')
