In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 2

# imports
import collections
import functools
import io
import math
import itertools
import os
import shutil
import pathlib
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import re
import json

import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dipy.viz
from dipy.data import get_sphere
from dipy.viz import window, actor
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.patheffects
import mpl_toolkits
import matplotlib.pyplot as plt
import seaborn as sns

# Data management libraries.
import nibabel as nib
import pandas as pd

import box
from box import Box
import pprint
from pprint import pprint as ppr

# Computation & ML libraries.
import numpy as np
import skimage
import torch

import pitn

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, edgeitems=2, threshold=100, linewidth=88)
torch.set_printoptions(
    sci_mode=False, edgeitems=2, threshold=100, linewidth=88, profile="short"
)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = f"direnv exec {os.getcwd()} /usr/bin/env"
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [None]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():
    # GPU information
    try:
        gpu_info = pitn.utils.system.get_gpu_specs()
        print(gpu_info)
    except NameError:
        print("CUDA Version: ", torch.version.cuda)
else:
    print("CUDA not in use, falling back to CPU")

In [None]:
# cap is defined in an ipython magic command
print(cap)

## File and Folder Selection

In [None]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"])
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"])
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

In [None]:
selected_run_names = [
    "2022-03-01T06_58_21__pitn_dti_mid_net",
    # "2022-03-01T13_19_51__pitn_log_euclid_mid_net",
    # "2022-03-01T00_33_44__tanno_espcn_baseline",
    # "2022-03-04T20_22_21__debug_test_blumberg_revnet_comp",
    "2022-03-06T23_57_13__blumberg_revnet_rn4_no_bn_changing_samples",
    "2022-03-06T15_39_24__blumberg_revnet_rn4_no_bn_50_epoch",
    "2022-03-05T17_31_29__blumberg_revnet_rn4_100_epoch",
    "2022-03-06T01_08_02__blumberg_revnet_rn4_no_bn",
]
run_model_names = {
    "2022-03-01T06_58_21__pitn_dti_mid_net": "DIQT\nDTI",
    # "2022-03-01T13_19_51__pitn_log_euclid_mid_net": "DIQT\nLE",
    # "2022-03-01T00_33_44__tanno_espcn_baseline": "ESPCN\nBase",
    # "2022-03-04T20_22_21__debug_test_blumberg_revnet_comp": "RevNet-RN4",
    "2022-03-06T23_57_13__blumberg_revnet_rn4_no_bn_changing_samples": "RevNet-RN4-No-BN\nChange-Samples",
    "2022-03-06T15_39_24__blumberg_revnet_rn4_no_bn_50_epoch": "RevNet-RN4-No-BN",
    "2022-03-06T01_08_02__blumberg_revnet_rn4_no_bn": "RevNet-RN4-100-No-BN",
    "2022-03-05T17_31_29__blumberg_revnet_rn4_100_epoch": "RevNet-RN4-100",
}
selected_dirs = [results_dir / d for d in selected_run_names]
assert all([d.exists for d in selected_dirs])

In [None]:
# Load relevant result data.
runs = dict()
for name, d in zip(selected_run_names, selected_dirs):
    run_info = Box()
    run_info.test_perf = pd.read_csv(d / "test_loss.csv")
    name_col = pd.DataFrame.from_dict(
        {"run_name": list(itertools.repeat(name, len(run_info.test_perf)))}
    )
    run_info.test_perf = pd.concat([run_info.test_perf, name_col], axis=1)
    run_info.preds_dir = d / "predicted_dti"
    runs[name] = run_info

## Test Performance Distributions

In [None]:
perf_metrics = (
    "rmse",
    "nrmse",
    "log_euclid_rmse",
    "log_euclid_nrmse",
    "scaled_psnr",
    "ssim_fa",
    "rmse_fa",
    "nrmse_fa",
)

perf_comparison_directions = {
    "rmse": "↓",
    "nrmse": "↓",
    "log_euclid_rmse": "↓",
    "log_euclid_nrmse": "↓",
    "scaled_psnr": "↑",
    "ssim_fa": "↑",
    "rmse_fa": "↓",
    "nrmse_fa": "↓",
}

In [None]:
test_results = pd.concat([r.test_perf for r in runs.values()], ignore_index=True)

In [None]:
for run, model_name in run_model_names.items():
    print(run, model_name)
    mask = (test_results.run_name == run) & (test_results.model == "diqt")
    test_results.loc[mask, "model"] = model_name

test_results = test_results.replace(
    {"model": "cubic_spline"}, {"model": "Cubic\nSpline"}
)

In [None]:
test_results.groupby(["model", "metric"]).mean()

In [None]:
with mpl.rc_context(
    {
        "font.size": 6.0,
    }
):
    fig, axs = plt.subplots(
        # ncols=len(perf_metrics),
        ncols=4,
        nrows=2,
        # sharex=True,
        figsize=(12, 10),
        dpi=130,
        # gridspec_kw={"wspace": 1.0, "hspace": 1.0},
    )
    axs = axs.flatten()
    sns.despine(fig=fig, top=True, right=True)

    all_colors = sns.color_palette(
        "tab10",
        n_colors=len(test_results.model.unique()) + len(test_results.run_name.unique()),
    )
    model_colors = all_colors[: len(test_results.model.unique())]
    run_colors = all_colors[len(test_results.model.unique()) :]
    run_order = list(test_results.run_name.unique())

    ax_count = 0
    for i, l in enumerate(perf_metrics):

        ax = axs[i]
        df = test_results.loc[test_results.metric == l]

        vplot = sns.violinplot(
            x="model",
            y="value",
            data=df,
            ax=ax,
            scale="count",
            inner=None,
            palette=model_colors,
        )
        ax.grid(axis="y", alpha=0.5)

        points_plot = sns.swarmplot(
            x="model",
            y="value",
            # hue="run_name",
            # hue_order=run_order,
            data=df,
            ax=ax,
            # palette=run_colors,
            # palette=plt.get_cmap('gist_rainbow'),
            color="black",
            edgecolor="white",
            size=2.0,
            linewidth=0.4,
        )
        # points_plot.get_legend().remove()

        # Calculate mean performance score.
        means = df.groupby("model").mean()
        # Make sure the order follows seaborn's x-axis ordering.
        model_order = list(map(lambda ax: ax.get_text(), axs[i].get_xticklabels()))
        means = means.reindex(model_order)

        lines = ax.hlines(
            y=means.value,
            xmin=np.arange(0, len(means)) - 0.5 + 0.05,
            xmax=np.arange(1, len(means) + 1) - 0.5 - 0.05,
            colors=model_colors,
            lw=1.5,
        )

        outline_path_effects = [
            mpl.patheffects.Stroke(linewidth=5, foreground="white", alpha=0.9),
            mpl.patheffects.Normal(),
        ]
        lines.set_path_effects(outline_path_effects)

        ax.set_xticklabels(ax.get_xticklabels(), rotation=25)

        fig.canvas.draw()
        ax_format = ax.get_yaxis().get_major_formatter()

        for m, c in zip(means.value, model_colors):

            ax.annotate(
                f"{m:.4g}",
                xy=(ax.get_xlim()[0] + (ax.get_xlim()[0] * 0.4), m),
                xycoords="data",
                color=c,
                ha="right",
                va="center",
                annotation_clip=False,
                fontweight="bold",
                snap=True,
                bbox=dict(
                    boxstyle="square,pad=0.3", fc="white", lw=0, snap=True, alpha=0.75
                ),
            )
        ax.set_title(f"{l.replace('_', ' ')} {perf_comparison_directions[l]}")
        ax.set_ylabel("")
        ax.set_xlabel("")
        ax_count += 1
    for ax in axs[ax_count:]:
        sns.despine(ax=ax, left=True, bottom=True, top=True, right=True, trim=True)
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_yticklabels([])
        ax.set_xticklabels([])
plt.savefig("test_results_violin_with_blumberg.pdf")

## Tensor Ellipsoids

In [None]:
select_subj_id = "196952"
select_run = "2022-03-01T06_58_21__pitn_dti_mid_net"
run_info = runs[select_run]
pred_dti_file = run_info.preds_dir / f"{select_subj_id}_predicted_dti.nii.gz"
pred_dti = nib.load(pred_dti_file)
pred_dti_vol = torch.as_tensor(pred_dti.get_fdata())
dti_vol_mat = pitn.eig.tril_vec2sym_mat(pred_dti_vol, tril_dim=0)
eigvals, eigvecs = pitn.eig.eigh_workaround(dti_vol_mat)
mask = (pred_dti_vol == 0).all(0)[
    None,
    None,
]
pred_fa = pitn.metrics.fast_fa(
    pred_dti_vol[
        None,
    ],
    mask,
)[0, 0]
pred_fa = pred_fa.detach().cpu().numpy()
pred_fa = np.clip(pred_fa, 0, 1)
pred_dti_vol = pred_dti_vol.detach().cpu().numpy()
eigvals = eigvals.detach().cpu().numpy()
eigvecs = eigvecs.detach().cpu().numpy()

In [None]:
print("Computing tensor ellipsoids in a part of the splenium of the CC")
sphere = get_sphere("repulsion724")
# Enables/disables interactive visualization
interactive = False

scene = window.Scene()
RGB = dipy.reconst.dti.color_fa(pred_fa, eigvecs)
evals = eigvals[13:43, 44:74, 28:29]
evecs = eigvecs[13:43, 44:74, 28:29]
cfa = RGB[13:43, 44:74, 28:29]
cfa /= cfa.max()

scene.add(
    actor.tensor_slicer(evals, evecs, scalar_colors=cfa, sphere=sphere, scale=0.3)
)

print("Saving illustration as tensor_ellipsoids.png")
window.record(scene, n_frames=1, out_path="tensor_ellipsoids.png", size=(600, 600))
if interactive:
    window.show(scene)