# PITN Model Results Analysis

In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 2

# imports
import collections
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
import zipfile
import ast

import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.patheffects
import mpl_toolkits
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.pyplot as plt
import seaborn as sns

# Data management libraries.
import nibabel as nib
import nibabel.processing
import natsort
from natsort import natsorted
from pprint import pprint as ppr
from box import Box
import yaml

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchio
import pytorch_lightning as pl
import monai
import einops
import torchinfo

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy

import pitn

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = f"direnv exec {os.getcwd()} /usr/bin/env"
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True)

In [None]:
# torch setup
# Will need CUDA for finding eigendecomposition of these large volumes.
# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device("cuda")
    # Activate cudnn benchmarking to optimize convolution algorithm speed.
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = True
        print("CuDNN convolution optimization enabled.")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

In [None]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"])
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"])
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

In [None]:
# Parameter setup.

params = Box(default_box=True)

# 6 channels for the 6 DTI components
params.n_channels = 6
params.n_subjs = 48
params.lr_vox_size = 2.5
params.fr_vox_size = 1.25

# Data params
params.data.fr_dir = data_dir / f"scale-{params.fr_vox_size:.2f}mm"
params.data.lr_dir = data_dir / f"scale-{params.lr_vox_size:.2f}mm"
params.data.dti_fname_pattern = r"sub-*dti.nii.gz"
params.data.mask_fname_pattern = r"dti/sub-*mask.nii.gz"

# The data were downsampled artificially by this factor.
params.data.downsampled_by_factor = params.lr_vox_size / params.fr_vox_size
params.data.downsampled_by_factor = (
    int(params.data.downsampled_by_factor)
    if int(params.data.downsampled_by_factor) == params.data.downsampled_by_factor
    else params.data.downsampled_by_factor
)

params.data.eigval_clip_cutoff = 0.00332008

# Needed for downsampling shape correction.
params.train.in_patch_size = (24, 24, 24)

In [None]:
# Manuscript template is 122mm wide, so convert width to inches
FIG_WIDTH_INCHES = 122 / 25.4

## Load & Preprocess Ground Truth & Predicted DTIs

### Subject Selection

In [None]:
# Find data directories for each subject.
subj_dirs: Box = Box()

selected_ids = [
    "397154",
    "224022",
    "140117",
    "751348",
    "894774",
    "156637",
    "227432",
    "303624",
    "185947",
    "810439",
    "753251",
    "644246",
    "141422",
    "135528",
    "103010",
    "700634",
    "406432",
    "803240",
    "815247",
    "167238",
    "100408",
    "792867",
    "157437",
    "164030",
    "103515",
    "118730",
    "198047",
    "189450",
    "203923",
    "108828",
    "124220",
    "386250",
    "118124",
    "701535",
    "679770",
    "382242",
    "231928",
    "196952",  # Hold-out subject; for visualization, ensure never in the train or val sets
    "567961",
    "910241",
    "175035",
    "567759",
    "978578",
    "150019",
    "690152",
    "297655",
    "307127",
    "634748",
]
HOLDOUT_SUBJ_ID = "196952"
selected_subjs = selected_ids
selected_subjs = natsorted(selected_subjs)

for subj_id in selected_subjs:
    subj_dirs[subj_id] = Box()
    subj_dirs[subj_id].fr = pitn.utils.system.get_file_glob_unique(
        params.data.fr_dir, f"*{subj_id}*"
    )
    subj_dirs[subj_id].lr = pitn.utils.system.get_file_glob_unique(
        params.data.lr_dir, f"*{subj_id}*"
    )
    assert subj_dirs[subj_id].fr.exists()
    assert subj_dirs[subj_id].lr.exists()
ppr(subj_dirs)

### Loading & Preprocessing

In [None]:
# This dict will contain all ground truth data.
subj_gt: Box

In [None]:
# Prep for Dataset loading.

# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=True)

# HR -> LR patch coordinate conversion function.
fr2lr_patch_coords_fn = {
    "lr_dti": functools.partial(
        pitn.coords.transform.int_downscale_patch_idx,
        downscale_factor=params.data.downsampled_by_factor,
        downscale_patch_shape=params.train.in_patch_size,
    )
}
fr2lr_patch_coords_fn["lr_log_euclid"] = fr2lr_patch_coords_fn["lr_dti"]
fr2lr_patch_coords_fn["lr_mask"] = fr2lr_patch_coords_fn["lr_dti"]

# Kwargs for the patches dataset (the _VolPatchDataset class) of the HR volumes.
patch_kwargs = dict(
    patch_shape=tuple(
        np.floor(
            np.asarray(params.train.in_patch_size) * params.data.downsampled_by_factor
        ).astype(int)
    ),
    stride=1,
    meta_keys_to_patch_index={"dti", "log_euclid", "mask"},
    mask_name="mask",
)


# Coefficients to the log-euclidean lower triangle/6D vector that properly scales
# the Euclidean distance under the log-euclidean metrics.
mat_norm_coeffs = torch.ones(6)
mat_norm_coeffs[torch.as_tensor([1, 3, 4])] = np.sqrt(2)
mat_norm_coeffs = mat_norm_coeffs.reshape(-1, 1, 1, 1)


def fix_downsample_shape_errors(
    fr_vol: torch.Tensor, fr_affine: torch.Tensor, target_spatial_shape: tuple
):
    """Small utility to fix shape differences between LR and FR data."""
    target_shape = np.asarray(target_spatial_shape)
    if fr_vol.shape[1:] != tuple(target_shape):
        # Use torchio objects because they fix the affine matrix, too.
        # Flip before transform to pad on the right/top/furthest side of the dimension
        # first, before the left/bottom/closest.
        flip_vol = fr_vol.flip([1, 2, 3])
        im = torchio.ScalarImage(tensor=flip_vol, affine=fr_affine)
        transform = torchio.transforms.CropOrPad(target_spatial_shape, 0, copy=False)
        im = transform(im)
        result_vol = im["data"]
        # Unflip.
        result_vol = result_vol.flip([1, 2, 3])
        result_aff = im["affine"]
    else:
        result_vol = fr_vol
        result_aff = fr_affine

    return result_vol, result_aff


def orient_to_viz(vol, affine):

    if torch.is_tensor(vol):
        v = vol.detach().cpu().numpy()
    else:
        v = vol
    v = np.rot90(np.rot90(v, k=1, axes=(1, 3)), k=2, axes=(2, 3))
    if torch.is_tensor(vol):
        v = torch.from_numpy(np.copy(v)).to(vol)

    # Adjust the affine matrix.
    full_rot_aff = np.zeros_like(affine)
    full_rot_aff[-1, -1] = 1.0
    # 90 degree rot around the second axis.
    q1 = nib.quaternions.angle_axis2quat(np.pi / 2, [0, 1, 0])
    # 180 degree rot around the first axis.
    q2 = nib.quaternions.angle_axis2quat(np.pi, [1, 0, 0])
    new_q = nib.quaternions.mult(q1, q2)
    rot_aff = nib.quaternions.quat2mat(new_q)
    full_rot_aff[:-1, :-1] = rot_aff
    new_aff = full_rot_aff @ affine

    return v, new_aff

In [None]:
subj_gt = Box(default_box=True)

meta_keys_to_keep = {"affine", "original_affine"}

with torch.no_grad():
    for subj_id, subj_dir in subj_dirs.items():

        data = dict()
        data["subj_id"] = subj_id
        fr_subj_dir = subj_dirs[subj_id]["fr"]
        lr_subj_dir = subj_dirs[subj_id]["lr"]
        data["fr_subj_dir"] = fr_subj_dir
        data["lr_subj_dir"] = lr_subj_dir

        ####### Low-resolution DTIs/volumes
        lr_dti_f = pitn.utils.system.get_file_glob_unique(
            lr_subj_dir, params.data.dti_fname_pattern
        )
        im = nib_reader.read(lr_dti_f)
        lr_dti, meta = nib_reader.get_data(im)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        lr_dti = torch.from_numpy(lr_dti)
        lr_dti, meta["affine"] = orient_to_viz(lr_dti, meta["affine"])
        data["lr_dti"] = lr_dti
        data["lr_dti_meta_dict"] = meta

        # May need to handle shape errors when re-upscaling back from LR to HR.
        lr_dti_shape = np.asarray(lr_dti.shape[1:])
        target_fr_shape = np.floor(
            lr_dti_shape * params.data.downsampled_by_factor
        ).astype(int)

        ####### Full-resolution images/volumes.
        # DTI.
        dti_f = pitn.utils.system.get_file_glob_unique(
            fr_subj_dir, params.data.dti_fname_pattern
        )
        im = nib_reader.read(dti_f)
        dti, meta = nib_reader.get_data(im)
        dti = torch.from_numpy(dti)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        dti, meta["affine"] = fix_downsample_shape_errors(
            dti, meta["affine"], target_fr_shape
        )
        dti, meta["affine"] = orient_to_viz(dti, meta["affine"])
        data["dti"] = dti
        data["dti_meta_dict"] = meta

        # Diffusion mask.
        mask_f = pitn.utils.system.get_file_glob_unique(
            fr_subj_dir, params.data.mask_fname_pattern
        )
        im = nib_reader.read(mask_f)
        mask, meta = nib_reader.get_data(im)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        mask = torch.from_numpy(mask)
        # Add channel dim if not available.
        if mask.ndim == 3:
            mask = mask[None]
        mask, meta["affine"] = fix_downsample_shape_errors(
            mask, meta["affine"], target_fr_shape
        )
        mask, meta["affine"] = orient_to_viz(mask, meta["affine"])
        mask = mask.bool()
        data["mask"] = mask
        data["mask_meta_dict"] = meta

        # Construct a quick and cheap mask for the LR DTI
        cheap_lr_mask = F.interpolate(
            data["mask"][None].float(),
            size=data["lr_dti"][0].shape,
            mode="nearest",
        )[0]
        data["lr_mask"] = cheap_lr_mask.bool()

        # Consider this as the "noise correction" step to have more informative, consistent
        # results with minimal biasing. Otherwise, outliers (which are clearly errors) can
        # change loss and performance metrics by orders of magnitude for no good reason!
        if "eigval_clip_cutoff" in params.data and params.data.eigval_clip_cutoff:
            correct_dti = pitn.data.outliers.clip_dti_eigvals(
                data["dti"].to(device),
                tensor_components_dim=0,
                eigval_max=params.data.eigval_clip_cutoff,
            ).to(data["dti"])
            correct_dti = correct_dti * data["mask"]
            ####
            # sae_fr = (
            #     F.l1_loss(data["dti"], correct_dti, reduction="none") * data["mask"]
            # )
            # sae_fr = sae_fr.view(6, -1).sum(1)
            # mae_fr = sae_fr / torch.count_nonzero(data["mask"])
            # print(f"---Subj {subj_id}---")
            # print(
            #     "MAE of FR DTI after eigenvalue clipping:\n",
            #     mae_fr.tolist(),
            # )
            ####
            data["dti"] = correct_dti

            correct_lr_dti = pitn.data.outliers.clip_dti_eigvals(
                data["lr_dti"].to(device),
                tensor_components_dim=0,
                eigval_max=params.data.eigval_clip_cutoff,
            ).to(data["lr_dti"])
            lr_mask = data["lr_mask"]
            correct_lr_dti = correct_lr_dti * lr_mask
            ####
            # sae_lr = (
            #     F.l1_loss(data["lr_dti"], correct_lr_dti, reduction="none") * lr_mask
            # )
            # sae_lr = sae_lr.view(6, -1).sum(1)
            # mae_lr = sae_lr / torch.count_nonzero(lr_mask)
            # print(
            #     "MAE of LR DTI after eigenvalue clipping:\n",
            #     mae_lr.tolist(),
            # )
            ####
            data["lr_dti"] = correct_lr_dti

        ####### Log-euclid pre-computed volumes

        # LR log-euclid volume.
        lr_log_euclid = pitn.eig.tril_vec2sym_mat(data["lr_dti"], tril_dim=0)
        lr_log_euclid = pitn.riemann.log_euclid.log_map(lr_log_euclid)
        lr_log_euclid = pitn.eig.sym_mat2tril_vec(lr_log_euclid, tril_dim=0)
        lr_log_euclid = lr_log_euclid * mat_norm_coeffs
        data["lr_log_euclid"] = lr_log_euclid

        log_euclid = pitn.eig.tril_vec2sym_mat(data["dti"], tril_dim=0)
        log_euclid = pitn.riemann.log_euclid.log_map(log_euclid)
        log_euclid = pitn.eig.sym_mat2tril_vec(log_euclid, tril_dim=0)
        log_euclid = log_euclid * mat_norm_coeffs
        data["log_euclid"] = log_euclid

        # Pre-compute FA maps
        fa = pitn.metrics.fast_fa(
            data["dti"][None].to(device), foreground_mask=data["mask"][None].to(device)
        )
        fa = fa.to(data["dti"])[0]
        data["fa"] = fa

        # Finalize this subject.
        subj_gt[subj_id] = data
        print("Loaded Subject ", subj_id)
        print("=" * 20)

print("===Data Loaded===")

### Calculate Aggregate Statistics

In [None]:
# Only choose subjects in the training and validation datasets.

subj_agg_stats = Box(default_box=True)
subj_agg_stats.dti.min = torch.zeros(params.n_channels).to(subj_gt[selected_ids[0]].dti)
subj_agg_stats.dti.max = torch.zeros(params.n_channels).to(subj_agg_stats.dti.min)

subj_agg_stats.log_euclid.min = subj_agg_stats.dti.min
subj_agg_stats.log_euclid.max = subj_agg_stats.dti.max

for s in subj_gt.values():
    fr_mask = s.mask
    dti = torch.masked_select(s.dti, fr_mask)
    subj_agg_stats.dti.min = torch.minimum(
        subj_agg_stats.dti.min, dti.view(params.n_channels, -1).min(-1).values
    )
    subj_agg_stats.dti.max = torch.maximum(
        subj_agg_stats.dti.max, dti.view(params.n_channels, -1).max(-1).values
    )

    lr_dti = s.lr_dti
    lr_mask = s.lr_mask
    lr_dti = torch.masked_select(lr_dti, lr_mask)
    subj_agg_stats.dti.min = torch.minimum(
        subj_agg_stats.dti.min, lr_dti.view(params.n_channels, -1).min(-1).values
    )
    subj_agg_stats.dti.max = torch.maximum(
        subj_agg_stats.dti.max, lr_dti.view(params.n_channels, -1).max(-1).values
    )


print(subj_agg_stats.dti.min)
print(subj_agg_stats.dti.max)

In [None]:
# Calculate global ranges of data for PSNR calculations.

expander = functools.partial(einops.rearrange, pattern="c -> 1 c 1 1 1")
# Collect DTI global data features.
dti_min = expander(subj_agg_stats.dti.min)
dti_max = expander(subj_agg_stats.dti.max)

feat_min, feat_max = torch.as_tensor(
    [
        [0] * 6,
        [1] * 6,
    ]
)
feat_min = expander(feat_min)
feat_max = expander(feat_max)

# PSNR is calculated on the final output tensor components, so no log-euclidean or
# scaling will occur here.
psnr_range_params = pitn.data.norm.GlobalScaleParams(
    feature_min=feat_min, feature_max=feat_max, data_min=dti_min, data_max=dti_max
)

## Select & Describe Models for Comparison

In [None]:
def run_params_from_config(run_config_file: Path) -> dict:
    config_f = Path(run_config_file).resolve()
    with open(config_f, "r") as f:
        config_str = f.read()
    config = ast.literal_eval(config_str)
    # Clean up any recursive weirdness.
    config = Box(config).to_dict()
    return config

In [None]:
baseline_spline_run = "2022-03-27T05_51_39__test_spline"
baseline_espcn_runs = []
baseline_revnet_runs = [
    "2022-03-22T20_40_20__uvers_espcn_revnet_split_1",
    "2022-03-22T20_40_22__uvers_espcn_revnet_split_2",
    "2022-03-23T04_03_59__uvers_espcn_revnet_split_3",
]
diqt_carn_single_stream_runs = [
    "2022-03-27T13_24_22__uvers_pitn_single_stream_dti_split_1",
    "2022-03-27T12_46_02__uvers_pitn_single_stream_dti_split_2",
    "2022-03-27T08_29_27__uvers_pitn_single_stream_dti_split_3",
    "2022-03-27T13_24_24__uvers_pitn_single_stream_le_split_1",
    "2022-03-27T20_05_11__uvers_pitn_single_stream_le_split_2",
    "2022-03-27T17_57_13__uvers_pitn_single_stream_le_split_3",
]
diqt_carn_fake_anat_stream_runs = [
    "2022-03-25T17_15_37__uvers_pitn_fake_anat_stream_dti_split_1",
    "2022-03-25T17_15_39__uvers_pitn_fake_anat_stream_dti_split_2",
    "2022-03-26T19_11_13__uvers_pitn_fake_anat_stream_dti_split_3",
    "2022-03-26T18_46_05__uvers_pitn_fake_anat_stream_le_split_1",
    "2022-03-26T18_46_07__uvers_pitn_fake_anat_stream_le_split_2",
    "2022-03-27T02_36_43__uvers_pitn_fake_anat_stream_le_split_3",
]
diqt_carn_anat_stream_runs = [
    "2022-03-22T21_33_49__uvers_pitn_anat_stream_dti_split_1",
    "2022-03-24T03_22_55__uvers_pitn_anat_stream_dti_split_2",
    "2022-03-24T22_47_26__uvers_pitn_anat_stream_dti_split_3",
    "2022-03-23T17_25_43__uvers_pitn_anat_stream_le_split_1",
    "2022-03-24T12_56_41__uvers_pitn_anat_stream_le_split_2",
    "2022-03-25T08_21_45__uvers_pitn_anat_stream_le_split_3",
]

### Load Run Data

In [None]:
run_root_dir = results_dir
run_table = collections.defaultdict(list)
# cols = [
#     "run_name",
#     "model_name",
#     "split",
#     "use_le",
#     "use_anat",
#     "streams",
#     "use_half_precision",
# ]

sp_run = baseline_spline_run
run_dir = run_root_dir / sp_run
run_table["run_name"].append(sp_run)
run_table["model_name"].append("Cubic Spline")
run_table["split"].append(0)
run_table["use_le"].append(True)
run_table["use_anat"].append(False)
run_table["streams"].append(1)
run_table["use_half_precision"].append(False)

for run in baseline_revnet_runs:
    run_dir = run_root_dir / run
    run_table["run_name"].append(run)
    run_table["model_name"].append("RevNet4")
    sp_idx = run.find("split_") + (len("split_") - 1) + 1
    split = int(run[sp_idx])
    run_table["split"].append(split)
    run_table["use_le"].append(False)
    run_table["use_anat"].append(False)
    run_table["streams"].append(1)
    run_table["use_half_precision"].append(False)

for run in diqt_carn_single_stream_runs:
    run_dir = run_root_dir / run
    run_table["run_name"].append(run)

    sp_idx = run.find("split_") + (len("split_") - 1) + 1
    split = int(run[sp_idx])
    run_table["split"].append(split)

    run_table["use_le"].append("_le_" in run)
    run_table["use_anat"].append(False)
    run_table["streams"].append(1)
    run_table["use_half_precision"].append(True)

    model_basename = "CARN"
    if run_table["use_le"][-1]:
        model_name = model_basename + " LE"
    else:
        model_name = model_basename + " DTI"
    run_table["model_name"].append(model_name)

for run in diqt_carn_fake_anat_stream_runs:
    run_dir = run_root_dir / run
    run_table["run_name"].append(run)

    sp_idx = run.find("split_") + (len("split_") - 1) + 1
    split = int(run[sp_idx])
    run_table["split"].append(split)

    run_params_f = run_dir / "run_params.yaml"
    with open(run_params_f, "r") as f:
        p = yaml.load(f, Loader=yaml.FullLoader)

    run_p = Box(p)
    run_table["use_le"].append(run_p.use_log_euclid)
    run_table["use_anat"].append(False)
    run_table["streams"].append(2)
    run_table["use_half_precision"].append(run_p.use_half_precision_float)

    model_basename = "CARN Fake Anat"
    if run_p.use_log_euclid:
        model_name = model_basename + " LE"
    else:
        model_name = model_basename + " DTI"
    run_table["model_name"].append(model_name)

for run in diqt_carn_anat_stream_runs:
    run_dir = run_root_dir / run
    run_table["run_name"].append(run)

    sp_idx = run.find("split_") + (len("split_") - 1) + 1
    split = int(run[sp_idx])
    run_table["split"].append(split)

    run_params_f = run_dir / "run_params.yaml"
    with open(run_params_f, "r") as f:
        p = yaml.load(f, Loader=yaml.FullLoader)

    run_p = Box(p)
    run_table["use_le"].append(run_p.use_log_euclid)
    run_table["use_anat"].append(True)
    run_table["streams"].append(2)
    run_table["use_half_precision"].append(run_p.use_half_precision_float)

    model_basename = "CARN Anat"
    if run_p.use_log_euclid:
        model_name = model_basename + " LE"
    else:
        model_name = model_basename + " DTI"
    run_table["model_name"].append(model_name)


run_table = pd.DataFrame.from_dict(run_table)

In [None]:
run_table

In [None]:
run_results = Box()
run_root_dir = results_dir

# Handle cubic spline baseline results.
sp_run = baseline_spline_run
print(sp_run)
sp_run_dir = run_root_dir / sp_run
sp_run_res = Box(default_box=True)
sp_run_res.name = sp_run
sp_run_res.dir = sp_run_dir.resolve()
sp_run_res.split = 0
sp_run_res.metrics = pd.read_csv(sp_run_res.dir / "test_loss.csv", index_col=False)
sp_run_res.model_name = "Cubic Spline"
sp_run_res.metrics = sp_run_res.metrics.replace(
    {"model": "diqt"}, {"model": sp_run_res.model_name}
)
split_col = pd.DataFrame.from_dict(
    {"split": list(itertools.repeat(sp_run_res.split, len(sp_run_res.metrics)))}
)
sp_run_res.metrics = pd.concat([sp_run_res.metrics, split_col], axis=1)
sp_run_params_f = sp_run_dir / "run_params.yaml"
if sp_run_params_f.exists():
    with open(sp_run_params_f, "r") as f:
        p = yaml.load(f, Loader=yaml.FullLoader)
    sp_run_res.params = dict(p)
else:
    sp_run_res.params = None
pred_zip = sp_run_dir / "predicted_dti.zip"
sp_run_res.pred_zip = pred_zip.resolve()
pred_le_zip = sp_run_dir / "predicted_le.zip"
sp_run_res.pred_le_zip = pred_le_zip.resolve()
sp_run_res.pred_pre_anat_zip = None

run_results[sp_run] = sp_run_res

for run in itertools.chain(
    baseline_espcn_runs,
    baseline_revnet_runs,
    diqt_carn_single_stream_runs,
    diqt_carn_fake_anat_stream_runs,
    diqt_carn_anat_stream_runs,
):
    print(run)
    run_res = Box(default_box=True)
    run_dir = run_root_dir / run
    run_res.name = run
    run_res.dir = run_dir.resolve()
    run_res.split = run_table.loc[run_table.run_name == run].split.iloc[0]
    run_res.metrics = pd.read_csv(run_res.dir / "test_loss.csv", index_col=False)
    model_name = run_table.loc[run_table.run_name == run].model_name.iloc[0]
    run_res.model_name = model_name

    run_res.metrics = run_res.metrics.replace(
        {"model": "diqt"}, {"model": run_res.model_name}
    )
    split_col = pd.DataFrame.from_dict(
        {"split": list(itertools.repeat(run_res.split, len(run_res.metrics)))}
    )
    run_res.metrics = pd.concat([run_res.metrics, split_col], axis=1)

    run_params_f = run_dir / "run_params.yaml"
    if run_params_f.exists():
        with open(run_params_f, "r") as f:
            p = yaml.load(f, Loader=yaml.FullLoader)
        run_res.params = dict(p)
    else:
        run_res.params = None

    pred_zip = run_dir / "predicted_dti.zip"
    run_res.pred_zip = pred_zip.resolve()
    if (run_dir / "predicted_pre_anat_dti.zip").exists():
        run_res.pred_pre_anat_zip = run_dir / "predicted_pre_anat_dti.zip"
    else:
        run_res.pred_pre_anat_zip = None

    run_results[run] = run_res

### Unpack Model Predictions

In [None]:
# Make this cell idempotent
for run in run_results.values():
    unzip_dir = run.pred_zip.parent / "predicted_dti"
    unzip_dir.mkdir(exist_ok=True)
    with zipfile.ZipFile(run.pred_zip, "r") as z:
        for f in z.namelist():
            unzip_pred = unzip_dir / f
            if unzip_pred.exists():
                continue
            z.extract(f, unzip_dir)
    run.pred_dir = unzip_dir
    pred_subjs = [f.name[: f.name.find("_pr")] for f in run.pred_dir.glob("*")]
    pred_vols = {
        s: pitn.utils.system.get_file_glob_unique(run.pred_dir, s + "*")
        for s in pred_subjs
    }
    run.pred_vols = pred_vols

    if run.pred_pre_anat_zip is not None:
        unzip_dir = run.pred_pre_anat_zip.parent / "predicted_pre_anat_dti"
        unzip_dir.mkdir(exist_ok=True)
        run.pred_pre_anat_dir = unzip_dir
        with zipfile.ZipFile(run.pred_pre_anat_zip, "r") as z:
            for f in z.namelist():
                unzip_pred = unzip_dir / f
                if unzip_pred.exists():
                    continue
                z.extract(f, unzip_dir)

        pred_pre_anat_vols = {
            s: pitn.utils.system.get_file_glob_unique(run.pred_pre_anat_dir, s + "*")
            for s in pred_subjs
        }

    else:
        unzip_dir = None
        pred_pre_anat_vols = None

    if "pred_le_zip" in run:
        unzip_dir = run.pred_le_zip.parent / "predicted_le"
        unzip_dir.mkdir(exist_ok=True)
        run.pred_le_dir = unzip_dir
        with zipfile.ZipFile(run.pred_le_zip, "r") as z:
            for f in z.namelist():
                unzip_pred = unzip_dir / f
                if unzip_pred.exists():
                    continue
                z.extract(f, unzip_dir)

        pred_le_vols = {
            s: pitn.utils.system.get_file_glob_unique(run.pred_le_dir, s + "*")
            for s in pred_subjs
        }
        run.pred_le_vols = pred_le_vols

    run.pred_pre_anat_dir = unzip_dir
    run.pred_pre_anat_vols = pred_pre_anat_vols
    print("Finished ", run.name)

In [None]:
# Select metrics to show.
perf_metrics_with_directions = {
    "rmse": "↓",
    # "nrmse": "↓",
    "rmse_log_euclid": "↓",
    # "nrmse_log_euclid": "↓",
    "scaled_psnr": "↑",
    "ssim_fa": "↑",
    # "rmse_fa": "↓",
    # "nrmse_fa": "↓",
}


metric_rename_map = {
    "rmse": "DTI RMFD ↓",
    "rmse_log_euclid": "LE RMFD ↓",
    "scaled_psnr": "PSNR ↑",
    "ssim_fa": "FA SSIM ↑",
}

model_rename_map = {
    "CARN Anat LE": "CARN LE (Proposed)",
    "CARN Anat DTI": "CARN DTI",
    "CARN LE": "CARN LE No Anat Net",
    "CARN DTI": "CARN DTI No Anat Net",
    "CARN Fake Anat LE": "CARN LE Fake Anat",
    "CARN Fake Anat DTI": "CARN DTI Fake Anat",
}

## Quantify Scores

In [None]:
perf_metrics = tuple(perf_metrics_with_directions.keys())

# Merge all results from all runs together.
test_results = pd.concat([r.metrics for r in run_results.values()], ignore_index=True)
test_results = test_results.loc[np.isin(test_results.metric, perf_metrics)]
test_results

In [None]:
np.unique(test_results.model)

In [None]:
results_table = (
    test_results.groupby(["model", "metric"]).mean().drop(columns=["subj_id", "split"])
)
results_table["std"] = test_results.groupby(["model", "metric"]).std().value
results_table = results_table.rename(columns={"value": "mean"})

results_table = results_table.reset_index()
rmse_mask = results_table.metric == "rmse"

# Scale rmse by 100 for readability.
# !Scale both the mean and std by 100
results_table.loc[rmse_mask, "mean"] = (results_table["mean"] * 100)[rmse_mask]
results_table.loc[rmse_mask, "std"] = (results_table["std"] * 100)[rmse_mask]

float_str_format = "{:.3f}"
results_table["display"] = results_table.apply(
    lambda row: "$"
    + float_str_format.format(np.round(row["mean"], decimals=3))
    + r" \pm "
    + float_str_format.format(np.round(row["std"], decimals=3))
    + "$",
    axis="columns",
)


results_table = results_table.pivot(index="model", columns="metric", values="display")
# Rename rmse metric after scaling
results_table = results_table.rename(columns=metric_rename_map)
results_table = results_table.rename(index=model_rename_map)
results_table = results_table.rename(
    columns={
        "DTI RMFD ↓": r"$\textrm{DTI RMFD} \times 10^{-2} \downarrow$",
        "LE RMFD ↓": r"$\textrm{LE RMFD} \downarrow$",
        "PSNR ↑": r"$\textrm{PSNR} \uparrow$",
        "FA SSIM ↑": r"$\textrm{FA SSIM} \uparrow$",
    }
)
# results_table = results_table.rename(
#     columns={"rmse": r"$\textrm{RMSE} \times 10^{-2} \downarrow$"}
# )
# results_table = results_table.rename(
#     columns={
#         "rmse_log_euclid": r"$\textrm{RMSE LE} \downarrow$",
#         "scaled_psnr": r"$\textrm{PSNR} \uparrow$",
#         "ssim_fa": r"$\textrm{SSIM FA} \uparrow$",
#     }
# )

# Reorder row/model order
results_table = results_table.loc[
    [
        "Cubic Spline",
        "RevNet4",
        "CARN LE (Proposed)",
        "CARN DTI",
        "CARN LE No Anat Net",
        "CARN DTI No Anat Net",
        "CARN LE Fake Anat",
        "CARN DTI Fake Anat",
    ]
]


# Save out table of results.
results_table.to_csv("test_results.csv")
with open("test_results_table.tex", "w") as f:
    latex_tab = results_table.to_latex(escape=False, column_format="l|l|l|l|l")
    f.write(latex_tab)

print(latex_tab)
results_table

Formatted table:

```
\centering
\begin{tabular}{|l|l|l|l|l|} 
\hline
\diagbox{model}{metric}                                                                                                             & $\textrm{DTI RMFD} \times 10^{-2} \downarrow$ & $\textrm{LE RMFD} \downarrow$              & $\textrm{PSNR} \uparrow$                    & $\textrm{FA SSIM} \uparrow$                 \\ 
\hline
\rowcolor[rgb]{0.933,0.933,0.929} Cubic Spline                                                                                      & $0.050 \pm 0.003$                             & $1.055 \pm 0.030$                          & $43.746 \pm 0.107$                          & $0.881 \pm 0.010$                           \\
RevNet4                                                                                                                             & $0.038 \pm 0.001$                             & $0.825 \pm 0.055$                          & $44.609 \pm 0.121$                          & $0.883 \pm 0.011$                           \\
\rowcolor[rgb]{0.933,0.933,0.929} \begin{tabular}[c]{@{}>{\cellcolor[rgb]{0.933,0.933,0.929}}l@{}}CARN LE\\(Proposed)\end{tabular}  & $0.036 \pm 0.003$                             & \mathversion{tabularbold}$0.684 \pm 0.047$ & $44.595 \pm 0.182$                          & \mathversion{tabularbold}$0.919 \pm 0.008$  \\
CARN DTI                                                                                                                            & \mathversion{tabularbold}$0.033 \pm 0.003$    & $0.764 \pm 0.060$                          & \mathversion{tabularbold}$44.778 \pm 0.190$ & $0.916 \pm 0.009$                           \\
\rowcolor[rgb]{0.933,0.933,0.929} \begin{tabular}[c]{@{}>{\cellcolor[rgb]{0.933,0.933,0.929}}l@{}}CARN LE\\No Anat Net\end{tabular} & $0.042 \pm 0.002$                             & $0.724 \pm 0.041$                          & $44.450 \pm 0.129$                          & $0.907 \pm 0.008$                           \\
\begin{tabular}[c]{@{}l@{}}CARN DTI\\No Anat Net\end{tabular}                                                                       & $0.038 \pm 0.001$                             & $0.820 \pm 0.055$                          & $44.578 \pm 0.125$                          & $0.899 \pm 0.009$                           \\
\rowcolor[rgb]{0.933,0.933,0.929} \begin{tabular}[c]{@{}>{\cellcolor[rgb]{0.933,0.933,0.929}}l@{}}CARN LE\\Fake Anat\end{tabular}   & $0.041 \pm 0.002$                             & $0.722 \pm 0.042$                          & $44.527 \pm 0.123$                          & $0.909 \pm 0.008$                           \\
\begin{tabular}[c]{@{}l@{}}CARN DTI\\Fake Anat\end{tabular}                                                                         & $0.037 \pm 0.001$                             & $0.820 \pm 0.056$                          & $44.611 \pm 0.127$                          & $0.899 \pm 0.009$                           \\
\hline
\end{tabular}
```

## Results Viz and Display

### DTI Eigenvalue Validity Check

In [None]:
# Subject-specific non-zero eigenvalues.
select_sub_id = HOLDOUT_SUBJ_ID
m_names = list(run_table.model_name.unique())

for m in m_names:
    print(m)
    model_hist = list()
    rs = list(run_table.loc[run_table.model_name == m].run_name.unique())
    for r in rs:
        print(r)
        run = run_results[r]
        pred_dti = torch.from_numpy(nib.load(run.pred_vols[select_sub_id]).get_fdata())
        mask = subj_gt[select_sub_id]["mask"]

        if mask.shape[1:] != pred_dti.shape[1:]:
            t = monai.transforms.CenterSpatialCrop(pred_dti.shape[1:])
            mask = t(mask)
        dti_mat = pitn.eig.tril_vec2sym_mat(pred_dti, tril_dim=0)
        dti_mat = dti_mat[mask[0]]
        eigvals = pitn.eig.eigvalsh_workaround(dti_mat)
        model_hist.append(eigvals.T.detach().cpu().numpy())
        # break
    model_hist = np.concatenate(model_hist, axis=1)
    print(f"Eigvals < 0: {(model_hist < 0).sum(axis=1)} out of {model_hist.shape[1]}")
    le_0_mask = model_hist < 0
    fig, axs = plt.subplots(ncols=3, sharex=True, sharey=True, figsize=(12, 2), dpi=100)
    # Lambda 1, 2, then 3
    if le_0_mask[0].sum() > 0:
        sns.histplot(model_hist[0][le_0_mask[0]], log_scale=False, ax=axs[0])
    if le_0_mask[1].sum() > 0:
        sns.histplot(
            model_hist[1][le_0_mask[1]], log_scale=False, ax=axs[1], color="red"
        )
    if le_0_mask[2].sum() > 0:
        sns.histplot(
            model_hist[2][le_0_mask[2]], log_scale=False, ax=axs[2], color="green"
        )
    plt.show()

In [None]:
# Negative eigenvalues distribution for all subjects (each subj counted only once)
m_names = list(run_table.model_name.unique())

for m in m_names:
    print(m)
    sub_model_hist = dict(l1=list(), l2=list(), l3=list())
    subs_seen = set()
    rs = list(run_table.loc[run_table.model_name == m].run_name.unique())
    for r in rs:
        print(r)
        run = run_results[r]
        for sub_id in run.pred_vols.keys():
            if sub_id in subs_seen:
                continue
            pred_dti = torch.from_numpy(nib.load(run.pred_vols[sub_id]).get_fdata())
            mask = subj_gt[sub_id]["mask"]

            if mask.shape[1:] != pred_dti.shape[1:]:
                t = monai.transforms.CenterSpatialCrop(pred_dti.shape[1:])
                mask = t(mask)
            dti_mat = pitn.eig.tril_vec2sym_mat(pred_dti, tril_dim=0)
            dti_mat = dti_mat[mask[0]]
            eigvals = pitn.eig.eigvalsh_workaround(dti_mat)
            eigvals = eigvals.T.detach().cpu().numpy()
            eigvals_mask = eigvals < 0
            # model_hist.append(eigvals.T.detach().cpu().numpy())
            sub_model_hist["l1"].append(eigvals[0][eigvals_mask[0]])
            sub_model_hist["l2"].append(eigvals[1][eigvals_mask[1]])
            sub_model_hist["l3"].append(eigvals[2][eigvals_mask[2]])

            subs_seen = subs_seen | {
                sub_id,
            }

    sub_model_hist["l1"] = np.concatenate(sub_model_hist["l1"], axis=0)
    sub_model_hist["l2"] = np.concatenate(sub_model_hist["l2"], axis=0)
    sub_model_hist["l3"] = np.concatenate(sub_model_hist["l3"], axis=0)

    print(
        f"Eigvals < 0: {len(sub_model_hist['l1'])}, {len(sub_model_hist['l2'])}, {len(sub_model_hist['l3'])}"
    )
    print("Min eigvals ")
    print(sub_model_hist["l1"].min() if len(sub_model_hist["l1"]) > 0 else "NA")
    print(sub_model_hist["l2"].min() if len(sub_model_hist["l2"]) > 0 else "NA")
    print(sub_model_hist["l3"].min() if len(sub_model_hist["l3"]) > 0 else "NA")

    with mpl.rc_context({"font.size": 6.0}):
        fig, axs = plt.subplots(
            ncols=3, sharex=True, sharey=True, figsize=(12, 2), dpi=100
        )
        # Lambda 1, 2, then 3
        if len(sub_model_hist["l1"]) > 0:
            sns.histplot(sub_model_hist["l1"], ax=axs[0])
        if len(sub_model_hist["l2"]) > 0:
            sns.histplot(sub_model_hist["l2"], ax=axs[1], color="red")
        if len(sub_model_hist["l3"]) > 0:
            sns.histplot(sub_model_hist["l3"], ax=axs[2], color="green")

        plt.show()

### FA Prediction Bias

In [None]:
# Negative eigenvalues distribution for all subjects (each subj counted only once)
select_sub_id = HOLDOUT_SUBJ_ID
select_sub_ids = {
    "141422",
    "135528",
    "103010",
    # HOLDOUT_SUBJ_ID,
    # "567759",
    # "815247",
}
m_names = list(run_table.model_name.unique())
with mpl.rc_context({"font.size": 6.0}):
    fig, axs = plt.subplots(
        ncols=len(select_sub_ids),
        nrows=len(m_names),
        figsize=(9, 2.3 * len(select_sub_ids)),
        dpi=180,
        sharex="all",
        sharey="all",
    )
    for i_m, m in enumerate(m_names):
        print(m)
        fa_diff_hist = dict()
        subs_seen = set()
        rs = list(run_table.loc[run_table.model_name == m].run_name.unique())
        for r in rs:
            print(r)
            run = run_results[r]
            for j_sub, sub_id in enumerate(select_sub_ids):
                if sub_id in subs_seen:
                    continue
                if sub_id not in run.pred_vols.keys():
                    continue
                pred_dti = torch.from_numpy(nib.load(run.pred_vols[sub_id]).get_fdata())
                mask = subj_gt[sub_id]["mask"]
                gt_fa = subj_gt[sub_id]["fa"]
                if mask.shape[1:] != pred_dti.shape[1:]:
                    t = monai.transforms.CenterSpatialCrop(pred_dti.shape[1:])
                    mask = t(mask)
                    gt_fa = t(gt_fa)

                fa = pitn.metrics.fast_fa(pred_dti[None], mask[None])[0]

                diff = (fa - gt_fa)[mask].flatten()
                diff = diff.detach().cpu().numpy()
                sel_diff = diff  # [np.where(np.abs(diff) > 1e-6)]

                print(f"Diff range: {diff.min()} - {diff.max()}")
                print(f"Skew {scipy.stats.skew(sel_diff)}")
                print(
                    f"Statistically significant skew? {scipy.stats.skewtest(sel_diff)}"
                )
                sns.histplot(sel_diff, ax=axs[i_m, j_sub], kde=True, thresh=10)
                if j_sub == 0:
                    axs[i_m, j_sub].set_ylabel(m)

                axs[i_m, j_sub].axvline(
                    np.mean(sel_diff), color="red", ls="--", alpha=0.8
                )
                axs[i_m, j_sub].axvline(
                    np.median(sel_diff), color="green", ls="-", alpha=0.8
                )

                subs_seen = subs_seen | {
                    sub_id,
                }

### Metrics Breakdown

In [None]:
# Select metrics to show.
# perf_metrics_with_directions = {
#     "rmse": "↓",
#     # "nrmse": "↓",
#     "rmse_log_euclid": "↓",
#     # "nrmse_log_euclid": "↓",
#     "scaled_psnr": "↑",
#     "ssim_fa": "↑",
#     # "rmse_fa": "↓",
#     # "nrmse_fa": "↓",
# }

# metric_rename_map = {
#     "rmse": "DTI RMFD",
#     "rmse_log_euclid": "LE RMFD",
#     "scaled_psnr": "PSNR",
#     "ssim_fa": "FA SSIM",
# }

perf_metrics = tuple(metric_rename_map.values())

In [None]:
# Merge all results from all runs together.
test_results = pd.concat([r.metrics for r in run_results.values()], ignore_index=True)
test_results.metric = test_results.metric.apply(lambda m: metric_rename_map.get(m, m))
test_results.model = test_results.model.apply(lambda m: model_rename_map.get(m, m))
# test_results = test_results.map(model_rename_map)

In [None]:
test_results.groupby(["model", "metric", "split"]).mean().drop(columns="subj_id")

In [None]:
np.unique(test_results.model)

In [None]:
# Main result plot
n_splits = len(test_results.split.unique())
n_metrics = len(test_results.metric.unique())

# models_to_display = list(test_results.model.unique())
models_to_display = [
    "Cubic Spline",
    "RevNet4",
    "CARN DTI",
    "CARN LE (Proposed)",
]

# model_rename_mapping = {
#     "Cubic Spline": "Cubic Spline",
#     "RevNet4": "RevNet4",
#     "CARN Anat DTI": "CARN DTI Anat\n(Ours)",
#     "CARN Anat LE": "CARN LE Anat\n(Ours)",
# }
n_models = len(models_to_display)

with mpl.rc_context({"font.size": 9.0}):

    ncols = 2
    nrows = 2  # np.ceil(n_metrics / ncols).astype(int)
    figsize = (n_models * ncols, nrows * 3)
    dpi = 160
    fig, axs = plt.subplots(
        ncols=ncols,
        nrows=nrows,
        # sharex=True,
        figsize=figsize,
        dpi=dpi,
        # gridspec_kw={"wspace": 1.0, "hspace": 1.0},
    )
    axs = axs.flatten()
    sns.despine(fig=fig, top=True, right=True)

    all_colors = sns.color_palette(
        "tab10",
        n_colors=n_splits + n_models,
    )
    model_colors = all_colors[:n_models]
    run_colors = all_colors[n_models:]
    # run_order = list(test_results.run_name.unique())

    ax_count = 0
    for i, l in enumerate(perf_metrics):

        ax = axs[i]
        # Convert metric index to ASCII letters...easier than hand-coding the alphabet.
        subfig_section = str(chr(97 + i)).upper()
        df = test_results.loc[test_results.metric == l].copy()
        df = df.loc[np.isin(df.model, models_to_display)]
        # Include newlines into model names to clean up visually
        df.model = df.model.apply(
            lambda x: x
            if len(x.split()) <= 2
            else " ".join(x.split()[:2]) + "\n" + " ".join(x.split()[2:])
        )
        # df.model = df.model.apply(lambda x: model_rename_map.get(x, x))

        # model_rename_mapping[
        # test_results.loc[np.isin(test_results.model, curr_comp)]
        vplot = sns.violinplot(
            x="model",
            y="value",
            data=df,
            ax=ax,
            scale="count",
            inner="box",
            palette=model_colors,
            # lw=0.5,
        )
        ax.grid(axis="y", alpha=0.5)

        # points_plot = sns.stripplot(
        #     x="model",
        #     y="value",
        #     hue="split",
        #     # hue_order=run_order,
        #     jitter=0.17,
        #     data=df,
        #     ax=ax,
        #     palette=run_colors,
        #     # palette=plt.get_cmap('gist_rainbow'),
        #     # color="black",
        #     edgecolor="white",
        #     size=2.8,
        #     linewidth=0.7,
        # )
        # points_plot.get_legend().remove()

        # Calculate mean performance score.
        means = df.groupby(["model", "split"]).mean().groupby("model").mean()
        # Make sure the order follows seaborn's x-axis ordering.
        model_order = list(map(lambda ax: ax.get_text(), axs[i].get_xticklabels()))
        means = means.reindex(model_order)

        #         lines = ax.hlines(
        #             y=means.value,
        #             xmin=np.arange(0, len(means)) - 0.5 + 0.05,
        #             xmax=np.arange(1, len(means) + 1) - 0.5 - 0.05,
        #             colors=model_colors,
        #             lw=1.5,
        #             zorder=1000,
        #         )

        #         outline_path_effects = [
        #             mpl.patheffects.Stroke(linewidth=5, foreground="white", alpha=0.9),
        #             mpl.patheffects.Normal(),
        #         ]
        #         lines.set_path_effects(outline_path_effects)

        ax.set_xticklabels(ax.get_xticklabels(), rotation=25)

        fig.canvas.draw()
        ax_format = ax.get_yaxis().get_major_formatter()

        #         for m, c in zip(means.value, model_colors):

        #             ax.annotate(
        #                 f"{m:.4g}",
        #                 xy=(ax.get_xlim()[0] + (ax.get_xlim()[0] * 0.4), m),
        #                 xycoords="data",
        #                 color=c,
        #                 ha="right",
        #                 va="center",
        #                 annotation_clip=False,
        #                 fontweight="bold",
        #                 snap=True,
        #                 bbox=dict(
        #                     boxstyle="square,pad=0.3", fc="white", lw=0, snap=True, alpha=0.75
        #                 ),
        #             )
        ax.set_title(metric_rename_map.get(l, l))
        ax.set_ylabel("")
        ax.set_xlabel("")
        ax.annotate(
            subfig_section,
            (-0.3, 1.2),
            xycoords="axes fraction",
            fontweight="bold",
            fontstyle="oblique",
            fontsize="x-large",
            horizontalalignment="left",
            verticalalignment="top",
        )

    for ax in axs[: (nrows - 1) * ncols]:
        # sns.despine(ax=ax, left=True, bottom=True, top=True, right=True, trim=True)
        # ax.set_yticks([])
        # ax.set_xticks([])
        # ax.set_yticklabels([])
        ax.set_xticklabels([])
# plt.suptitle("DIQT V-Fro DTI vs")

plt.savefig("main_results_violin.pdf")
plt.savefig("main_results_violin.png")

In [None]:
# Ablation result violin plot
n_splits = len(test_results.split.unique())
n_metrics = len(test_results.metric.unique())

# models_to_display = list(test_results.model.unique())
models_to_display = [
    "CARN DTI No Anat Net",
    "CARN LE No Anat Net",
    "CARN DTI Fake Anat",
    "CARN LE Fake Anat",
    "CARN DTI",
    "CARN LE (Proposed)",
]

n_models = len(models_to_display)

with mpl.rc_context({"font.size": 9.0}):

    ncols = 2
    nrows = 2  # np.ceil(n_metrics / ncols).astype(int)
    figsize = (n_models * ncols, nrows * 3)
    dpi = 160
    fig, axs = plt.subplots(
        ncols=ncols,
        nrows=nrows,
        # sharex=True,
        figsize=figsize,
        dpi=dpi,
        # gridspec_kw={"wspace": 1.0, "hspace": 1.0},
    )
    axs = axs.flatten()
    sns.despine(fig=fig, top=True, right=True)

    all_colors = sns.color_palette(
        "tab10",
        n_colors=n_splits + n_models,
    )
    model_colors = all_colors[:n_models]
    run_colors = all_colors[n_models:]
    # run_order = list(test_results.run_name.unique())

    ax_count = 0
    for i, l in enumerate(perf_metrics):

        ax = axs[i]
        # Convert metric index to ASCII letters...easier than hand-coding the alphabet.
        subfig_section = str(chr(97 + i)).upper()
        df = test_results.loc[test_results.metric == l].copy()
        df = df.loc[np.isin(df.model, models_to_display)]

        # df.model = df.model.apply(lambda x: model_rename_map.get(x, x))
        # print(np.unique(df.model))
        # Include newlines into model names to clean up visually
        df.model = df.model.apply(
            lambda x: x
            if len(x.split()) <= 2
            else " ".join(x.split()[:2]) + "\n" + " ".join(x.split()[2:])
        )
        # model_rename_mapping[
        # test_results.loc[np.isin(test_results.model, curr_comp)]
        vplot = sns.violinplot(
            x="model",
            y="value",
            data=df,
            ax=ax,
            scale="count",
            inner="box",
            palette=model_colors,
            # lw=0.5,
        )
        ax.grid(axis="y", alpha=0.5)

        # points_plot = sns.stripplot(
        #     x="model",
        #     y="value",
        #     hue="split",
        #     # hue_order=run_order,
        #     jitter=0.17,
        #     data=df,
        #     ax=ax,
        #     palette=run_colors,
        #     # palette=plt.get_cmap('gist_rainbow'),
        #     # color="black",
        #     edgecolor="white",
        #     size=2.8,
        #     linewidth=0.7,
        # )
        # points_plot.get_legend().remove()

        # Calculate mean performance score.
        means = df.groupby(["model", "split"]).mean().groupby("model").mean()
        # Make sure the order follows seaborn's x-axis ordering.
        model_order = list(map(lambda ax: ax.get_text(), axs[i].get_xticklabels()))
        means = means.reindex(model_order)

        #         lines = ax.hlines(
        #             y=means.value,
        #             xmin=np.arange(0, len(means)) - 0.5 + 0.05,
        #             xmax=np.arange(1, len(means) + 1) - 0.5 - 0.05,
        #             colors=model_colors,
        #             lw=1.5,
        #             zorder=1000,
        #         )

        #         outline_path_effects = [
        #             mpl.patheffects.Stroke(linewidth=5, foreground="white", alpha=0.9),
        #             mpl.patheffects.Normal(),
        #         ]
        #         lines.set_path_effects(outline_path_effects)

        ax.set_xticklabels(ax.get_xticklabels(), rotation=25)

        fig.canvas.draw()
        ax_format = ax.get_yaxis().get_major_formatter()

        #         for m, c in zip(means.value, model_colors):

        #             ax.annotate(
        #                 f"{m:.4g}",
        #                 xy=(ax.get_xlim()[0] + (ax.get_xlim()[0] * 0.4), m),
        #                 xycoords="data",
        #                 color=c,
        #                 ha="right",
        #                 va="center",
        #                 annotation_clip=False,
        #                 fontweight="bold",
        #                 snap=True,
        #                 bbox=dict(
        #                     boxstyle="square,pad=0.3", fc="white", lw=0, snap=True, alpha=0.75
        #                 ),
        #             )

        ax.set_title(metric_rename_map.get(l, l))
        ax.set_ylabel("")
        ax.set_xlabel("")
        ax.annotate(
            subfig_section,
            (-0.3, 1.2),
            xycoords="axes fraction",
            fontweight="bold",
            fontstyle="oblique",
            fontsize="x-large",
            horizontalalignment="left",
            verticalalignment="top",
        )

    for ax in axs[: (nrows - 1) * ncols]:
        # sns.despine(ax=ax, left=True, bottom=True, top=True, right=True, trim=True)
        # ax.set_yticks([])
        # ax.set_xticks([])
        # ax.set_yticklabels([])
        ax.set_xticklabels([])
plt.savefig("ablation_results_violin.pdf")
plt.savefig("ablation_results_violin.png")

In [None]:
# Ablation result violin plot, vertical
n_splits = len(test_results.split.unique())
n_metrics = len(test_results.metric.unique())

# models_to_display = list(test_results.model.unique())
models_to_display = [
    "Cubic Spline",
    "RevNet4",
    "CARN LE (Proposed)",
    "CARN DTI",
    "CARN LE No Anat Net",
    "CARN DTI No Anat Net",
    "CARN LE Fake Anat",
    "CARN DTI Fake Anat",
]

n_models = len(models_to_display)

with mpl.rc_context(
    {
        **sns.plotting_context("paper"),
        **{
            "font.size": 8.0,
            "axes.labelsize": 8.0,
            "axes.titlesize": 8.0,
            "xtick.labelsize": 8.0,
            "ytick.labelsize": 8.0,
            "legend.fontsize": 8.0,
            "legend.title_fontsize": 8.0,
            "axes.xmargin": 0.001,
            "axes.ymargin": 0.001,
            "ytick.major.pad": 1.0,
            "axes.labelpad": 1.0,
            "savefig.pad_inches": 0.01,
            "figure.subplot.left": 0.001,
            "figure.subplot.right": 0.999,  # the right side of the subplots of the figure
            "figure.subplot.bottom": 0.05,  # the bottom of the subplots of the figure
            "figure.subplot.top": 0.95,
            # "xtick.alignment": "right",
            # "figure.autolayout": True,
            # "figure.constrained_layout.use": True,
        },
    }
):

    ncols = 2
    nrows = 2  # np.ceil(n_metrics / ncols).astype(int)
    # Manuscript template is 122mm wide, so convert width to inches
    figsize = (122 / 25.4, ncols * 2.5)
    dpi = 200
    fig, axs = plt.subplots(
        ncols=ncols,
        nrows=nrows,
        # sharex=True,
        figsize=figsize,
        dpi=dpi,
        gridspec_kw={"wspace": 0.29},
    )
    axs = axs.flatten()
    sns.despine(fig=fig, top=True, right=True)

    all_colors = sns.color_palette(
        "tab10",
        n_colors=n_splits + n_models,
    )
    model_colors = all_colors[:n_models]
    run_colors = all_colors[n_models:]
    # run_order = list(test_results.run_name.unique())

    ax_count = 0
    for i, l in enumerate(perf_metrics):

        ax = axs[i]
        # Convert metric index to ASCII letters...easier than hand-coding the alphabet.
        subfig_section = str(chr(97 + i)).upper()
        df = test_results.loc[test_results.metric == l].copy()
        df = df.loc[np.isin(df.model, models_to_display)]

        # df.model = df.model.apply(lambda x: model_rename_map.get(x, x))
        # print(np.unique(df.model))
        # Include newlines into model names to clean up visually
        split_model_names_with_newlines = True
        df.model = df.model.apply(
            lambda x: x
            if len(x.split()) <= 2 or not split_model_names_with_newlines
            else " ".join(x.split()[:2]) + "\n" + " ".join(x.split()[2:])
        )

        model_display_order = list(
            map(
                lambda x: x
                if len(x.split()) <= 2 or not split_model_names_with_newlines
                else " ".join(x.split()[:2]) + "\n" + " ".join(x.split()[2:]),
                models_to_display,
            )
        )
        # model_rename_mapping[
        # test_results.loc[np.isin(test_results.model, curr_comp)]
        vplot = sns.violinplot(
            x="model",
            y="value",
            order=model_display_order,
            data=df,
            ax=ax,
            scale="width",
            inner="box",
            palette=model_colors,
            orient="v",
            linewidth=0.9,
        )
        ax.grid(axis="y", alpha=0.5)

        # # Calculate mean performance score.
        # means = df.groupby(["model", "split"]).mean().groupby("model").mean()
        # # Make sure the order follows seaborn's x-axis ordering.
        # model_order = list(map(lambda ax: ax.get_text(), axs[i].get_xticklabels()))
        # means = means.reindex(model_order)

        #         lines = ax.hlines(
        #             y=means.value,
        #             xmin=np.arange(0, len(means)) - 0.5 + 0.05,
        #             xmax=np.arange(1, len(means) + 1) - 0.5 - 0.05,
        #             colors=model_colors,
        #             lw=1.5,
        #             zorder=1000,
        #         )

        #         outline_path_effects = [
        #             mpl.patheffects.Stroke(linewidth=5, foreground="white", alpha=0.9),
        #             mpl.patheffects.Normal(),
        #         ]
        #         lines.set_path_effects(outline_path_effects)

        ax.set_xticklabels(ax.get_xticklabels(), rotation=25)

        ax.set_title(metric_rename_map.get(l, l))
        if "DTI RMFD" in metric_rename_map.get(l, l):
            ax.ticklabel_format(
                axis="y", useOffset=False, useMathText=True, scilimits=(0, 0)
            )
        ax.set_ylabel("")
        ax.set_xlabel("")
        # ax.annotate(
        #     subfig_section,
        #     (-0.2, 1.1),
        #     xycoords="axes fraction",
        #     fontweight="bold",
        #     fontstyle="oblique",
        #     fontsize="x-large",
        #     horizontalalignment="left",
        #     verticalalignment="top",
        # )

    for ax in axs[: (nrows - 1) * ncols]:
        ax.set_xticklabels([])

# plt.savefig("all_results_violin.pdf")
# plt.savefig("all_results_violin.png")

In [None]:
# All results violin plot, horizontal
n_splits = len(test_results.split.unique())
n_metrics = len(test_results.metric.unique())

# models_to_display = list(test_results.model.unique())
models_to_display = [
    "Cubic Spline",
    "RevNet4",
    "CARN LE (Proposed)",
    "CARN DTI",
    "CARN LE No Anat Net",
    "CARN DTI No Anat Net",
    "CARN LE Fake Anat",
    "CARN DTI Fake Anat",
]

n_models = len(models_to_display)

with mpl.rc_context(
    {
        **sns.plotting_context("paper"),
        **{
            "font.size": 6.0,
            "axes.labelsize": 6.0,
            # "axes.titlesize": 8.0,
            "xtick.labelsize": 6.0,
            "ytick.labelsize": 6.0,
            "legend.fontsize": 6.0,
            # "ytick.alignment": "bottom",
            "legend.title_fontsize": 6.0,
            "axes.xmargin": 0.02,
            "axes.ymargin": 0.02,
            # "ytick.major.pad": 1.0,
            "axes.labelpad": 1.0,
            "savefig.pad_inches": 0.0,
            "figure.subplot.left": 0.00,
            "figure.subplot.right": 1.0,  # the right side of the subplots of the figure
            "figure.subplot.bottom": 0.00,  # the bottom of the subplots of the figure
            "figure.subplot.top": 1.0,
            # "xtick.alignment": "right",
            # "figure.autolayout": True,
            # "figure.constrained_layout.use": True,
        },
    }
):

    ncols = 2
    nrows = 2  # np.ceil(n_metrics / ncols).astype(int)
    # Manuscript template is 122mm wide, so convert width to inches
    figsize = (122 / 25.4, nrows * 2.1)
    dpi = 800
    # dpi = 200
    fig, axs = plt.subplots(
        ncols=ncols,
        nrows=nrows,
        # sharex=True,
        figsize=figsize,
        dpi=dpi,
        # frameon=False,
        # facecolor='red',
        # gridspec_kw={"wspace": 0.2, "hspace": 0.2},
    )
    axs = axs.flatten()
    sns.despine(fig=fig, top=True, right=True)

    all_colors = sns.color_palette(
        "tab10",
        n_colors=n_splits + n_models,
    )
    model_colors = all_colors[:n_models]
    run_colors = all_colors[n_models:]
    # run_order = list(test_results.run_name.unique())

    ax_count = 0
    for i, l in enumerate(perf_metrics):

        ax = axs[i]
        # Convert metric index to ASCII letters...easier than hand-coding the alphabet.
        subfig_section = str(chr(97 + i)).upper()
        df = test_results.loc[test_results.metric == l].copy()
        df = df.loc[np.isin(df.model, models_to_display)]

        # df.model = df.model.apply(lambda x: model_rename_map.get(x, x))
        # print(np.unique(df.model))
        # Include newlines into model names to clean up visually
        split_model_names_with_newlines = False
        df.model = df.model.apply(
            lambda x: x
            if len(x.split()) <= 2 or not split_model_names_with_newlines
            else " ".join(x.split()[:2]) + "\n" + " ".join(x.split()[2:])
        )

        model_display_order = list(
            map(
                lambda x: x
                if len(x.split()) <= 2 or not split_model_names_with_newlines
                else " ".join(x.split()[:2]) + "\n" + " ".join(x.split()[2:]),
                models_to_display,
            )
        )

        vplot = sns.violinplot(
            y="model",
            x="value",
            order=model_display_order,
            data=df,
            ax=ax,
            scale="width",
            inner="box",
            palette=model_colors,
            orient="h",
            linewidth=0.9,
        )

        ax.set_title(metric_rename_map.get(l, l))
        ax.set_ylabel("")
        ax.set_xlabel("")

        ax.grid(True, axis="x", alpha=0.5)
        ax.set_axisbelow(False)

        # Allow for more major tick lines, if possible.
        # <https://e2eml.school/matplotlib_ticks.html>
        ax.xaxis.set_major_locator(
            mpl.ticker.MaxNLocator(nbins=4, steps=[1, 2, 2.5, 5, 10])
        )

    for i_ax in [1, 3]:
        axs[i_ax].set_yticklabels([])

#         for i_model, model_label in enumerate(ax.get_yticklabels()):
#             # Get location of label bounding box.
#             bbox = model_label.get_window_extent()
#             ax_bbox = ax.transAxes.inverted().transform(bbox)
#             left, bottom = tuple(ax_bbox[0, :])
#             right, top = 0.0, bottom

#             tick_line = ax.get_yticklines(minor=False)[i_model * 2]
#             # Copy properties from the tick line into the under line.
#             label_under_line = mpl.lines.Line2D(
#                 xdata=[left, right],
#                 ydata=[bottom, top],
#                 # color=tick_line.get_color(),
#                 # linestyle='-',
#                 # linewidth=tick_line.get_linewidth(),
#                 # marker=tick_line.get_marker(),
#                 # transform=ax.transAxes,
#                 # clip_on=False,
#             )
#             label_under_line.update_from(tick_line)
#             label_under_line.set(
#                 xdata=[left, right],
#                 ydata=[bottom, top],
#                 linestyle='-',
#                 transform=ax.transAxes,
#                 clip_on=False,
#             )

#             ax.add_line(label_under_line)

plt.savefig("all_results_horiz_violin.pdf", bbox_inches="tight", pad_inches=0.02)
plt.savefig("all_results_horiz_violin.png", bbox_inches="tight", pad_inches=0.02)

In [None]:
# # Create violinplot
# curr_comp = [
#     "CARN DTI",
#     "CARN LE",
#     # "CARN Fake Anat DTI", "CARN Fake Anat LE",
#     "CARN Anat DTI",
#     "CARN Anat LE",
# ]
# tr = test_results.loc[np.isin(test_results.model, curr_comp)]

In [None]:
# n_splits = len(test_results.split.unique())
# n_metrics = len(test_results.metric.unique())
# n_models = len(test_results.model.unique())

# with mpl.rc_context({"font.size": 10.0}):

#     ncols = 2
#     nrows = 2  # np.ceil(n_metrics / ncols).astype(int)
#     # figsize = (n_models * ncols, nrows * 3)
#     figsize = (len(curr_comp) * ncols, nrows * 3)

#     fig, axs = plt.subplots(
#         ncols=ncols,
#         nrows=nrows,
#         # sharex=True,
#         figsize=figsize,
#         dpi=160,
#         # gridspec_kw={"wspace": 1.0, "hspace": 1.0},
#     )
#     axs = axs.flatten()
#     sns.despine(fig=fig, top=True, right=True)

#     all_colors = sns.color_palette(
#         "tab10",
#         n_colors=n_splits + n_models,
#     )
#     model_colors = all_colors[:n_models]
#     run_colors = all_colors[n_models:]
#     # run_order = list(test_results.run_name.unique())

#     ax_count = 0
#     for i, l in enumerate(perf_metrics):

#         ax = axs[i]
#         df = tr.loc[tr.metric == l]

#         vplot = sns.violinplot(
#             x="model",
#             y="value",
#             data=df,
#             ax=ax,
#             scale="count",
#             inner=None,
#             palette=model_colors,
#             # lw=0.5,
#         )
#         ax.grid(axis="y", alpha=0.5)

#         points_plot = sns.stripplot(
#             x="model",
#             y="value",
#             hue="split",
#             # hue_order=run_order,
#             jitter=0.17,
#             data=df,
#             ax=ax,
#             palette=run_colors,
#             # palette=plt.get_cmap('gist_rainbow'),
#             # color="black",
#             edgecolor="white",
#             size=2.8,
#             linewidth=0.7,
#         )
#         points_plot.get_legend().remove()

#         # Calculate mean performance score.
#         means = df.groupby(["model", "split"]).mean().groupby("model").mean()
#         # Make sure the order follows seaborn's x-axis ordering.
#         model_order = list(map(lambda ax: ax.get_text(), axs[i].get_xticklabels()))
#         means = means.reindex(model_order)

#         lines = ax.hlines(
#             y=means.value,
#             xmin=np.arange(0, len(means)) - 0.5 + 0.05,
#             xmax=np.arange(1, len(means) + 1) - 0.5 - 0.05,
#             colors=model_colors,
#             lw=1.5,
#             zorder=1000,
#         )

#         outline_path_effects = [
#             mpl.patheffects.Stroke(linewidth=5, foreground="white", alpha=0.9),
#             mpl.patheffects.Normal(),
#         ]
#         lines.set_path_effects(outline_path_effects)

#         ax.set_xticklabels(ax.get_xticklabels(), rotation=25)

#         fig.canvas.draw()
#         ax_format = ax.get_yaxis().get_major_formatter()

#         for m, c in zip(means.value, model_colors):

#             ax.annotate(
#                 f"{m:.4g}",
#                 xy=(ax.get_xlim()[0] + (ax.get_xlim()[0] * 0.4), m),
#                 xycoords="data",
#                 color=c,
#                 ha="right",
#                 va="center",
#                 annotation_clip=False,
#                 fontweight="bold",
#                 snap=True,
#                 bbox=dict(
#                     boxstyle="square,pad=0.3", fc="white", lw=0, snap=True, alpha=0.75
#                 ),
#             )
#         ax.set_title(f"{l.replace('_', ' ')} {perf_metrics_with_directions[l]}")
#         ax.set_ylabel("")
#         ax.set_xlabel("")

#     for ax in axs[: (nrows - 1) * ncols]:
#         # sns.despine(ax=ax, left=True, bottom=True, top=True, right=True, trim=True)
#         # ax.set_yticks([])
#         # ax.set_xticks([])
#         # ax.set_yticklabels([])
#         ax.set_xticklabels([])
# # plt.suptitle("DIQT V-Fro DTI vs")
# # plt.savefig("uvers_all_runs_violin.pdf")

In [None]:
# run_results[
#     "2022-03-25T08_21_45__uvers_pitn_anat_stream_le_split_3"
# ].metrics.subj_id.unique()

In [None]:
# runs_to_comp = [
#     "2022-03-27T08_29_27__uvers_pitn_single_stream_dti_split_3",
#     "2022-03-24T22_47_26__uvers_pitn_anat_stream_dti_split_3",
#     "2022-03-25T08_21_45__uvers_pitn_anat_stream_le_split_3",
# ]
# sub = "397154"
# vols_to_show = list()
# for r in runs_to_comp:
#     pred_dti = nib.load(run_results[r].pred_vols[sub]).get_fdata()
#     vols_to_show.append(pred_dti)

# fig = plt.figure(dpi=200, figsize=(8, 5))
# pitn.viz.plot_vol_slices(
#     *vols_to_show, slice_idx=(0.5, None, None), fig=fig, cmap="gray"
# )  # , vol_labels=runs_to_comp)

In [None]:
# subj_gt[sub].dti

In [None]:
# fig = plt.figure(dpi=200, figsize=(5, 5))
# pitn.viz.plot_vol_slices(
#     subj_gt[sub].lr_dti.numpy(),
#     subj_gt[sub].dti.numpy(),
#     slice_idx=(0.5, None, None),
#     fig=fig,
#     cmap="gray",
# )  # , vol_labels=runs_to_comp)

## Visual Comparison Grid

### Dense Comparison Grid

In [None]:
# select_subj_id = HOLDOUT_SUBJ_ID
# rot_k = 0

# row_labels = ("Color FA", r"$D_{x, x}", r"D_{y, z}")

# col_labels = (
#     "Downsample",
#     "Cubic Spline",
#     "RevNet4",
#     # "CARN DTI",
#     "CARN LE\n(Proposed)",
#     "Ground\nTruth",
# )

# runs_to_sample = (
#     "2022-03-27T05_51_39__test_spline",
#     "2022-03-23T04_03_59__uvers_espcn_revnet_split_3",
#     # "2022-03-24T22_47_26__uvers_pitn_anat_stream_dti_split_3",
#     "2022-03-25T08_21_45__uvers_pitn_anat_stream_le_split_3",
# )

# full_vol_cols = list()
# # Low-res input col
# lr_col = [
#     np.rot90(
#         np.moveaxis(pitn.viz.direction_map(subj_gt[select_subj_id].lr_dti), 0, -1),
#         k=rot_k,
#     ),
#     np.rot90(subj_gt[select_subj_id].lr_dti[0].cpu().detach().numpy(), k=rot_k),
#     np.rot90(subj_gt[select_subj_id].lr_dti[4].cpu().detach().numpy(), k=rot_k),
# ]

# # Zoom in with Nearest Neighbor interpolation to fit the size of other images.
# lr_col = [
#     scipy.ndimage.zoom(lr_col[0], (2, 2, 2, 1), order=0),
#     scipy.ndimage.zoom(lr_col[1], (2, 2, 2), order=0),
#     scipy.ndimage.zoom(lr_col[2], (2, 2, 2), order=0),
# ]

# full_vol_cols.append(lr_col)

# for r in runs_to_sample:

#     pred_dti_f = run_results[r].pred_vols[select_subj_id]
#     pred_dti = nib.load(pred_dti_f).get_fdata()

#     if pred_dti.shape[0] != 6:
#         pred_dti = np.moveaxis(pred_dti, -1, 0)

#     if pred_dti.shape == (6, 136, 166, 136):
#         pred_dti = np.pad(pred_dti, ((0, 0), (4, 4), (4, 4), (4, 4)))
#     pred_dti = np.rot90(pred_dti, axes=(1, 2), k=rot_k)

#     color_fa = pitn.viz.direction_map(pred_dti)
#     color_fa = np.moveaxis(color_fa, 0, -1)
#     dxx = pred_dti[0]
#     dyz = pred_dti[4]
#     model_col = [color_fa, dxx, dyz]
#     full_vol_cols.append(model_col)

# hr_col = [
#     np.rot90(
#         np.moveaxis(pitn.viz.direction_map(subj_gt[select_subj_id].dti), 0, -1), k=rot_k
#     ),
#     np.rot90(subj_gt[select_subj_id].dti[0].cpu().detach().numpy(), k=rot_k),
#     np.rot90(subj_gt[select_subj_id].dti[4].cpu().detach().numpy(), k=rot_k),
# ]
# full_vol_cols.append(hr_col)

# # for i in range(len(full_vol_cols)):
# #     for j in range(len(full_vol_cols[i])):
# #         v = full_vol_cols[i][j]
# #         # full_vol_cols[i, j] = np.rot90(np.rot90(v, k=1, axes=(1, 3)), k=2, axes=(2, 3))
# #         full_vol_cols[i][j] = np.rot90(v)

In [None]:
# [v.shape for x in full_vol_cols for v in x]

In [None]:
# zoom_out_grid_coords = (75, slice(20, 154), slice(20, 124))
# zoom_in_grid_coords = (75, slice(67, 107), slice(55, 95))

# with mpl.rc_context(
#     {
#         "font.size": 6.0,
#         "axes.labelpad": 10.0,
#         # "figure.autolayout": False,
#         # "figure.constrained_layout.use": True,
#     }
# ):
#     fig = plt.figure(dpi=250, figsize=(4, 4))

#     grid = ImageGrid(
#         fig,
#         111,  # similar to subplot(111)
#         aspect=False,
#         nrows_ncols=(len(row_labels) * 2, len(col_labels)),
#         axes_pad=0.01,  # pad between axes in inch.
#     )

#     for i_c, ax_col_list in enumerate(grid.axes_column):
#         full_vol_col = full_vol_cols[i_c]
#         for i_ax, j_r in zip(range(0, len(ax_col_list), 2), range(len(ax_col_list))):
#             if i_ax >= 2:
#                 cmap = "gray"
#             else:
#                 cmap = None
#             full_vol_im = full_vol_col[j_r]

#             zoom_out_ax = ax_col_list[i_ax]

#             zoom_out_im = full_vol_im[zoom_out_grid_coords]
#             zoom_out_ax.imshow(zoom_out_im, cmap=cmap)
#             zoom_out_ax.axis("off")

#             zoom_in_ax = ax_col_list[i_ax + 1]
#             zoom_in_im = full_vol_im[zoom_in_grid_coords]

#             if zoom_in_im.shape[-1] == 3:
#                 zoom_in_im = scipy.ndimage.zoom(zoom_in_im, (2.5, 2.5, 1), order=0)
#             else:
#                 zoom_in_im = scipy.ndimage.zoom(zoom_in_im, 2.5, order=0)
#             zoom_in_ax.imshow(zoom_in_im, cmap=cmap)
#             zoom_in_ax.axis("off")
# plt.savefig("dense_qual_results.pdf")
# plt.savefig("dense_qual_results.png")

### Compact Nested Comparison Grid

In [None]:
select_subj_id = HOLDOUT_SUBJ_ID
rot_k = 0

row_labels = ("Color FA", r"$D_{x, x}$", r"$D_{y, z}$")
col_labels = (
    "Downsample",
    "Cubic Spline",
    "RevNet4",
    # "CARN DTI",
    "CARN LE\n(Proposed)",
    "Ground\nTruth",
)

runs_to_sample = (
    "2022-03-27T05_51_39__test_spline",
    "2022-03-23T04_03_59__uvers_espcn_revnet_split_3",
    # "2022-03-24T22_47_26__uvers_pitn_anat_stream_dti_split_3",
    "2022-03-25T08_21_45__uvers_pitn_anat_stream_le_split_3",
)

full_vol_cols = list()
# Low-res input col
lr_col = [
    np.rot90(
        np.moveaxis(pitn.viz.direction_map(subj_gt[select_subj_id].lr_dti), 0, -1),
        k=rot_k,
    ),
    np.rot90(subj_gt[select_subj_id].lr_dti[0].cpu().detach().numpy(), k=rot_k),
    np.rot90(subj_gt[select_subj_id].lr_dti[4].cpu().detach().numpy(), k=rot_k),
]

# Zoom in with Nearest Neighbor interpolation to fit the size of other images.
lr_col = [
    scipy.ndimage.zoom(lr_col[0], (2, 2, 2, 1), order=0),
    scipy.ndimage.zoom(lr_col[1], (2, 2, 2), order=0),
    scipy.ndimage.zoom(lr_col[2], (2, 2, 2), order=0),
]

full_vol_cols.append(lr_col)

for r in runs_to_sample:

    pred_dti_f = run_results[r].pred_vols[select_subj_id]
    pred_dti = nib.load(pred_dti_f).get_fdata()

    if pred_dti.shape[0] != 6:
        pred_dti = np.moveaxis(pred_dti, -1, 0)

    if pred_dti.shape == (6, 136, 166, 136):
        pred_dti = np.pad(pred_dti, ((0, 0), (4, 4), (4, 4), (4, 4)))
    pred_dti = np.rot90(pred_dti, axes=(1, 2), k=rot_k)

    color_fa = pitn.viz.direction_map(pred_dti)
    color_fa = np.moveaxis(color_fa, 0, -1)
    dxx = pred_dti[0]
    dyz = pred_dti[4]
    model_col = [color_fa, dxx, dyz]
    full_vol_cols.append(model_col)
    print("Grabbed run ", r)

hr_col = [
    np.rot90(
        np.moveaxis(pitn.viz.direction_map(subj_gt[select_subj_id].dti), 0, -1), k=rot_k
    ),
    np.rot90(subj_gt[select_subj_id].dti[0].cpu().detach().numpy(), k=rot_k),
    np.rot90(subj_gt[select_subj_id].dti[4].cpu().detach().numpy(), k=rot_k),
]
full_vol_cols.append(hr_col)

# for i in range(len(full_vol_cols)):
#     for j in range(len(full_vol_cols[i])):
#         v = full_vol_cols[i][j]
#         # full_vol_cols[i, j] = np.rot90(np.rot90(v, k=1, axes=(1, 3)), k=2, axes=(2, 3))
#         full_vol_cols[i][j] = np.rot90(v)

In [None]:
[v.shape for x in full_vol_cols for v in x]

In [None]:
height_idx = 81
zoom_out_grid_coords = (height_idx, slice(20, 154), slice(20, 124))
zoom_in_grid_coords = (height_idx, slice(27, 67), slice(45, 85))

embed_im_zoom_factor = 2.0

row_vranges = list(map(lambda x: list((np.inf, -np.inf)), range(len(row_labels))))
# Normalize images within-modality.
for i_row in range(len(row_labels)):
    for j_col in range(len(col_labels)):
        full_vol_im = full_vol_cols[j_col][i_row]
        zoom_out_im = full_vol_im[zoom_out_grid_coords]

        row_vranges[i_row][0] = min(row_vranges[i_row][0], zoom_out_im.min())
        row_vranges[i_row][1] = max(row_vranges[i_row][1], zoom_out_im.max())


with mpl.rc_context(
    {
        **sns.plotting_context("paper"),
        **{
            "font.size": 8.0,
            # "axes.labelsize": 6.0,
            "axes.titlesize": 8.0,
            # "xtick.labelsize": 6.0,
            # "ytick.labelsize": 6.0,
            # "legend.fontsize": 6.0,
            # "ytick.alignment": "bottom",
            # "legend.title_fontsize": 6.0,
            "axes.xmargin": 0.02,
            "axes.ymargin": 0.02,
            # "ytick.major.pad": 1.0,
            # "axes.labelpad": 1.0,
            "savefig.pad_inches": 0.0,
            "figure.subplot.left": 0.00,
            "figure.subplot.right": 1.0,  # the right side of the subplots of the figure
            "figure.subplot.bottom": 0.00,  # the bottom of the subplots of the figure
            "figure.subplot.top": 1.0,
            # "xtick.alignment": "right",
            # "figure.autolayout": True,
            "figure.constrained_layout.use": True,
        },
    }
):

    # with mpl.rc_context(
    #     {
    #         "font.size": 8.0,
    #         # "axes.labelpad": 5.0,
    #         # "figure.autolayout": False,
    #         "figure.constrained_layout.use": True,
    #     }
    # ):
    fig = plt.figure(dpi=800, figsize=(FIG_WIDTH_INCHES, 3.1))

    grid = mpl.gridspec.GridSpec(
        nrows=len(row_labels),
        ncols=len(col_labels),
        figure=fig,
        wspace=0.01,
        hspace=0.005,
    )

    for i_row in range(grid.nrows):
        for j_col in range(grid.ncols):
            ax = fig.add_subplot(grid[i_row, j_col])
            if i_row >= 1:
                cmap = "gray"
            else:
                cmap = None
            # Handle labels & ticks
            # Set row & column labels
            if ax.get_subplotspec().is_first_col():
                ax.set_ylabel(row_labels[i_row])
            if ax.get_subplotspec().is_first_row():
                ax.set_xlabel(col_labels[j_col])
                ax.xaxis.set_label_position("top")
            # Empty all ticks and ticklabels.
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect("equal")

            # Grab actual volume to display.
            full_vol_im = full_vol_cols[j_col][i_row]

            # Plot base zoomed-out image.
            zoom_out_im = full_vol_im[zoom_out_grid_coords]
            im_plot = ax.imshow(
                zoom_out_im,
                cmap=cmap,
                interpolation="antialiased",
                origin="upper",
                vmin=row_vranges[i_row][0],
                vmax=row_vranges[i_row][1],
            )

            # Plot zoomed-in ROI
            zoom_in_im = full_vol_im[zoom_in_grid_coords]
            # Zoom in with nearest neighbor interpolation.
            if zoom_in_im.shape[-1] == 3:
                zoom_in_im = scipy.ndimage.zoom(
                    zoom_in_im, (embed_im_zoom_factor, embed_im_zoom_factor, 1), order=0
                )
            else:
                zoom_in_im = scipy.ndimage.zoom(
                    zoom_in_im, embed_im_zoom_factor, order=0
                )

            # Set up zoom-in as an inset axis.
            embed_x0 = zoom_out_im.shape[1] - zoom_in_im.shape[1] - 0.5
            embed_y0 = zoom_out_im.shape[0] - zoom_in_im.shape[0] - 0.5
            # Shift origin point by a fraction of the zoom-in size
            embed_x0 += zoom_in_im.shape[1] * 0.3
            embed_y0 += zoom_in_im.shape[0] * 0.15

            inset_ax = ax.inset_axes(
                bounds=[embed_x0, embed_y0, zoom_in_im.shape[1], zoom_in_im.shape[0]],
                transform=ax.transData,
            )

            inset_ax.imshow(
                zoom_in_im,
                cmap=cmap,
                origin="upper",
                interpolation="antialiased",
                vmin=row_vranges[i_row][0],
                vmax=row_vranges[i_row][1],
            )

            inset_ax.set_xticks([])
            inset_ax.set_yticks([])
            inset_ax.set_xticklabels([])
            inset_ax.set_yticklabels([])
            inset_ax.set_aspect("equal")
            # Change frame around inset image to yellow.
            plt.setp(inset_ax.spines.values(), color="yellow")

            # Outline zoomed ROI with yellow on the zoomed-out image.
            rect_orig = (
                zoom_in_grid_coords[2].start - zoom_out_grid_coords[2].start - 0.5,
                zoom_in_grid_coords[1].start - zoom_out_grid_coords[1].start - 0.5,
            )
            rect_size = (
                zoom_in_grid_coords[2].stop - zoom_in_grid_coords[2].start,
                zoom_in_grid_coords[1].stop - zoom_in_grid_coords[1].start,
            )
            rect = mpl.patches.Rectangle(
                xy=rect_orig,
                width=rect_size[0],
                height=rect_size[1],
                edgecolor="yellow",
                lw=0.75,
                fill=False,
                transform=ax.transData,
            )
            ax.add_patch(rect)

plt.savefig("qual_results.pdf", bbox_inches="tight", pad_inches=0.01)
plt.savefig("qual_results.png", bbox_inches="tight", pad_inches=0.01)