# Results: WSI Inference Analysis

In [None]:
# Imports
from pandas import read_csv, DataFrame
from glob import glob
from tqdm.notebook import tqdm
import numpy as np
from typing import Tuple, List
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from colorama import Fore, Style
from multiprocessing import Pool
import networkx as nx
from libpysal import weights
import seaborn as sns
from ipywidgets import interact, Dropdown
from matplotlib.patches import Rectangle

from os import makedirs
from os.path import join, isdir, isfile

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import RFE
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import cohen_kappa_score, confusion_matrix

from nft_helpers.utils import get_filename, load_yaml, imread
from nft_helpers.girder_dsa import get_tile_metadata, login
from nft_helpers.roi_utils import read_roi_txt_file

# Prepare parameters.
cf = load_yaml()
gc = login(join(cf.dsaURL, 'api/v1'), username=cf.user, password=cf.password)
np.set_printoptions(suppress=True)

save_dir = join(cf.datadir, 'results/wsi-inference')
makedirs(save_dir, exist_ok=True)

COLORS = [f'#{color}' for color in cf.colors]
LINESTYLES = ['solid', 'dotted', 'dashed', 'dashdot', (5, (10, 3))]


def plot_cm(cm: np.array, labels: List[str], title: str = '', 
            figsize: Tuple[int, int] = (4, 4)):
    """Plot confusion matrix.
    
    Args:
        cm: Confusion matrix with rows are true and columns are predictions.
        labels: Labels of the confusion matrix.
        title: Title of plot.
        figsize: Size of figure.
        
    """
    cm = DataFrame(cm, index=labels, columns=labels)
    
    plt.figure(figsize=figsize)
    ax = sns.heatmap(
        cm, cmap='viridis', annot=True, cbar=False, fmt=".0f", square=True, 
        linewidths=1, linecolor='black', annot_kws={"size": 18}
    )
    ax.xaxis.set_ticks_position("none")
    ax.yaxis.set_ticks_position("none")
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=14)
    
    plt.ylabel('True', fontsize=18, fontweight='bold')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=0, fontsize=14)
    
    plt.xlabel('Predicted', fontsize=18, fontweight='bold')
    plt.title(title, fontsize=14, fontweight='bold')
    
    return ax


def plot_tri_heatmap(data: np.array, labels: list, figsize: (int, int) = (7,7), 
                     title: str = None, save_fp: str = None, **kwargs: dict
                    ) -> np.array:
    """Create a correlation heatmap from an array, only plotting the bottom half
    triangle of the heatmap. 
    
    Args:
        data: Data to plot, the data is assumed to be symetrical and only the 
            bottom left triangle of the heatmap will be shown.
        labels: Labels on both axis, ordered from top to bottom and left to 
            right.
        figsize: (width, height) of figure.
        title: Title of figure.
        save_fp: Filepath to save figure to.
        kwargs: Keyword arguments to pass to seaborn.heatmap()
       
    Returns:
        The input data array.
        
    """
    fig, ax = plt.subplots(figsize=figsize)
    mask = np.triu(np.ones_like(data), k=1)
    ax = sns.heatmap(data, annot=True, mask=mask, xticklabels=labels, 
                     yticklabels=labels, ax=ax, **kwargs)
    ax.set_xticks(ax.get_xticks(), labels, size=16)
    ax.set_yticks(ax.get_yticks(), labels, size=16, rotation=360)
    ax.set_facecolor('k')
    
    if title is not None:
        plt.title(title, fontsize=18, weight='bold')
        
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=16)
        
    return ax

## Parameters
Parameters to apply in the entire notebook.

In [None]:
RANDOM_STATE = 64  # seed random state for reproducibility
RFE_FEATURES = 20  # number of features to use for random forest classifier

## WSI Inference Time

In [None]:
# Map wsi filename to DSA id.
wsis = read_csv('csvs/wsis.csv')
wsi_ids = {get_filename(r.wsi_name): r.wsi_id for _, r in wsis.iterrows()}

# Compile time it took to run WSI inference - use a single cohort.
inf_dir = join(cf.datadir, 'wsi-inference/results/inference-cohort-1')

# Compile the results as dataframe.
wsi_inf_df = []

for log_fp in tqdm(glob(join(inf_dir, 'logs/*.txt'))):
    # WSI filename.
    fn = get_filename(log_fp)
    
    with open(log_fp, 'r') as fh:
        time_logs = fh.readlines()
        
    # Get GPU info and times.
    gpus = None
    gpu_flag = False
    time_flag = False
    times = {}

    # Get GPU info.
    for ln in time_logs:
        if ln.startswith('GPU'):
            gpu_flag = True
        elif gpu_flag:
            gpus = ln.strip()
            gpu_flag = False
        elif ln.startswith('Times'):
            time_flag = True
        elif time_flag:
            time, seconds = ln.strip().split(': ')
            times[time] = int(float(seconds) / 60)
            
    # Get info on the WSI
    wsi_id = wsi_ids[fn]
    wsi_metadata = get_tile_metadata(gc, wsi_id)
    w, h = wsi_metadata['sizeX'], wsi_metadata['sizeY']
    
    # Read the tissue mask.
    mask = imread(join(
        cf.datadir, 'wsi-inference/tissue-masks/masks', fn + '.png'
    ), grayscale=True)
    
    # Scale factor of the low res tissue mask.
    sf = w / mask.shape[1]
    
    # Get the number of tissue pixels as full resolution.
    num_pos = np.count_nonzero(mask) * (sf * sf)
    
    # Convert to area in millimeters squared.
    tissue_area = num_pos * (wsi_metadata['mm_x'] * wsi_metadata['mm_y'])
    wsi_area = (w * h) * (wsi_metadata['mm_x'] * wsi_metadata['mm_y'])
    
    # Read the prediction info - mainly how many annotations there were.
    preds = read_roi_txt_file(join(inf_dir, 'inference', fn + '.txt'))
    
    # Add all the info.
    wsi_inf_df.append([
        fn, wsi_id, gpus, times['Tiling'], times['Predicting'], 
        times['Merging predictions'], times['Cleaning up'], times['Total time'],
        wsi_area, tissue_area, len(preds)
    ])
    
wsi_inf_df = DataFrame(wsi_inf_df, columns=[
    'wsi_name', 'wsi_id', 'GPU', 'Tiling', 'Predicting', 
    'Merging predictions', 'Clean up', 'Total', 'WSI Area (mm x mm)', 
    'Tissue Area (mm x mm)', '# of Predictions'
])

In [None]:
# Print some information on the numbers.
print(f'Slowest WSI inference time: {wsi_inf_df.Total.max()} minutes.')
print(f'Fastest WSI inference time: {wsi_inf_df.Total.min()} minutes.')

# Averages, by GPU.
print('\nAverages of total times:')
print(f'  (all) {wsi_inf_df.Total.mean():.2f} ± ' + \
      f'{wsi_inf_df.Total.std():.2f} minutes.')

gpu_df = wsi_inf_df[wsi_inf_df.GPU == 'NVIDIA RTX A4500']
print(f'  (A4500) {gpu_df.Total.mean():.2f} ± ' + \
      f'{gpu_df.Total.std():.2f} minutes.')

gpu_df = wsi_inf_df[wsi_inf_df.GPU == 'NVIDIA RTX A5000']
print(f'  (A5000) {gpu_df.Total.mean():.2f} ± ' + \
      f'{gpu_df.Total.std():.2f} minutes.')

In [None]:
# Plot scatter / line plots of the data.
wsi_inf_df = wsi_inf_df.sort_values(by='Total')


def save_fig(fp: str, close: bool = True):
    """Save matplotlib figure.
    
    Args:
        fp: Filepath to save figure to.
        close: If True, close the figure so it does not show.
        
    """
    plt.savefig(fp, dpi=300, bbox_inches='tight')
    
    if close:
        plt.close()


def plot_times(
    df: DataFrame, cols: list, figsize: Tuple[int, int] = (10, 5), 
    ylabel: str = 'Times (minutes)', xlabel: str = 'WSIs', title: str = None
):
    """Plot time scatter / line figures.
    
    Args:
        df: Input dataframe.
        cols: List of columns to use to plot in y-axis.
        figsize: Size of figure.
        ylabel: Label on the y axis.
        xlabel: Label on the x axis.
        title: Figure title.
        
    Return:
        Figure axis.
        
    """    
    fig, ax = plt.subplots(figsize=figsize)

    for i, t in enumerate(cols):
        y = df[t].tolist()
        x = np.arange(1, len(y)+1)

        plt.plot(x, y, c=COLORS[i], linestyle=LINESTYLES[i])
        
    plt.xlim([0, len(df)])
    plt.legend(cols, fontsize=12)
    
    plt.ylabel(ylabel, fontsize=18, fontweight='bold')
    plt.xlabel(xlabel, fontsize=18, fontweight='bold')
    
    if title is not None:
        plt.title(title, fontsize=18, fontweight='bold')
        
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_linewidth(3)
    ax.spines['left'].set_linewidth(3)
    ax.spines['right'].set_linewidth(3)
    ax.tick_params(axis='both', which='both', direction='out', length=10, 
                   width=3)
    plt.yticks(fontweight='bold', fontsize=16)
    plt.xticks(fontweight='bold', fontsize=16)
    
    return ax


times = ['Total', 'Tiling', 'Predicting', 'Merging predictions', 'Clean up']
ax = plot_times(wsi_inf_df, cols=times)
plt.show()
save_fig(join(save_dir, 'wsi-inference-times.png'))

In [None]:
# Calculate the Pearson correlation of the different times with different features.
print('Pearson correlation with tissue area.')
for t in times:
    p = pearsonr(wsi_inf_df[t].tolist(), wsi_inf_df['Tissue Area (mm x mm)'])[1]
    
    print(f'  {t} p-value = {p:.4f}')
    
print('\nPearson correlation with number of predictions.')
for t in times:
    p = pearsonr(wsi_inf_df[t].tolist(), wsi_inf_df['# of Predictions'])[1]
    
    print(f'  {t} p-value = {p:.4f}')

## Load Imaging Features for Cases

In [None]:
# For each case - create a feature vector and save as dataframe.
case_fts_fp = join(save_dir, 'inference-features.csv')
inf_dir = join(cf.datadir, 'wsi-inference')

cases = read_csv('csvs/cases.csv')
wsis = read_csv('csvs/wsis.csv')

# WSIs in the three cohorts.
wsis = wsis[wsis.cohort.isin((
    'Inference-Cohort-1', 'Inference-Cohort-2', 'External-Cohort'
))]

regions = ['Hippocampus', 'Amygdala', 'Temporal cortex', 'Occipital cortex']
stages = [0, 1, 2, 3, 4, 5, 6]
dataset_map = {
    'Inference-Cohort-1': 'train', 'Inference-Cohort-2': 'Emory test',
    'External-Cohort': 'UC Davis test'
}

# Radius in microns used to calculate average clustering coefficient.
radii = np.arange(150, 600, 50)


def extract_case_features(case, fov_dir):
    """Extract a set of case features."""
    # WSIs in this case.
    case_wsis = wsis[wsis.case == case]
    
    # Case metadata.
    case_metadata = cases[cases.case == case].iloc[0]
    
    # Some cases have two hippocampuse slides - take only one.
    if len(case_wsis) != 4:
        # I know that these cases all have a right hippocampus slide that we 
        # ignore (e.g. take the other one).
        case_wsis = case_wsis[case_wsis.region != 'Right hippocampus']
        
    # Rename hippocampus slides to a single unified name.
    case_wsis = case_wsis.replace({'Left hippocampus': 'Hippocampus', 
                                   'Right hippocampus': 'Hippocampus'})
    case_wsis = case_wsis.sort_values(by='region')
    
    # Get dataset this case belongs to.
    dataset = dataset_map[case_metadata.cohort]
    
    # Braak stage - ground truth.
    stage = case_metadata.Braak_stage
    
    # For cases with intermediate Braak stage 1-2, convert it to 2.
    stage = 2 if stage == '1-2' else int(stage)
        
    # Add age, for calculation bind the 90+ to 90.
    age = case_metadata.age_at_death
    
    if age == '90+':
        age = 90
    else:
        age = int(age)
        
    abc = case_metadata.ABC
    
    try:
        abc = int(abc)
    except:
        abc = -1
        
    # Add demographics.
    case_fts = [
        dataset, case, stage, age, 0 if case_metadata.sex == 'female' else 1,
        abc
    ]
    
    for region in regions:
        # Get region WSI.
        r = case_wsis[case_wsis.region == region].iloc[0]
        fn = get_filename(r.wsi_name)
        
        # Get tile metadata.
        img_metadata = get_tile_metadata(gc, r.wsi_id)
        
        # FOV size in pixels for this WSI, FOV is 4mm^2 by area and it is a 
        # square FOV.
        fov_w = int(2 / img_metadata['mm_x'])
        fov_h = int(2 / img_metadata['mm_y'])
        
        # Calculate the radii in pixels.
        px_radii = [int(r / 1000 / img_metadata['mm_x']) for r in radii]
        
        # Get tissue area from the tissue mask.
        mask = imread(join(inf_dir, f'tissue-masks/masks/{fn}.png'))
        
        # Get the denominator as the amount of tissue in mm x mm.
        h, w = mask.shape[:2]
        
        # Pixel area scale factor (low res -> high res)
        sf = (img_metadata['sizeY'] / h) * (img_metadata['sizeX'] / w)
        
        # Convert to scale factor in mm^2
        sf *= img_metadata['mm_x'] * img_metadata['mm_y']

        # Denominator is the mm^2 area that contains tissue.
        den = np.count_nonzero(mask) * sf
        
        # Read the NFTs dectected.
        preds = read_roi_txt_file(join(
            inf_dir, f'results/{r.cohort.lower()}-additional-rois/inference/'
            f'{fn}.txt'
        ))
        
        if not len(preds):
            raise Exception('No predictions found, logic not found for this.')
        
        # Add density of Pre-NFTs & iNFTs in the tissue.
        case_fts.append(len(preds[preds[:, 0] == 0]) / den)
        case_fts.append(len(preds[preds[:, 0] == 1]) / den)
        
        # Read the FOV info info or calculate them.
        wsi_fov_dir = join(fov_dir, fn)
        
        if isdir(wsi_fov_dir):
            highest_fov = {
                0: read_csv(join(wsi_fov_dir, '0.csv')),
                1: read_csv(join(wsi_fov_dir, '1.csv'))
            }
        else:
            makedirs(wsi_fov_dir, exist_ok=True)
            
            # Find the FOV with the highest density of each type NFT
            # Convert preds to a geodataframe
            geopreds = []

            for pred in preds:
                label, x1, y1, x2, y2 = pred[:5]

                geopreds.append([
                    label, x1, y1, x2, y2, Point((x1 + x2) / 2, (y1 + y2) / 2)
                ])

            geopreds = GeoDataFrame(
                geopreds, columns=['label', 'x1', 'y1', 'x2', 'y2', 'geometry']
            )

            # Check FOVs with some overlap to catch highest FOV.
            xys = []

            for x in range(0, img_metadata['sizeX'], int(fov_w / 2)):
                for y in range(0, img_metadata['sizeY'], int(fov_h / 2)):
                    xys.append([x, y])

            # loop for each class
            highest_fov = {0: None, 1: None}

            for cls in (0, 1):
                highest_within = 0
                
                for xy in xys:
                    x, y = xy

                    # Create the FOV polygon.
                    x1, y1, x2, y2 = x, y, x + fov_w, y + fov_h

                    fov = corners_to_polygon(x1, y1, x2, y2)

                    # Calculate how many points are in the FOV
                    within = geopreds[geopreds.within(fov)]
                    within = within[within.label == cls]

                    if len(within) > highest_within:
                        highest_fov[cls] = within.copy()
                        highest_within = len(within)
                        
                if highest_fov[cls] is None:
                    highest_fov[cls] = DataFrame(
                        [],
                        columns=['label', 'x1', 'y1', 'x2', 'y2', 'geometry']
                    )
                    
                highest_fov[cls].to_csv(
                    join(wsi_fov_dir, f'{cls}.csv'), index=False
                )
                        
        # Add number of objects in the most populated FOV.
        case_fts.append(len(highest_fov[0]))
        case_fts.append(len(highest_fov[1]))
            
        # Calculate average clustering cofficient for different radii.
        for cls in (0, 1):
            coordinates = []

            for _, r in highest_fov[cls].iterrows():
                coordinates.append([(r.x1 + r.x2) / 2, (r.y1 + r.y2) / 2])
            
            coordinates = np.array(coordinates)
            
            if len(coordinates):
                for r in px_radii:
                    # https://networkx.org/documentation/stable/auto_examples/geospatial/plot_points.html
                    dist = weights.DistanceBand.from_array(
                        coordinates, threshold=r, silence_warnings=True
                    )

                    dist_graph = dist.to_networkx()

                    # Calculate average clustering of the graph.
                    mean_coef = nx.average_clustering(dist_graph)

                    case_fts.append(nx.average_clustering(dist_graph))
            else:
                case_fts.append(0)
        
    return case_fts


# Create location to save FOV files
fov_dir = join(save_dir, 'fovs')
makedirs(fov_dir, exist_ok=True)

# Features list: split imaging featurse from others.
cols = [
        'dataset', 'case', 'stage', 'age', 'sex', 'ABC'
    ]

# Add features
img_features = []

for region in regions:
    img_features.append(f'Pre-NFT density ({region})')
    img_features.append(f'iNFT density ({region})')

    img_features.append(f'Pre-NFT FOV count ({region})')

    img_features.append(f'iNFT FOV count ({region})')

    for lbl in ('Pre-NFT', 'iNFT'):
        for r in radii:
            img_features.append(f'{lbl} Clustering Coef (r={r}, {region})')

# Load data from file or create it.
if isfile(case_fts_fp):
    data = read_csv(case_fts_fp).fillna(0)
else:
    data = []

    with Pool(20) as pool:
        jobs = [
            pool.apply_async(
                func=extract_case_features,
                args=(case, fov_dir,)
            )
            for case in wsis.case.unique()
        ]

        # Run the jobs, does not return anything but saves results.
        for job in tqdm(jobs):
            data.append(job.get())

    data = DataFrame(data, columns=cols + img_features)
    data.to_csv(join(save_dir, 'inference-features.csv'), index=False)

## Recursive Feature Selection
Reduce the number of features used to train random forest classifier using recursive feature selection.

In [None]:
# Recursive Feature Eliminitation (RFE) with Logistic Regression.
train_data = data[data.dataset == 'Emory test']
X_train = train_data[img_features].to_numpy()
y_train = train_data.stage.tolist()

model = RandomForestClassifier(random_state=RANDOM_STATE)

# Chooose the top 20 features - manually tested this number.
rfe = RFE(model, n_features_to_select=RFE_FEATURES)

rfe = rfe.fit(X_train, y_train)

# summarize the selection of the attributes
selected_features = np.array(img_features)[rfe.support_].tolist()

# List of top 10 selected features.
X_train_selected = train_data[selected_features].to_numpy()

## Random Forest Classifier: Hyperparameter Tuning

In [None]:
# Hyperparameter tune Random Forest Classifier.
params = {
    'n_estimators': [
        int(x) for x in np.linspace(start = 200, stop = 2000, num = 10)
    ],
    'max_features': ['auto', 'sqrt'],
    'max_depth': [None] + [int(x) for x in np.linspace(10, 110, num = 11)],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4],
    'bootstrap': [True, False]
}

# Create model (Random forest classifier).
rfc = RandomForestClassifier(random_state=RANDOM_STATE, class_weight=None)

# Search for best set of parameters, automatically fits the best parameters.
gs_rfc = RandomizedSearchCV(
    rfc, params, scoring='balanced_accuracy', cv=3, n_jobs=20, verbose=0, 
    random_state=RANDOM_STATE, n_iter=100
)

# Search for best hyperparameters.
gs_rfc  = gs_rfc.fit(X_train_selected, y_train)

# Best estimator.
rfc = gs_rfc.best_estimator_

## Best Features

In [None]:
# Plot the best features.
importances = rfc.feature_importances_.tolist()
importances, fts = (list(t) for t in zip(*sorted(zip(importances, selected_features))))

ft_importances_df = []

for imp, ft in zip(importances, fts):
    ft_importances_df.append([ft, imp])
    
ft_importances_df = DataFrame(ft_importances_df, columns=['Feature', 'Importance'])
ft_importances_df.to_csv(join(save_dir, 'feature-impotances.csv'), index=False)
    
# only plot the top 10 features
n = len(importances)

if n > 10:
    importances = importances[n-10:]
    fts = fts[n-10:]

fig, ax = plt.subplots(figsize=(4,5))
y_pos = np.arange(len(fts))
ax.barh(y_pos, importances)
ax.set_yticks(y_pos, labels=fts)
plt.ylabel('Feature', fontweight='bold', fontsize=16)
plt.xlabel('Importance', fontweight='bold', fontsize=16)
ax.tick_params(axis='both', which='both', direction='out', length=10, width=2)
    
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_linewidth(2)
ax.spines['left'].set_linewidth(2)
plt.show()
save_fig(join(save_dir, 'feature-importance.png'), close=True)

## Confusion Matrix (Emory-Train Cohort)

In [None]:
# Confusion matrix for Emory-Train cohort.
stages = ['0', 'I', 'II', 'III', 'IV', 'V', 'VI']

# Predict stages on the Emory train cohort (test for Braak stages).
test_data = data[data.dataset == 'train']
X_test = test_data[selected_features].to_numpy()
y_test = test_data.stage.tolist()

y_test_pred = rfc.predict(X_test)

cm = confusion_matrix(y_test, y_test_pred)
k = cohen_kappa_score(y_test, y_test_pred, weights='quadratic')

ax = plot_cm(cm, stages, title=f"Emory Train Cohort (k={k:.2f})")

for i in range(len(stages)):
    ax.add_patch(Rectangle((i, i), 1, 1, fill=False, edgecolor='green', lw=3, hatch='/'))
plt.show()
save_fig(join(save_dir, 'Braak-cm-Emory-train.png'))

## Confusion Matrix (UC-Davis Cohort)

In [None]:
# Confusion matrix for UC-Davis cohort.
davis_data = data[data.dataset == 'UC Davis test']
X_davis = davis_data[selected_features].to_numpy()
y_davis = davis_data.stage.tolist()

y_davis_pred = rfc.predict(X_davis)

cm = confusion_matrix(y_davis, y_davis_pred, labels=[0, 1, 2, 3, 4, 5, 6])
k = cohen_kappa_score(y_davis, y_davis_pred, weights='quadratic', labels=[2, 3, 4, 5, 6])

ax = plot_cm(cm, stages, title=f"UC Davis Cohort (k={k:.2f})")

from matplotlib.patches import Rectangle

for i in range(len(stages)):
    ax.add_patch(Rectangle((i, i), 1, 1, fill=False, edgecolor='green', lw=3, hatch='/'))
plt.show()
save_fig(join(save_dir, 'Braak-cm-UC-Davis.png'))

In [None]:
# There is a big outlier - True = VI, predicted = 0 - identify this case.
for i, r in davis_data.reset_index(drop=True).iterrows():
    pred = y_davis_pred[i]
    
    if r.stage == 6 and pred == 0:
        display(r)


## Raters Agreement Heatmap

In [None]:
# Plot paired Braak stage Cohen's kappa heatmap.
wsis_df = read_csv('csvs/wsis.csv')
wsis_df = wsis_df[
    (wsis_df.cohort == 'Annotated-Cohort') & \
    (wsis_df.annotator_experience == 'expert')
]

experts = sorted(list(wsis_df.annotator.unique()))
raters = experts + ['ML']

# Add predicted Braak stage
test_data = test_data.copy()
test_data['ML'] = y_test_pred

for expert in wsis_df.annotator.unique():
    for i, case in enumerate(test_data.case.tolist()):
        stage = wsis_df[(wsis_df.annotator == expert) & (wsis_df.case == case)].iloc[0].Braak_stage
        
        test_data.loc[i, expert] = int(stage)

# Build the kappa array.
kappa_hm = np.zeros((len(raters), len(raters)))

for i, r1 in enumerate(raters):
    for j, r2 in enumerate(raters):
        k = cohen_kappa_score(
            test_data[r1].tolist(), 
            test_data[r2].tolist(),
            weights='quadratic',
            labels=[0, 1, 2, 3, 4, 5, 6]
        )
        kappa_hm[i, j] = k
        
# Mask the top half.
mask = np.triu(np.ones_like(kappa_hm))
kappas = kappa_hm[mask == 0]

kwargs = {'cmap': 'viridis', 'annot_kws': {"size":16}, 'linecolor': 'w', 
          'linewidths': 0}

ax = plot_tri_heatmap(
    kappa_hm, 
    ['E1', 'E2', 'E3', 'E4', 'E5', 'ML'], 
    figsize=(10,10), 
    title=f"Braak Stage Agreement (k={np.mean(kappas):.2f} " + u"\u00B1" + \
          f' {np.std(kappas):.2f})',
    **kwargs
)

plt.xlabel('Rater', fontsize=18, fontweight='bold')
plt.ylabel('Rater', fontsize=18, fontweight='bold')
plt.show()
save_fig(join(save_dir, 'Braak-stage-agreement-hm.png'))

In [None]:
# Save features, including which features were used to train.
ft_df = []

for feature in selected_features:
    ft_df.append([feature, 'Yes'])
    
for feature in img_features:
    if feature not in selected_features:
        ft_df.append([feature, 'No'])
        
ft_df = DataFrame(ft_df, columns=['Feature', 'Used for RF Training'])
ft_df.to_csv(join(save_dir, 'features.csv'), index=False)

## Confusion Matrix for Braak stages from Raters against Ground Truth Stage

In [None]:
# Plot confusion matrices - for experts against the original Braak stage.
iaa_df = read_csv('csvs/wsis.csv')
iaa_df = iaa_df[
    (iaa_df.cohort == 'Annotated-Cohort') & \
    (iaa_df.annotator_experience == 'expert')
]

iaa_cases = iaa_df.case.unique()

true = []

for case in iaa_cases:
    stage = cases[cases.case == case].iloc[0].Braak_stage
    
    if stage == '1-2':
        stage = 2
    else:
        stage = int(stage)

    true.append(stage)


def _plot_cm(annotator):
    """Plot confusion matrix of annotator stages to original stages."""
    pred = []
    
    for case in iaa_cases:
        stage = iaa_df[(iaa_df.case == case) & (iaa_df.annotator == annotator)].iloc[0].Braak_stage
        
        pred.append(int(stage))
        
    k = cohen_kappa_score(true, pred, weights='quadratic', labels=[0, 1, 2, 3, 4, 5, 6])

    # Confusion matrix.
    cm = confusion_matrix(true, pred, labels=[0, 1, 2, 3, 4, 5, 6])
    labels = ['0', '1', '2', '3', '4', '5', '6']
    cm = DataFrame(cm, index=labels, columns=labels)

    # Plot the confusion matrix.
    plt.figure(figsize=(4,4))
    ax = sns.heatmap(
        cm, cmap='viridis', annot=True, cbar=False, fmt=".0f", square=True, 
        linewidths=1, linecolor='black', annot_kws={"size": 18}
    )
    ax.xaxis.set_ticks_position("none")
    ax.yaxis.set_ticks_position("none")

    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=14)
    plt.ylabel('True', fontsize=18, fontweight='bold')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=0, fontsize=14)
    plt.xlabel('Pred', fontsize=18, fontweight='bold')
    plt.title(f"Confusion Matrix for Annotator {annotator}\nWeighted Cohen's Kappa: {k:.4f}", fontsize=14, fontweight='bold')
    plt.show()
    
    
    
_ = interact(
    _plot_cm, 
    annotator=Dropdown(options=sorted(list(iaa_df.annotator.unique())))
)

## Compare Pre-NFT / iNFT Density Between Cohorts

In [None]:
#
plot_df = []

for _, r in data.iterrows():
    # Add the density for this case.
    row = [r.dataset, r.case, r.stage, r.age, r.sex, 'Pre-NFT']
    row.extend([
        r['Pre-NFT density (Hippocampus)'],
        r['Pre-NFT density (Amygdala)'],
        r['Pre-NFT density (Temporal cortex)'],
        r['Pre-NFT density (Occipital cortex)'],
        r['Pre-NFT FOV count (Hippocampus)'],
        r['Pre-NFT FOV count (Amygdala)'],
        r['Pre-NFT FOV count (Temporal cortex)'],
        r['Pre-NFT FOV count (Occipital cortex)']
    ])
    plot_df.append(row)
    
    row = [r.dataset, r.case, r.stage, r.age, r.sex, 'iNFT']
    row.extend([
        r['iNFT density (Hippocampus)'],
        r['iNFT density (Amygdala)'],
        r['iNFT density (Temporal cortex)'],
        r['iNFT density (Occipital cortex)'],
        r['iNFT FOV count (Hippocampus)'],
        r['iNFT FOV count (Amygdala)'],
        r['iNFT FOV count (Temporal cortex)'],
        r['iNFT FOV count (Occipital cortex)']
    ])
    plot_df.append(row)
    
plot_df = DataFrame(
    plot_df, 
    columns=[
        'dataset', 'case', 'stage', 'age_at_death', 'sex', 'label', 
        'Density (Hippocampus)', 'Density (Amygdala)', 'Density (Temporal)',
        'Density (Occipital)', 'FOV count (Hippocampus)', 
        'FOV count (Amygdala)', 'FOV count (Temporal)', 'FOV count (occipital)'
    ]
)


def plot_grouped_plots(feature):
    """Plot grouped bar plot for a feature for Pre-NFT and iNFT"""
    df = plot_df[plot_df.label == 'Pre-NFT']
    
    y_max = plot_df[feature].max()
        
    fig = plt.figure(figsize=(12, 4))
    ax1 = plt.subplot(121)
    sns.boxplot(data=df, y=feature, x='stage', hue='dataset', ax=ax1, )
#                 errorbar='se', capsize=.1, zorder=5)
    plt.xlabel('Braak Stage', fontsize=16, fontweight='bold')
    plt.ylabel(feature, fontsize=16, fontweight='bold')
    plt.xticks(fontweight='bold', fontsize=12)
    plt.yticks(fontweight='bold', fontsize=12)
    plt.title('Pre-NFT', fontsize=16, fontweight='bold', y=1.15)
    plt.legend(ncol=3, fontsize=14, bbox_to_anchor=(0.55, 1.15), loc='upper center')
    ax1.tick_params(axis='both', which='both', direction='out', length=10, 
                    width=2)
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax1.spines['bottom'].set_linewidth(2)
    ax1.spines['left'].set_linewidth(2)
    
    df = plot_df[plot_df.label == 'iNFT']
    ax2 = plt.subplot(122, sharey=ax1)
    sns.boxplot(data=df, y=feature, x='stage', hue='dataset', ax=ax2, )
#                 errorbar='se', capsize=.1, zorder=5)
    plt.ylabel(None)
    plt.xlabel('Braak Stage', fontsize=16, fontweight='bold')
    plt.xticks(fontweight='bold', fontsize=12)
    plt.yticks(fontweight='bold', fontsize=12)
    plt.title('iNFT', fontsize=16, fontweight='bold', y=1.15)
    ax2.tick_params(axis='both', which='both', direction='out', length=10, 
                    width=2)
    ax2.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax2.spines['bottom'].set_linewidth(2)
    ax2.spines['left'].set_linewidth(2)
    ax2.get_legend().remove()
    
    save_fp = join(save_dir, f'{feature} bars.png')
    
    if not isfile(save_fp):
        plt.savefig(save_fp, dpi=300, bbox_inches='tight')
        
    plt.show()
    
    
_ = interact(
    plot_grouped_plots, 
    feature=Dropdown(options=[
        'Density (Hippocampus)', 'Density (Amygdala)', 'Density (Temporal)',
        'Density (Occipital)', 'FOV count (Hippocampus)', 
        'FOV count (Amygdala)', 'FOV count (Temporal)', 'FOV count (occipital)'
    ]))

In [None]:
# # For each stage group & class calculate statistical significance.
# densities = [
#     'Density (Hippocampus)', 'Density (Amygdala)', 'Density (Temporal)',
#     'Density (Occipital)'
# ]

# results = 'Statistical signficance between cohorts in groups.\n\n'

# cohorts = ['train', 'Emory test', 'UC Davis test']
# stages = [2, 3, 4, 5, 6]

# for density in densities:
#     for stage in stages:
#         results += f'{density} Braak stage {stage}:\n'
#         for lb in ['Pre-NFT', 'iNFT']:
#             lb_df = plot_df[(plot_df.label == lb) & (plot_df.stage == stage)]

#             results += f'  o {lb}\n'

#             # Calculate one-way ANOVA between the three groups.
#             group1 = lb_df[lb_df.dataset == 'train'][density]
#             group2 = lb_df[lb_df.dataset == 'Emory test'][density]
#             group3 = lb_df[lb_df.dataset == 'UC Davis test'][density]

#             s, p = f_oneway(group1, group2, group3)

#             results += f'     - ANOVA: F={s:.4f}, p-value={p:.4f}\n'

#             # If the p-value is less then 0.05 then follow with post-hoc Tukey's
#             # test to identify groups that are significant from each other.
#             if p < 0.05:
#                 m_comp = pairwise_tukeyhsd(endog=lb_df[density], groups=lb_df.dataset, alpha=0.05)
                
#                 # Convert to a dataframe.
#                 m_comp = m_comp.summary().data
#                 m_comp = DataFrame(m_comp[1:], columns=m_comp[0])
                
#                 # Add pairs that where different.
#                 for _, r in m_comp[m_comp.reject].iterrows():
#                     results += f"     - {r.group1} & {r.group2} Pair Tukeys p-value={r['p-adj']:.4f}\n"

# #                 print(density, lb)
# #                 print(m_comp)
# #                 print()
# #                 for pair in combinations(cohorts, 2):
# #                     cohort1, cohort2 = pair

# #                     endog = lb_df[lb_df.dataset.isin((cohort1, cohort2))].reset_index(drop=True)
# #                     m_comp = pairwise_tukeyhsd(endog=endog[density].to_numpy(), groups=endog['dataset'].to_numpy(), alpha=0.05)

        
#         results += '\n'
        
#     results += '\n\n'
#         # Calculate a t-test between groups.
# #         for cohort_pair in combinations(cohorts, 2):
# #             cohort1, cohort2 = cohort_pair
            
# #             s, p = ttest_ind(
# #                 plot_df[
# #                     (plot_df.dataset == cohort1) & (plot_df.label == lb)
# #                 ][density],
# #                 plot_df[
# #                     (plot_df.dataset == cohort2) & (plot_df.label == lb)
# #                 ][density]
# #             )
            
# #             results += f'     - {cohort1} & {cohort2}, statistic = {s:.4f}, p-value = {p:.4f}\n'
                
# print(results)

# plot_df.dataset.unique()

## ML Model for Predicting ABC

In [None]:
# Recursive feature extraction.
data_abc = data[data.ABC >= 0]

train_data_abc = data_abc[data_abc.dataset == 'Emory test']
X_train_abc = train_data_abc[img_features].to_numpy()
y_train_abc = train_data_abc.ABC.tolist()

model = RandomForestClassifier(random_state=RANDOM_STATE)

# Chooose the top 20 features - manually tested this number.
rfe_abc = RFE(model, n_features_to_select=RFE_FEATURES)

rfe_abc = rfe_abc.fit(X_train_abc, y_train_abc)

# summarize the selection of the attributes
selected_features_abc = np.array(img_features)[rfe_abc.support_].tolist()

# List of top 10 selected features.
X_train_selected_abc = train_data_abc[selected_features_abc].to_numpy()

In [None]:
# Hyperparameter tune Random Forest Classifier.
rfc_abc = RandomForestClassifier(random_state=RANDOM_STATE, class_weight=None)

# Search for best set of parameters, automatically fits the best parameters.
gs_rfc = RandomizedSearchCV(
    rfc_abc, params, scoring='balanced_accuracy', cv=3, n_jobs=20, verbose=0, 
    random_state=RANDOM_STATE, n_iter=100
)

# Search for best hyperparameters.
gs_rfc  = gs_rfc.fit(X_train_selected_abc, y_train_abc)

# Best estimator.
rfc_abc = gs_rfc.best_estimator_

In [None]:
# Plot the best features.
importances = rfc_abc.feature_importances_.tolist()
importances, fts = (list(t) for t in zip(*sorted(zip(importances, selected_features))))

ft_importances_df = []

for imp, ft in zip(importances, fts):
    ft_importances_df.append([ft, imp])
    
ft_importances_df = DataFrame(ft_importances_df, columns=['Feature', 'Importance'])
    
# only plot the top 10 features
n = len(importances)

if n > 10:
    importances = importances[n-10:]
    fts = fts[n-10:]

fig, ax = plt.subplots(figsize=(4,5))
y_pos = np.arange(len(fts))
ax.barh(y_pos, importances)
ax.set_yticks(y_pos, labels=fts)
plt.ylabel('Feature', fontweight='bold', fontsize=16)
plt.xlabel('Importance', fontweight='bold', fontsize=16)
ax.tick_params(axis='both', which='both', direction='out', length=10, width=2)
    
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_linewidth(2)
ax.spines['left'].set_linewidth(2)
plt.show()

In [None]:
# Confusion matrix for Emory-Train cohort.
abcs = ['0', '1', '2', '3']

# Predict stages on the Emory train cohort (test for Braak stages).
test_data_abc = data_abc[data_abc.dataset == 'train']
X_test_abc = test_data_abc[selected_features_abc].to_numpy()
y_test_abc = test_data_abc.ABC.tolist()

y_test_pred_abc = rfc_abc.predict(X_test_abc)

cm = confusion_matrix(y_test_abc, y_test_pred_abc)
k = cohen_kappa_score(y_test_abc, y_test_pred_abc, weights='quadratic')

ax = plot_cm(cm, abcs, title=f"[ABC] Emory Train Cohort (k={k:.2f})")

for i in range(len(abcs)):
    ax.add_patch(Rectangle((i, i), 1, 1, fill=False, edgecolor='green', lw=3, hatch='/'))
plt.show()