# Spline Baseline


Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

---

References:

* `R. Tanno et al., “Uncertainty modelling in deep learning for safer neuroimage enhancement: Demonstration in diffusion MRI,” NeuroImage, vol. 225, p. 117366, Jan. 2021, doi: 10.1016/j.neuroimage.2020.117366.`
* `D. C. Alexander et al., “Image quality transfer and applications in diffusion MRI,” NeuroImage, vol. 152, pp. 283–298, May 2017, doi: 10.1016/j.neuroimage.2017.02.089.`
* `S. B. Blumberg, R. Tanno, I. Kokkinos, and D. C. Alexander, “Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images,” in Medical Image Computing and Computer Assisted Intervention – MICCAI 2018, Cham, 2018, pp. 118–125, doi: 10.1007/978-3-030-00928-1_14.`


## Imports & Environment Setup

### Imports

In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 2

# imports
import collections
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
import zipfile

import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.patheffects
import matplotlib.pyplot as plt
import seaborn as sns

# Data management libraries.
import nibabel as nib
import nibabel.processing
import natsort
from natsort import natsorted
import pprint
from pprint import pprint as ppr
import box
from box import Box

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchio
import pytorch_lightning as pl
import monai
import einops
import torchinfo

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy

import pitn

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [None]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    if "CUDA_PYTORCH_DEVICE_IDX" in os.environ.keys():
        dev_idx = int(os.environ["CUDA_PYTORCH_DEVICE_IDX"])
    else:
        dev_idx = 0
    device = torch.device(f"cuda:{dev_idx}")
    print("CUDA Device IDX ", dev_idx)
    torch.cuda.set_device(device)
    print("CUDA Current Device ", torch.cuda.current_device())
    print("CUDA Device properties: ", torch.cuda.get_device_properties(device))

    # Activate cudnn benchmarking to optimize convolution algorithm speed.
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = True
        print("CuDNN convolution optimization enabled.")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

### Specs Recording

In [None]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():
    # GPU information
    try:
        gpu_info = pitn.utils.system.get_gpu_specs()
        print(gpu_info)
    except NameError:
        print("CUDA Version: ", torch.version.cuda)
else:
    print("CUDA not in use, falling back to CPU")

In [None]:
# cap is defined in an ipython magic command
print(cap)

Author: Tyler Spears

Last updated: 2022-03-27T05:51:35.597651+00:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.23.1

Compiler    : GCC 7.3.0
OS          : Linux
Release     : 5.4.0-100-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

Git hash: d8162cca518d2e6c2f05523feb712a52b5ea3f05

seaborn          : 0.11.1
torch            : 1.10.2
pytorch_lightning: 1.5.10
scipy            : 1.5.3
natsort          : 7.1.1
einops           : 0.3.0
nibabel          : 3.2.1
dipy             : 1.4.1
box              : 5.4.1
torchio          : 0.18.30
json             : 2.0.9
matplotlib       : 3.4.1
sys              : 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0]
monai            : 0.8.0
ipywidgets       : 7.6.3
pitn             : 0.0.post1.dev132+g02c0d1a
torchinfo        : 1.6.2
pandas           : 1.2.3
skimage          : 0.18.1
numpy            : 1.20.2

  id  Name              Driver Version      CUDA Version  

### Data Variables & Definitions Setup

In [None]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"])
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"])
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

## Parameter Reading & Experiment Setup

### Parameters

<div class="alert alert-block alert-info"> <b>NOTE</b> Here are all the parameters! This makes it easy to find them! </div>

In [None]:
params = Box(default_box=True)

# General experiment-wide params
###############################################
params.experiment_name = "test_spline"
params.override_experiment_name = False
###############################################
# 6 channels for the 6 DTI components
params.n_channels = 6
params.n_subjs = 48
params.lr_vox_size = 2.5
params.fr_vox_size = 1.25
params.use_anat = False
params.use_half_precision_float = False
params.num_workers = 8
params.use_log_euclid = True

# Data params
params.data.fr_dir = data_dir / f"scale-{params.fr_vox_size:.2f}mm"
params.data.lr_dir = data_dir / f"scale-{params.lr_vox_size:.2f}mm"
params.data.dti_fname_pattern = r"sub-*dti.nii.gz"
params.data.mask_fname_pattern = r"dti/sub-*mask.nii.gz"

params.data.eigval_clip_cutoff = 0.00332008

# The data were downsampled artificially by this factor.
params.data.downsampled_by_factor = params.lr_vox_size / params.fr_vox_size
params.data.downsampled_by_factor = (
    int(params.data.downsampled_by_factor)
    if int(params.data.downsampled_by_factor) == params.data.downsampled_by_factor
    else params.data.downsampled_by_factor
)

# If a config file exists, override the defaults with those values.
try:
    if "PITN_CONFIG" in os.environ.keys():
        config_fname = Path(os.environ["PITN_CONFIG"])
    else:
        config_fname = pitn.utils.system.get_file_glob_unique(Path("."), r"config.*")
    f_type = config_fname.suffix.casefold()
    if f_type in {".yaml", ".yml"}:
        f_params = Box.from_yaml(filename=config_fname)
    elif f_type == ".json":
        f_params = Box.from_json(filename=config_fname)
    elif f_type == ".toml":
        f_params = Box.from_toml(filename=config_fname)
    else:
        raise RuntimeError()

    params.merge_update(f_params)

except:
    print("WARNING: Config file not loaded")
    pass

# Remove the default_box behavior now that params have been fully read in.
p = Box(default_box=False)
p.merge_update(params)
params = p

### Experiment Logging Setup

In [None]:
# tensorboard experiment logging setup.
EXPERIMENT_NAME = params.experiment_name

ts = datetime.datetime.now().replace(microsecond=0).isoformat()
# Break ISO format because many programs don't like having colons ':' in a filename.
ts = ts.replace(":", "_")
if not params.override_experiment_name:
    experiment_name = ts + "__" + EXPERIMENT_NAME
else:
    experiment_name = EXPERIMENT_NAME
run_name = experiment_name
# experiment_results_dir = results_dir / experiment_name

# Create temporary directory for results directory, in case experiment does not finish.
# Only grab directories that are timestamped starting with a year.
tmp_dirs = list(tmp_results_dir.glob("[0-9][0-9][0-9][0-9]*"))

# Only keep up to N tmp results.
n_tmp_to_keep = 5
if len(tmp_dirs) > (n_tmp_to_keep - 1):
    print(f"More than {n_tmp_to_keep} temporary results, culling to the most recent")
    tmps_to_delete = natsorted([str(tmp_dir) for tmp_dir in tmp_dirs])[
        : -(n_tmp_to_keep - 1)
    ]
    for tmp_dir in tmps_to_delete:
        shutil.rmtree(tmp_dir)
        print("Deleted temporary results directory ", tmp_dir)

experiment_results_dir = tmp_results_dir / experiment_name
# Final target directory, to be made when experiment is complete.
final_experiment_results_dir = results_dir / experiment_name

In [None]:
print(experiment_name)

2022-03-27T05_51_39__test_spline


In [None]:
# Pass this object into the pytorchlightning Trainer object, for easier logging within
# the training/testing loops.
pl_logger = pl.loggers.TensorBoardLogger(
    tmp_results_dir,
    name=experiment_name,
    version="",
    log_graph=False,
    default_hp_metric=False,
)
# Use the lower-level logger for logging histograms, images, etc.
logger = pl_logger.experiment

# Create a separate txt file to log streams of events & info besides parameters & results.
log_txt_file = Path(logger.log_dir) / "log.txt"
with open(log_txt_file, "a+") as f:
    f.write(f"Experiment Name: {experiment_name}\n")
    f.write(f"Timestamp: {ts}\n")
    # Parameters.
    f.write(pprint.pformat(params.to_dict()) + "\n")

## Data Loading

### Subject ID Selection

In [None]:
# Find data directories for each subject.
subj_dirs: Box = Box()

selected_ids = [
    "397154",
    "224022",
    "140117",
    "751348",
    "894774",
    "156637",
    "227432",
    "303624",
    "185947",
    "810439",
    "753251",
    "644246",
    "141422",
    "135528",
    "103010",
    "700634",
    "406432",
    "803240",
    "815247",
    "167238",
    "100408",
    "792867",
    "157437",
    "164030",
    "103515",
    "118730",
    "198047",
    "189450",
    "203923",
    "108828",
    "124220",
    "386250",
    "118124",
    "701535",
    "679770",
    "382242",
    "231928",
    "196952",  # Hold-out subject; for visualization, ensure never in the train or val sets
    "567961",
    "910241",
    "175035",
    "567759",
    "978578",
    "150019",
    "690152",
    "297655",
    "307127",
    "634748",
]
HOLDOUT_SUBJ_ID = "196952"
## Sub-set the chosen participants for dev and debugging!
sampled_subjs = random.sample(selected_ids, params.n_subjs)
if len(sampled_subjs) < len(selected_ids):
    warnings.warn(
        f"WARNING: Sub-selecting {len(sampled_subjs)}/{len(selected_ids)} "
        + "participants for dev and debugging. "
        + f"Subj IDs selected: {sampled_subjs}"
    )
selected_subjs = sampled_subjs
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_subjs = natsorted(selected_subjs)

for subj_id in selected_subjs:
    subj_dirs[subj_id] = Box()
    subj_dirs[subj_id].fr = pitn.utils.system.get_file_glob_unique(
        params.data.fr_dir, f"*{subj_id}*"
    )
    subj_dirs[subj_id].lr = pitn.utils.system.get_file_glob_unique(
        params.data.lr_dir, f"*{subj_id}*"
    )
    assert subj_dirs[subj_id].fr.exists()
    assert subj_dirs[subj_id].lr.exists()
ppr(subj_dirs)

{'100408': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-100408'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-100408')},
 '103010': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-103010'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-103010')},
 '103515': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-103515'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-103515')},
 '108828': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-108828'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-108828')},
 '118124': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/su

In [None]:
with open(log_txt_file, "a+") as f:
    f.write(f"Selected Subjects: {selected_ids}\n")

logger.add_text("subjs", pprint.pformat(selected_ids))

### Loading and Preprocessing

In [None]:
# Prep for Dataset loading.

# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=True)

# HR -> LR patch coordinate conversion function.
fr2lr_patch_coords_fn = {
    "lr_dti": functools.partial(
        pitn.coords.transform.int_downscale_patch_idx,
        downscale_factor=params.data.downsampled_by_factor,
        downscale_patch_shape=(11, 11, 11),
    )
}
fr2lr_patch_coords_fn["lr_log_euclid"] = fr2lr_patch_coords_fn["lr_dti"]
fr2lr_patch_coords_fn["lr_mask"] = fr2lr_patch_coords_fn["lr_dti"]

# Kwargs for the patches dataset (the _VolPatchDataset class) of the HR volumes.
patch_kwargs = dict(
    patch_shape=tuple(
        np.floor(np.asarray((11, 11, 11)) * params.data.downsampled_by_factor).astype(
            int
        )
    ),
    stride=1,
    meta_keys_to_patch_index={"dti", "log_euclid", "mask"},
    mask_name="mask",
)

# Coefficients to the log-euclidean lower triangle/6D vector that properly scales
# the Euclidean distance under the log-euclidean metrics.
mat_norm_coeffs = torch.ones(6)
mat_norm_coeffs[torch.as_tensor([1, 3, 4])] = np.sqrt(2)
mat_norm_coeffs = mat_norm_coeffs.reshape(-1, 1, 1, 1)


def fix_downsample_shape_errors(
    fr_vol: torch.Tensor, fr_affine: torch.Tensor, target_spatial_shape: tuple
):
    """Small utility to fix shape differences between LR and FR data."""
    target_shape = np.asarray(target_spatial_shape)
    if fr_vol.shape[1:] != tuple(target_shape):
        # Use torchio objects because they fix the affine matrix, too.
        # Flip before transform to pad on the right/top/furthest side of the dimension
        # first, before the left/bottom/closest.
        flip_vol = fr_vol.flip([1, 2, 3])
        im = torchio.ScalarImage(tensor=flip_vol, affine=fr_affine)
        transform = torchio.transforms.CropOrPad(target_spatial_shape, 0, copy=False)
        im = transform(im)
        result_vol = im["data"]
        # Unflip.
        result_vol = result_vol.flip([1, 2, 3])
        result_aff = im["affine"]
    else:
        result_vol = fr_vol
        result_aff = fr_affine

    return result_vol, result_aff


def orient_to_viz(vol, affine):

    if torch.is_tensor(vol):
        v = vol.detach().cpu().numpy()
    else:
        v = vol
    v = np.rot90(np.rot90(v, k=1, axes=(1, 3)), k=2, axes=(2, 3))
    if torch.is_tensor(vol):
        v = torch.from_numpy(np.copy(v)).to(vol)

    # Adjust the affine matrix.
    full_rot_aff = np.zeros_like(affine)
    full_rot_aff[-1, -1] = 1.0
    # 90 degree rot around the second axis.
    q1 = nib.quaternions.angle_axis2quat(np.pi / 2, [0, 1, 0])
    # 180 degree rot around the first axis.
    q2 = nib.quaternions.angle_axis2quat(np.pi, [1, 0, 0])
    new_q = nib.quaternions.mult(q1, q2)
    rot_aff = nib.quaternions.quat2mat(new_q)
    full_rot_aff[:-1, :-1] = rot_aff
    new_aff = full_rot_aff @ affine

    return v, new_aff

#### Data Loading

In [None]:
# Import and organize all data.
subj_data: dict = dict()

meta_keys_to_keep = {"affine", "original_affine"}

with torch.no_grad():
    for subj_id, subj_dir in subj_dirs.items():

        data = dict()
        data["subj_id"] = subj_id
        fr_subj_dir = subj_dirs[subj_id]["fr"]
        lr_subj_dir = subj_dirs[subj_id]["lr"]
        data["fr_subj_dir"] = fr_subj_dir
        data["lr_subj_dir"] = lr_subj_dir

        ####### Low-resolution DTIs/volumes
        lr_dti_f = pitn.utils.system.get_file_glob_unique(
            lr_subj_dir, params.data.dti_fname_pattern
        )
        im = nib_reader.read(lr_dti_f)
        lr_dti, meta = nib_reader.get_data(im)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        lr_dti = torch.from_numpy(lr_dti)
        lr_dti, meta["affine"] = orient_to_viz(lr_dti, meta["affine"])
        data["lr_dti"] = lr_dti
        data["lr_dti_meta_dict"] = meta

        # May need to handle shape errors when re-upscaling back from LR to HR.
        lr_dti_shape = np.asarray(lr_dti.shape[1:])
        target_fr_shape = np.floor(
            lr_dti_shape * params.data.downsampled_by_factor
        ).astype(int)

        ####### Full-resolution images/volumes.
        # DTI.
        dti_f = pitn.utils.system.get_file_glob_unique(
            fr_subj_dir, params.data.dti_fname_pattern
        )
        im = nib_reader.read(dti_f)
        dti, meta = nib_reader.get_data(im)
        dti = torch.from_numpy(dti)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        dti, meta["affine"] = fix_downsample_shape_errors(
            dti, meta["affine"], target_fr_shape
        )
        dti, meta["affine"] = orient_to_viz(dti, meta["affine"])
        data["dti"] = dti
        data["dti_meta_dict"] = meta

        # Diffusion mask.
        mask_f = pitn.utils.system.get_file_glob_unique(
            fr_subj_dir, params.data.mask_fname_pattern
        )
        im = nib_reader.read(mask_f)
        mask, meta = nib_reader.get_data(im)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        mask = torch.from_numpy(mask)
        # Add channel dim if not available.
        if mask.ndim == 3:
            mask = mask[None]
        mask, meta["affine"] = fix_downsample_shape_errors(
            mask, meta["affine"], target_fr_shape
        )
        mask, meta["affine"] = orient_to_viz(mask, meta["affine"])
        mask = mask.bool()
        data["mask"] = mask
        data["mask_meta_dict"] = meta

        # Consider this as the "noise correction" step to have more informative, consistent
        # results with minimal biasing. Otherwise, outliers (which are clearly errors) can
        # change loss and performance metrics by orders of magnitude for no good reason!
        if "eigval_clip_cutoff" in params.data and params.data.eigval_clip_cutoff:
            correct_dti = pitn.data.outliers.clip_dti_eigvals(
                data["dti"].to(device),
                tensor_components_dim=0,
                eigval_max=params.data.eigval_clip_cutoff,
            ).to(data["dti"])
            correct_dti = correct_dti * data["mask"]
            ####
            sae_fr = (
                F.l1_loss(data["dti"], correct_dti, reduction="none") * data["mask"]
            )
            sae_fr = sae_fr.view(6, -1).sum(1)
            mae_fr = sae_fr / torch.count_nonzero(data["mask"])
            print(f"---Subj {subj_id}---")
            print(
                "MAE of FR DTI after eigenvalue clipping:\n",
                mae_fr.tolist(),
            )
            ####
            data["dti"] = correct_dti

            correct_lr_dti = pitn.data.outliers.clip_dti_eigvals(
                data["lr_dti"].to(device),
                tensor_components_dim=0,
                eigval_max=params.data.eigval_clip_cutoff,
            ).to(data["lr_dti"])
            lr_mask = (data["lr_dti"] != 0).max(0, keepdim=True).values
            correct_lr_dti = correct_lr_dti * lr_mask
            ####
            sae_lr = (
                F.l1_loss(data["lr_dti"], correct_lr_dti, reduction="none") * lr_mask
            )
            sae_lr = sae_lr.view(6, -1).sum(1)
            mae_lr = sae_lr / torch.count_nonzero(lr_mask)
            print(
                "MAE of LR DTI after eigenvalue clipping:\n",
                mae_lr.tolist(),
            )
            ####
            data["lr_dti"] = correct_lr_dti

        ####### Log-euclid pre-computed volumes
        # Construct a quick and cheap mask for the LR DTI
        cheap_lr_mask = F.interpolate(
            data["mask"][None].float(),
            size=data["lr_dti"][0].shape,
            mode="nearest",
        )[0]
        data["lr_mask"] = cheap_lr_mask.bool()

        # LR log-euclid volume.
        lr_log_euclid = pitn.eig.tril_vec2sym_mat(data["lr_dti"], tril_dim=0)
        lr_log_euclid = pitn.riemann.log_euclid.log_map(lr_log_euclid)
        lr_log_euclid = pitn.eig.sym_mat2tril_vec(lr_log_euclid, tril_dim=0)
        lr_log_euclid = lr_log_euclid * mat_norm_coeffs
        data["lr_log_euclid"] = lr_log_euclid

        log_euclid = pitn.eig.tril_vec2sym_mat(data["dti"], tril_dim=0)
        log_euclid = pitn.riemann.log_euclid.log_map(log_euclid)
        log_euclid = pitn.eig.sym_mat2tril_vec(log_euclid, tril_dim=0)
        log_euclid = log_euclid * mat_norm_coeffs
        data["log_euclid"] = log_euclid

        ######## Normalized Subj Volumes
        fr_means = (
            torch.masked_select(data["dti"], data["mask"])
            .view(6, -1)
            .mean(1)
            .reshape(6, 1, 1, 1)
        )
        fr_stds = torch.std(
            torch.masked_select(data["dti"], data["mask"]).view(6, -1), 1
        ).reshape(6, 1, 1, 1)
        lr_means = (
            torch.masked_select(data["lr_dti"], data["lr_mask"])
            .view(6, -1)
            .mean(1)
            .reshape(6, 1, 1, 1)
        )
        lr_stds = torch.std(
            torch.masked_select(data["lr_dti"], data["lr_mask"]).view(6, -1), 1
        ).reshape(6, 1, 1, 1)

        data["dti_means"] = fr_means
        data["dti_stds"] = fr_stds
        data["lr_dti_means"] = lr_means
        data["lr_dti_stds"] = lr_stds

        vol_names = {
            "dti",
            "mask",
            "lr_dti",
            "lr_mask",
            "log_euclid",
            "lr_log_euclid",
        }
        metadata_names = set(data.keys()) - vol_names
        vol_d = {k: data[k] for k in vol_names}
        meta_d = {k: data[k] for k in metadata_names}

        # Create multi-volume dataset for this subj-session.
        subj_dataset = pitn.data.SubjSesDataset(
            vol_d,
            primary_vol_name="dti",
            special_secondary2primary_coords_fns=fr2lr_patch_coords_fn,
            transform=None,
            primary_patch_kwargs=patch_kwargs,
            **meta_d,
        )

        # Finalize this subject.
        subj_data[subj_id] = subj_dataset

        print("=" * 20)

print("===Data Loaded===")

### Calculate Aggregate Statistics for Normalization

## Model Training

### Set Up Patch-Based Data Loaders

In [None]:
# Data train/validation/test split

num_subjs = len(subj_data)
num_test_subjs = int(params.n_subjs)
num_val_subjs = 0
num_train_subjs = 0

subj_list = list(subj_data.keys())

test_subjs = subj_list[:num_test_subjs]
val_subjs = subj_list[num_test_subjs : (num_test_subjs + num_val_subjs)]
train_subjs = subj_list[(num_test_subjs + num_val_subjs) :]

# print(f"{num_train_subjs}/{num_subjs} Training subject ID(s): {train_subjs}")
# print(f"{num_val_subjs}/{num_subjs} Validation subject ID(s): {val_subjs}")
# print(f"{num_test_subjs}/{num_subjs} Test subject ID(s): {test_subjs}")
# with open(log_txt_file, "a+") as f:
#     f.write(f"{num_train_subjs}/{num_subjs} Training subject ID(s): {train_subjs}\n")
#     f.write(f"{num_val_subjs}/{num_subjs} Validation subject ID(s): {val_subjs}\n")
#     f.write(f"{num_test_subjs}/{num_subjs} Test subject ID(s): {test_subjs}\n")

# logger.add_text("train_subjs", str(train_subjs))
# logger.add_text("val_subjs", str(val_subjs))
# logger.add_text("test_subjs", str(test_subjs))

In [None]:
# Create Datasets and DataLoaders for test, validation, and training steps.
sample_kws = {
    "subj_id": "subj_id",
    "dti": "dti",
    "lr_dti": "lr_dti",
    "mask": "mask",
    "lr_mask": "lr_mask",
    "log_euclid": "log_euclid",
    "lr_log_euclid": "lr_log_euclid",
}

# Test & Validation
# Only need raw DTIs for testing and validation, not training.
tv_sample_kws = sample_kws
test_val_collate_fn = functools.partial(pitn.samplers.collate_dicts, **tv_sample_kws)

test_ds = list()
for subj_id in test_subjs:
    test_ds.append(subj_data[subj_id])
test_dataset = torch.utils.data.ConcatDataset(test_ds)
test_loader = monai.data.DataLoader(
    test_dataset, collate_fn=test_val_collate_fn, batch_size=1
)

### Calculate Aggregate Statistics for Normalization

In [None]:
# Only choose subjects in the training and validation datasets.

subj_agg_stats = Box(default_box=True)
subj_agg_stats.dti.min = torch.zeros(params.n_channels).to(
    list(subj_data.values())[0][0]["dti"]
)
subj_agg_stats.dti.max = torch.zeros(params.n_channels).to(subj_agg_stats.dti.min)

subj_agg_stats.log_euclid.min = subj_agg_stats.dti.min
subj_agg_stats.log_euclid.max = subj_agg_stats.dti.max

for subj_id in set(test_subjs):
    s = subj_data[subj_id]
    fr_mask = s[0]["mask"]
    dti = torch.masked_select(s[0]["dti"], fr_mask)
    subj_agg_stats.dti.min = torch.minimum(
        subj_agg_stats.dti.min, dti.view(params.n_channels, -1).min(-1).values
    )
    subj_agg_stats.dti.max = torch.maximum(
        subj_agg_stats.dti.max, dti.view(params.n_channels, -1).max(-1).values
    )

    lr_dti = s[0]["lr_dti"]
    lr_mask = (lr_dti == 0).all(0)[
        None,
    ]
    lr_dti = torch.masked_select(lr_dti, lr_mask)
    subj_agg_stats.dti.min = torch.minimum(
        subj_agg_stats.dti.min, lr_dti.view(params.n_channels, -1).min(-1).values
    )
    subj_agg_stats.dti.max = torch.maximum(
        subj_agg_stats.dti.max, lr_dti.view(params.n_channels, -1).max(-1).values
    )

    log_euclid = torch.masked_select(s[0]["log_euclid"], fr_mask).view(
        params.n_channels, -1
    )
    subj_agg_stats.log_euclid.min = torch.minimum(
        subj_agg_stats.log_euclid.min,
        log_euclid.view(params.n_channels, -1).min(-1).values,
    )
    subj_agg_stats.log_euclid.max = torch.maximum(
        subj_agg_stats.log_euclid.max,
        log_euclid.view(params.n_channels, -1).max(-1).values,
    )
    lr_log_euclid = torch.masked_select(s[0]["lr_log_euclid"], lr_mask).view(
        params.n_channels, -1
    )
    # lr_log_euclid = s[0]["lr_log_euclid"][lr_mask]
    subj_agg_stats.log_euclid.min = torch.minimum(
        subj_agg_stats.log_euclid.min,
        lr_log_euclid.view(params.n_channels, -1).min(-1).values,
    )
    subj_agg_stats.log_euclid.max = torch.maximum(
        subj_agg_stats.log_euclid.max,
        lr_log_euclid.view(params.n_channels, -1).max(-1).values,
    )

print(subj_agg_stats.dti.min)
print(subj_agg_stats.dti.max)
print(subj_agg_stats.log_euclid.min)
print(subj_agg_stats.log_euclid.max)

In [None]:
net_scalers = dict()

expander = functools.partial(einops.rearrange, pattern="c -> 1 c 1 1 1")
# Collect DTI global data features.

feat_min: torch.Tensor
feat_max: torch.Tensor
dti_min = expander(subj_agg_stats.dti.min)
dti_max = expander(subj_agg_stats.dti.max)


feat_min, feat_max = torch.as_tensor(
    [
        [0] * 6,
        [1] * 6,
    ]
)
print("PSNR will be scaled to ", feat_min, feat_max)
feat_min = expander(feat_min)
feat_max = expander(feat_max)

net_scalers["input_scaler"] = lambda o: o
net_scalers["output_descaler"] = lambda o: o

# PSNR is calculated on the final output tensor components, so no log-euclidean or
# scaling will occur here.
psnr_range_params = pitn.data.norm.GlobalScaleParams(
    feature_min=feat_min, feature_max=feat_max, data_min=dti_min, data_max=dti_max
)

### Model Definition

In [None]:
class DIQTESPCNSystem(pl.LightningModule):
    def __init__(
        self,
        channels: int,
        upscale_factor: int,
        psnr_range_params: pitn.data.norm.GlobalScaleParams,
    ):
        super().__init__()

        self._channels = channels
        self._upscale_factor = upscale_factor
        self._psnr_range_params = psnr_range_params
        # Coefficients to the log-euclidean lower triangle/6D vector that properly scales
        # the Euclidean distance under the log-euclidean metrics.
        self.mat_norm_coeffs = torch.ones(6).float()
        self.mat_norm_coeffs[torch.as_tensor([1, 3, 4])] = np.sqrt(2)
        self.mat_norm_coeffs = mat_norm_coeffs.reshape(-1, 1, 1, 1)
        # Parameters
        # Network parameters

        # My own dinky logging object.
        self.plain_log = Box(
            {
                "test_loss": {
                    "rmse": dict(),
                    "nrmse": dict(),
                    "rmse_log_euclid": dict(),
                    "nrmse_log_euclid": dict(),
                    "scaled_psnr": dict(),
                    "ssim_fa": dict(),
                    "rmse_fa": dict(),
                    "nrmse_fa": dict(),
                },
                "viz": {
                    "test_preds_dti": dict(),
                    "test_preds_le": dict(),
                },
            },
        )

    def training_step(self, batch, batch_idx):
        return None

    def validation_step(self, batch, batch_idx):
        return None

    def pred_spline(self, x, order=3):
        x_np = x.detach().cpu().numpy()
        ys = list()
        for b in x_np:
            y_b = scipy.ndimage.zoom(
                b,
                zoom=(
                    1,
                    self._upscale_factor,
                    self._upscale_factor,
                    self._upscale_factor,
                ),
                order=order,
            )

            ys.append(torch.from_numpy(y_b))

        y = torch.stack(ys, dim=0)
        y = y.to(x)

        return y

    def on_test_start(self):
        # Initialize the metrics as hyperparams so they appear under the tensorboard
        # hparams tab. From:
        # <https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#logging-hyperparameters>
        self.logger.log_hyperparams(
            self.hparams,
            {
                "hp/rmse": 0,
                "hp/nrmse": 0,
                "hp/scaled_psnr": 0,
                "hp/ssim_fa": 0,
                "hp/rmse_log_euclid": 0,
                "hp/nrmse_log_euclid": 0,
                "hp/rmse_fa": 0,
                "hp/nrmsa_fa": 0,
            },
        )

    def test_step(self, batch: dict, batch_idx):

        batch = Box(batch)
        # Assume batch size of 1 for the validation set.
        subj_id = batch.subj_id[0]
        x = batch.lr_dti.float()
        log_euclid_x = batch.lr_log_euclid.float()
        y = batch.dti.float()
        y_cropped = y
        log_euclid_y = batch.log_euclid
        log_euclid_y_cropped = log_euclid_y

        lr_mask = batch.lr_mask.bool()
        mask = batch.mask.bool()
        mask_cropped = mask
        lr_mask = batch.lr_mask.bool()

        ### Network predictions
        y_pred = self.pred_spline(x)

        # If the network operates in euclidian space natively, then take the original
        # prediction and log map it.
        log_euclid_y_pred = self.pred_spline(log_euclid_x)

        output_select_shape = tuple(
            torch.masked_select(y_pred, mask_cropped)
            .view(y_pred.shape[0], y_pred.shape[1], -1)
            .shape
        )
        # Mask select the target and prediction(s)
        y_cropped_select = torch.masked_select(y_cropped, mask_cropped).view(
            *output_select_shape
        )

        y_pred_select = torch.masked_select(y_pred, mask_cropped).view(
            *output_select_shape
        )

        ###### Calculate network performance metrics.
        # MSE metrics.
        rmse_loss = pitn.nn.loss.dti_root_vec_fro_norm_loss(
            y_pred, y_cropped, mask=mask_cropped, scale_off_diags=True, reduction="mean"
        )
        nrmse_loss = pitn.metrics.minmax_normalized_dti_root_vec_fro_norm(
            y_pred,
            y_cropped,
            mask=mask_cropped,
            scale_off_diags=True,
            reduction="mean",
        )

        rmse_log_euclid_loss = pitn.nn.loss.dti_root_vec_fro_norm_loss(
            log_euclid_y_pred,
            log_euclid_y_cropped,
            mask=mask_cropped,
            scale_off_diags=False,
            reduction="mean",
        )
        nrmse_log_euclid_loss = pitn.metrics.minmax_normalized_dti_root_vec_fro_norm(
            log_euclid_y_pred,
            log_euclid_y_cropped,
            mask=mask_cropped,
            scale_off_diags=False,
            reduction="mean",
        )

        # Need to reshape the mins and maxes to play nice with the mask-selected tensors.
        min_shape_pre = tuple(self._psnr_range_params.data_min.shape[:2])
        max_shape_pre = tuple(self._psnr_range_params.data_max.shape[:2])
        scaled_psnr_loss = pitn.metrics.psnr_batch_channel_regularized(
            y_pred_select,
            y_cropped_select,
            range_min=self._psnr_range_params.data_min.view(*min_shape_pre, -1),
            range_max=self._psnr_range_params.data_max.view(*max_shape_pre, -1),
        )

        #### FA metrics
        y_cropped_fa = pitn.metrics.fast_fa(y_cropped, foreground_mask=mask_cropped)
        y_pred_fa = pitn.metrics.fast_fa(y_pred, foreground_mask=mask_cropped)
        y_cropped_fa_select = torch.masked_select(y_cropped_fa, mask_cropped).view(
            y_cropped_fa.shape[0], y_cropped_fa.shape[1], -1
        )
        y_pred_fa_select = torch.masked_select(y_pred_fa, mask_cropped).view(
            y_pred_fa.shape[0], y_pred_fa.shape[1], -1
        )

        ssim_fa_loss = pitn.metrics.ssim_y_range(
            y_pred_fa,
            y_cropped_fa,
        )
        rmse_fa_loss = torch.sqrt(
            F.mse_loss(
                y_pred_fa_select,
                y_cropped_fa_select,
                reduction="mean",
            )
        )
        nrmse_fa_loss = pitn.metrics.minmax_normalized_rmse(
            y_pred_fa_select,
            y_cropped_fa_select,
            reduction="mean",
        )

        # on_epoch and reduce_fx gather the individual test epoch values and aggregate
        # them at the end of the testing epoch.
        self.log_dict(
            {
                "test_loss/rmse": rmse_loss,
                "test_loss/nrmse": nrmse_loss,
                "test_loss/nrmse_log_euclid": nrmse_log_euclid_loss,
                "test_loss/rmse_log_euclid": rmse_log_euclid_loss,
                "test_loss/scaled_psnr": scaled_psnr_loss,
                "test_loss/ssim_fa": ssim_fa_loss,
                "test_loss/rmse_fa": rmse_fa_loss,
                "test_loss/nrmse_fa": nrmse_fa_loss,
            },
            on_epoch=True,
            reduce_fx=torch.mean,
        )
        # Log loss as an hparam metric to be shown alongside the experiment's hparams.
        self.log_dict(
            {
                "hp/rmse": rmse_loss,
                "hp/nrmse": nrmse_loss,
                "hp/nrmse_log_euclid": nrmse_log_euclid_loss,
                "hp/rmse_log_euclid": rmse_log_euclid_loss,
                "hp/scaled_psnr": scaled_psnr_loss,
                "hp/ssim_fa": ssim_fa_loss,
                "hp/rmse_fa": rmse_fa_loss,
                "hp/nrmse_fa": nrmse_fa_loss,
            },
            on_epoch=True,
            reduce_fx=torch.mean,
        )
        self.plain_log.test_loss.rmse[subj_id] = rmse_loss.detach().cpu().item()
        self.plain_log.test_loss.nrmse[subj_id] = nrmse_loss.detach().cpu().item()
        self.plain_log.test_loss.rmse_log_euclid[subj_id] = (
            rmse_log_euclid_loss.detach().cpu().item()
        )
        self.plain_log.test_loss.nrmse_log_euclid[subj_id] = (
            nrmse_log_euclid_loss.detach().cpu().item()
        )
        self.plain_log.test_loss.scaled_psnr[subj_id] = (
            scaled_psnr_loss.detach().cpu().item()
        )
        self.plain_log.test_loss.ssim_fa[subj_id] = ssim_fa_loss.detach().cpu().item()
        self.plain_log.test_loss.rmse_fa[subj_id] = rmse_fa_loss.detach().cpu().item()
        self.plain_log.test_loss.nrmse_fa[subj_id] = nrmse_fa_loss.detach().cpu().item()

        # Store entire predicted DTI for saving & visualization.
        self.plain_log.viz.test_preds_dti[subj_id] = y_pred[0].detach().cpu()
        self.plain_log.viz.test_preds_le[subj_id] = log_euclid_y_pred[0].detach().cpu()

        return {
            "rmse": rmse_loss,
            "nrmse": nrmse_loss,
            "nrmse_log_euclid": nrmse_log_euclid_loss,
            "rmse_log_euclid": rmse_log_euclid_loss,
            "scaled_psnr": scaled_psnr_loss,
            "ssim_fa": ssim_fa_loss,
            "rmse_fa": rmse_fa_loss,
            "nrmse_fa": nrmse_fa_loss,
        }

### Training Loop

In [None]:
model_kwargs = dict(
    channels=params.n_channels,
    upscale_factor=params.data.downsampled_by_factor,
    psnr_range_params=psnr_range_params,
)

# Create model from given kwargs.
model = DIQTESPCNSystem(**model_kwargs)

# Create trainer object.
trainer = pl.Trainer(
    # fast_dev_run=12,
    accelerator="cpu",
    enable_checkpointing=False,
    logger=pl_logger,
)

# # Many warnings are produced here, so it's better for my sanity (i.e., worse in every other
# # way) to just filter and ignore them...
# # with warnings.catch_warnings():
# #     warnings.simplefilter("ignore")
# try:
#     trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
# except RuntimeError as e:
#     with open(log_txt_file, "a+") as f:
#         f.write("\n")
#         f.write("!!!!! Fatal ERROR !!!!!!!\n")
#         f.write("Traceback:\n")
#         f.write(str(e) + "\n")
#         f.write("!!!!!!!!!!!!!!!!!!!!!!!!!\n")
#         raise e

# train_duration = datetime.datetime.now().replace(microsecond=0) - train_start_timestamp
# print(f"Train duration: {train_duration}")

In [None]:
# # Save out trained model
# trainer.save_checkpoint(str(experiment_results_dir / "model.ckpt"))

In [None]:
# torch.save(model.state_dict(), experiment_results_dir / "model_state_dict.pt")

In [None]:
# with open(log_txt_file, "a+") as f:
#     f.write("\n")
#     f.write(f"Training time: {train_duration}\n")
#     f.write(
#         f"\t{train_duration.days} Days, "
#         + f"{train_duration.seconds // 3600} Hours,"
#         + f"{(train_duration.seconds // 60) % 60} Minutes,"
#         + f'{train_duration.seconds % 60} Seconds"\n'
#     )

In [None]:
# # Plot rolling average window of training loss values.
# plt.figure(dpi=110)
# window = 1000
# rolling_mean = (
#     np.convolve(model.plain_log["train_loss"], np.ones(window), "valid") / window
# )
# rolling_start = 100
# plt.plot(
#     np.arange(
#         window + rolling_start,
#         window + rolling_start + len(rolling_mean[rolling_start:]),
#     ),
#     rolling_mean[rolling_start:],
# )
# plt.title("Training Loss " + params.train.loss_name + f"\nRolling Mean {window}")
# plt.xlabel("Epoch")
# plt.ylabel("Loss")
# # plt.ylim(0, 1)
# print(np.median(rolling_mean))
# print(
#     np.mean(model.plain_log["train_loss"][: window + rolling_start]),
#     np.var(model.plain_log["train_loss"][: window + rolling_start]),
#     np.max(model.plain_log["train_loss"][: window + rolling_start]),
# )

# plt.savefig(experiment_results_dir / "train_loss.png")

## Model Testing/Evaluation

### Testing Loop

In [None]:
# mod_state = model.state_dict()

# test_mod = DIQTCascadeSystem(
#     channels=params.n_channels,
#     batch_size=params.train.batch_size,
#     in_patch_shape=params.train.in_patch_size,
#     upscale_factor=params.net.upscale_factor,
#     anat_batch_key=params.data.anat_type,
#     train_loss_method=params.train.loss_name,
#     opt_params=params.optim.kwargs,
#     hparams=hyperparams,
#     **params.net.kwargs,
# )
# test_mod.load_state_dict(mod_state)
# trainer.test(test_mod, dataloaders=test_loader, ckpt_path=None, verbose=True)

In [None]:
trainer.test(model, dataloaders=test_loader, ckpt_path=None, verbose=True)

In [None]:
test_losses = (
    "rmse",
    "nrmse",
    "rmse_log_euclid",
    "nrmse_log_euclid",
    "scaled_psnr",
    "ssim_fa",
    "rmse_fa",
    "nrmse_fa",
)
loss_comparison_directions = {
    "rmse": "↓",
    "nrmse": "↓",
    "rmse_log_euclid": "↓",
    "nrmse_log_euclid": "↓",
    "scaled_psnr": "↑",
    "ssim_fa": "↑",
    "rmse_fa": "↓",
    "nrmse_fa": "↓",
}

test_results = Box(subj_id=list(), model=list(), metric=list(), value=list())
for subj_id in test_subjs:
    for metric in test_losses:
        # DIQT model
        test_results.subj_id.append(subj_id)
        test_results.model.append("diqt")
        test_results.metric.append(metric)
        test_results.value.append(model.plain_log.test_loss[metric][subj_id])

# Convert to a pandas dataframe.
test_results = pd.DataFrame(test_results.to_dict())

with open(log_txt_file, "a+") as f:
    f.write(f"Test loss functions: {list(test_losses)}\n")

test_loss_log_file = experiment_results_dir / "test_loss.csv"
test_results.to_csv(test_loss_log_file, index=False)

### Evaluation Visualization

#### Comparison within experiment

In [None]:
# with mpl.rc_context(
#     {
#         "font.size": 8.0,
#     }
# ):
#     fig, axs = plt.subplots(
#         ncols=len(test_losses),
#         sharex=True,
#         figsize=(11, 4),
#         dpi=130,
#         gridspec_kw={"wspace": 1.0, "hspace": 0},
#     )
#     sns.despine(fig=fig, top=True, right=True)

#     for i, l in enumerate(test_losses):

#         ax = axs[i]
#         df = test_results.loc[test_results.metric == l]
#         vplot = sns.violinplot(
#             x="model", y="value", data=df, ax=ax, scale="count", inner=None
#         )
#         axs[i].grid(axis="y", alpha=0.5)
#         points_plot = sns.swarmplot(
#             x="model",
#             y="value",
#             hue="subj_id",
#             data=df,
#             ax=ax,
#             # color="white",
#             edgecolor="black",
#             size=4,
#             linewidth=0.8,
#         )
#         points_plot.get_legend().remove()

#         # Calculate mean performance score.
#         means = df.groupby("model").mean()
#         # Make sure the order follows seaborn's x-axis ordering.
#         model_order = list(map(lambda ax: ax.get_text(), axs[i].get_xticklabels()))
#         means = means.reindex(model_order)

#         # Grab colors corresponding to each model.
#         colors = sns.color_palette(None, n_colors=len(df.model.unique()))

#         lines = ax.hlines(
#             y=means.value,
#             xmin=np.arange(0, len(colors)) - 0.5 + 0.05,
#             xmax=np.arange(1, len(colors) + 1) - 0.5 - 0.05,
#             colors=colors,
#             lw=1.5,
#         )

#         outline_path_effects = [
#             mpl.patheffects.Stroke(linewidth=5, foreground="white", alpha=0.9),
#             mpl.patheffects.Normal(),
#         ]
#         lines.set_path_effects(outline_path_effects)

#         ax.set_xticklabels(ax.get_xticklabels(), rotation=25)

#         fig.canvas.draw()
#         ax_format = ax.get_yaxis().get_major_formatter()

#         for m, c in zip(means.value, colors):

#             ax.annotate(
#                 f"{m:.4g}",
#                 xy=(ax.get_xlim()[0] + (ax.get_xlim()[0] * 0.4), m),
#                 xycoords="data",
#                 color=c,
#                 ha="right",
#                 va="center",
#                 annotation_clip=False,
#                 fontweight="bold",
#                 snap=True,
#                 bbox=dict(
#                     boxstyle="square,pad=0.3", fc="white", lw=0, snap=True, alpha=0.75
#                 ),
#             )
#         ax.set_title(f"{l.replace('_', ' ')} {loss_comparison_directions[l]}")
#         ax.set_ylabel("")
#         ax.set_xlabel("")
# plt.savefig(experiment_results_dir / "test_result_within_experiment.png")

#### Comparison with other works

In [None]:
# Plot testing loss values over all patches.
fig, ax_prob = plt.subplots(figsize=(8, 4), dpi=120)
log_scale = False

hist = sns.histplot(
    list(model.plain_log["test_loss"].rmse.values()),
    alpha=0.5,
    stat="count",
    log_scale=log_scale,
    ax=ax_prob,
    legend=False,
    hatch="\\\\",
    ec="blue",
)
hist.yaxis.set_major_locator(mpl.ticker.MaxNLocator(integer=True))
plt.xlabel("Loss in $mm^2/second$")

# Draw means of different comparison models.
comparison_kwargs = {"ls": "-", "lw": 2.5}
# Plot the current DNN model performance.
plt.axvline(
    np.asarray(list(model.plain_log["test_loss"].rmse.values())).mean(),
    label="Current Model Mean",
    color="blue",
    **comparison_kwargs,
)

plt.axvline(
    9.72e-4,
    label="(Tanno etal, 2017 &\n Blumberg etal, 2018)\nESPCN Baseline",
    color="green",
    **comparison_kwargs,
)
plt.axvline(
    9.76e-4,
    label="(Tanno etal, 2017 &\n Blumberg etal, 2018)\nBest ESPCN\n[but not really]",
    color="purple",
    **comparison_kwargs,
)
# Best performing Blumberg, et. al., 2018 paper model.
plt.axvline(
    8.46e-4,
    label="(Blumberg etal, 2018)\nBest Overall",
    color="pink",
    **comparison_kwargs,
)

plt.legend(fontsize="small")
plt.title("Test Loss Histogram Over All Subjects with Test Metric RMSE")
plt.savefig(experiment_results_dir / "test_rmse_hist.png")

In [None]:
# Plot testing loss values over all subjects.
fig, ax = plt.subplots(figsize=(8, 4), dpi=120)

models = (
    "(Ours)\nCurrent Model",
    "(Tanno etal, 2021)\nC-Spline\nMean Approx.",
    "(Tanno etal, 2021)\nRand. Forest\nMean Approx.",
    "(Tanno etal, 2017 &\n Blumberg etal, 2018)\nESPCN Baseline",
    '(Tanno etal, 2017 &\n Blumberg etal, 2018)\n"Best" ESPCN',
    "(Blumberg etal, 2018)\nBest Overall",
)

rmse_bounds = np.asarray(
    [
        [0, 0],
        [31.738e-4, 10.069e-4],
        [23.139e-4, 6.974e-4],
        [13.609e-4, 6.212e-4],
        [13.82e-4, 6.29e-4],
        [12.13e-4, 5.58e-4],
    ]
)

rmse_scores = (
    np.asarray(list(model.plain_log["test_loss"].rmse.values())).mean(),
    rmse_bounds[2].mean(),
    rmse_bounds[3].mean(),
    9.72e-4,
    9.76e-4,
    8.46e-4,
)

rmse_score_ranges = np.asarray(rmse_scores)[:, None] - rmse_bounds
rmse_score_ranges[:1] = rmse_score_ranges[:1] * 0
rmse_score_ranges = rmse_score_ranges.T

colors = sns.color_palette("deep", n_colors=len(rmse_scores))

ax.grid(True, axis="y", zorder=1000)
ax.set_axisbelow(True)
# Plot our evaluation scores.
end_idx = 2
ax.bar(
    models[:end_idx],
    rmse_scores[:end_idx],
    color=colors[:end_idx],
    edgecolor="black",
    lw=0.75,
)
for container in ax.containers:
    if isinstance(container, mpl.container.BarContainer):
        ax.bar_label(container, fmt="%.3e")


start_idx = end_idx
# Plot the crazy evaluation scores.
ax.bar(
    models[start_idx:],
    height=rmse_bounds[start_idx:, 0] - rmse_bounds[start_idx:, 1],
    bottom=rmse_bounds[start_idx:, 1],
    color=colors[start_idx:],
    edgecolor="black",
    lw=0.75,
    alpha=0.8,
)
end_idx = start_idx + 2
bar_width = ax.patches[0].get_width()
# Dotted lines for the "approximate" actual rmse score.
ax.hlines(
    rmse_scores[start_idx:end_idx],
    xmin=np.arange(start_idx, end_idx) - bar_width / 2,
    xmax=np.arange(start_idx, end_idx) + bar_width / 2,
    color="black",
    ls="--",
)
for score, x in zip(
    rmse_scores[start_idx:end_idx], np.arange(start_idx, end_idx) - bar_width / 2
):
    ax.annotate(f"{score:.3e}", (x, score + 0.03 * np.asarray(rmse_scores).max()))

start_idx = end_idx
end_idx = len(models)
ax.hlines(
    rmse_scores[start_idx:],
    xmin=np.arange(start_idx, end_idx) - bar_width / 2,
    xmax=np.arange(start_idx, end_idx) + bar_width / 2,
    color="black",
    ls="-",
)

for score, x in zip(
    rmse_scores[start_idx:end_idx], np.arange(start_idx, end_idx) - bar_width / 2
):
    ax.annotate(f"{score:.3e}", (x, score + 0.03 * np.asarray(rmse_scores).max()))

ax.set_ylim(bottom=0, top=ax.get_ylim()[1] * 1.1)
ax.set_xlabel("Model")
ax.set_ylabel("Loss in $mm^2/second$")

ax.set_title("Mean Over Subjects Test Loss RMSE")
ax.set_xticks(models)
ax.set_xticklabels(
    models, fontsize="x-small", rotation=25, ha="right", rotation_mode="anchor"
)
plt.savefig(experiment_results_dir / "test_rmse_bar.png")

In [None]:
sorted_test_idx = np.argsort(np.asarray(list(model.plain_log.test_loss.rmse.values())))
sorted_test_results = dict(
    list(model.plain_log.test_loss.rmse.items())[i] for i in sorted_test_idx
)
ppr(sorted_test_results, sort_dicts=False)
print(np.mean(list(list(model.plain_log.test_loss.rmse.values()))))

In [None]:
diqt_results = test_results.loc[test_results.model == "diqt"]

logger.add_histogram(
    "test/rmse_dist", np.asarray(diqt_results.loc[diqt_results.metric == "rmse"].value)
)
logger.add_histogram(
    "test/nrmse_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "nrmse"].value),
)
logger.add_histogram(
    "test/rmse_log_euclid_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "rmse_log_euclid"].value),
)
logger.add_histogram(
    "test/nrmse_log_euclid_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "nrmse_log_euclid"].value),
)

logger.add_histogram(
    "test/scaled_psnr_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "scaled_psnr"].value),
)
logger.add_histogram(
    "test/ssim_fa_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "ssim_fa"].value),
)
logger.add_histogram(
    "test/rmse_fa_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "rmse_fa"].value),
)
logger.add_histogram(
    "test/nrmse_fa_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "nrmse_fa"].value),
)

## Whole-Volume Visualization

### Setup

In [None]:
# Debug flag(s)
disable_fig_save = False

In [None]:
# Create full 3D volumes of full-res ground truth, low-res downsample, full-res mask,
# anatomical image, and full-res predictions.

results_viz = Box(default_box=True)
with torch.no_grad():

    for subj in test_dataset:
        # Index into the only item in the subject dataset.
        s = Box(default_box=True)
        subj_id = subj["subj_id"]
        print(f"Starting subject {subj_id}")

        # Collect all variants of the volume and aggregate into one container object.
        s.mask = subj["mask"]
        s.dti = subj["dti"]
        s.affine = subj["dti_meta_dict"]["affine"]
        s.lr_dti = subj["lr_dti"]

        s.pred = model.plain_log.viz.test_preds_dti[subj_id]
        s.le_pred = model.plain_log.viz.test_preds_le[subj_id]

        s.dti = s.dti * s.mask
        s.pred = s.pred * s.mask
        s.le_pred = s.le_pred * s.mask
        s.metrics.rmse = model.plain_log.test_loss.rmse[subj_id]
        s.abs_error = torch.abs(s.pred - s.dti)

        for k in {
            "mask",
            "dti",
            "lr_dti",
            "pred",
            "le_pred",
            "abs_error",
        }:
            v = s[k]
            if torch.is_tensor(v):
                v = v.detach().cpu().numpy()
            if k == "mask":
                v = v.astype(bool)
            else:
                v = v.astype(float)
            # s[k] = np.rot90(v)
            s[k] = v

        results_viz[subj_id] = s
        print(f"Finished subject {subj_id}")

In [None]:
# Save out all network predictions to Nifti2 files and compress them into a zip archive.
if not disable_fig_save:
    img_names = list()
    for subj_id, viz in results_viz.items():
        nib_img = nib.Nifti2Image(viz.pred, viz.affine)

        filename = experiment_results_dir / f"{subj_id}_predicted_dti.nii.gz"
        nib.save(nib_img, str(filename))
        img_names.append(filename)

    with zipfile.ZipFile(experiment_results_dir / "predicted_dti.zip", "w") as fzip:
        for filename in img_names:
            fzip.write(
                filename,
                arcname=filename.name,
                compress_type=zipfile.ZIP_DEFLATED,
                compresslevel=6,
            )
            os.remove(filename)


# And the pre-anat predictions.
if not disable_fig_save:
    img_names = list()
    for subj_id, viz in results_viz.items():
        nib_img = nib.Nifti2Image(viz.le_pred, viz.affine)

        filename = experiment_results_dir / f"{subj_id}_predicted_le.nii.gz"
        nib.save(nib_img, str(filename))
        img_names.append(filename)

    with zipfile.ZipFile(experiment_results_dir / "predicted_le.zip", "w") as fzip:
        for filename in img_names:
            fzip.write(
                filename,
                arcname=filename.name,
                compress_type=zipfile.ZIP_DEFLATED,
                compresslevel=7,
            )
            os.remove(filename)
    # Make sure we exit the 'with' statement above.
    print("Done with files")

In [None]:
# Pick the worst performing subject from the test set to visualize.
sel_rmse = test_results.loc[
    (test_results.model == "diqt") & (test_results.metric == "rmse")
][["subj_id", "value"]]
sel_rmse = sel_rmse.sort_values("value")
# Or 2nd worst performing...
bad_rmse = sel_rmse.iloc[-2]
viz_subj_id = bad_rmse.subj_id
print(viz_subj_id, bad_rmse.value)
viz_subj = results_viz[viz_subj_id]

In [None]:
# Select indices for visualizing.
dti_shape = np.asarray(viz_subj.dti.shape[1:])
lr_dti_shape = np.asarray(viz_subj.lr_dti.shape[1:])

viz_idx = dti_shape // 2
# Last dimension (saggital) shouldn't be exactly centered, as the longitudinal fissure
# doesn't have many fibers outside the corpus collosum.
viz_idx[2] = viz_idx[2] + 6
viz_lr_idx = lr_dti_shape // 2
viz_lr_idx[2] = viz_lr_idx[2] + 6 // params.data.downsampled_by_factor

viz_slice_idx = [
    np.s_[:, viz_idx[0], :, :],
    np.s_[:, :, viz_idx[1], :],
    np.s_[:, :, :, viz_idx[2]],
]

viz_lr_slice_idx = [
    np.s_[:, viz_lr_idx[0], :, :],
    np.s_[:, :, viz_lr_idx[1], :],
    np.s_[:, :, :, viz_lr_idx[2]],
]

### FA-Weighted Direction Maps

In [None]:
ims = list()
for im, slices in zip(
    [viz_subj.dti, viz_subj.pred, viz_subj.lr_dti],
    (viz_slice_idx, viz_slice_idx, viz_lr_slice_idx),
):
    for sl in slices:
        selection = im[sl]
        fa_w = pitn.viz.direction_map(selection, channels_first=True)
        fa_w = fa_w.transpose(1, 2, 0)
        ims.append(fa_w)

dim_labels = [
    "Axial",
    "Coronal",
    "Saggital",
]
vol_labels = ["Ground Truth", "Pred", "LR Input"]

with mpl.rc_context(
    {
        "font.size": 12.0,
        "axes.labelpad": 10.0,
        "figure.autolayout": False,
        "figure.constrained_layout.use": True,
    }
):
    fig = plt.figure(dpi=130, figsize=(7, 7))
    fig = pitn.viz.plot_im_grid(
        ims,
        nrows=len(vol_labels),
        title=f"Subj {viz_subj_id} Prediction Results",
        row_headers=vol_labels,
        col_headers=dim_labels,
        colorbars=None,
        fig=fig,
        interpolation="antialiased",
    )
if not disable_fig_save:
    plt.savefig(experiment_results_dir / "dir_map_pred.png");

### DTI Channel-Wise Visualization

In [None]:
channel_names = [
    "$D_{x,x}$",
    "$D_{x,y}$",
    "$D_{y,y}$",
    "$D_{x,z}$",
    "$D_{y,z}$",
    "$D_{z,z}$",
]

dim_labels = [
    "\nAxial",
    "\nCor.",
    "\nSagg.",
]

dti_names = [
    "FR",
    "Pred",
    "LR",
    "Abs Err",
]

#### Global Normalization

In [None]:
# Display all 6 DTIs for ground truth, predicted, and root squared error
cmap = "gray"
title = f"DTI Subj {viz_subj_id} Summary"

with mpl.rc_context(
    {
        "font.size": 6.0,
        "axes.labelpad": 10.0,
        "figure.autolayout": False,
        "figure.constrained_layout.use": True,
    }
):
    fig = plt.figure(dpi=150, figsize=(8, 12))
    fig = pitn.viz.plot_vol_slices(
        viz_subj.dti,
        viz_subj.pred,
        viz_subj.lr_dti,
        viz_subj.abs_error,
        slice_idx=(None, None, viz_idx[2] / dti_shape[2]),
        title=title,
        vol_labels=dti_names,
        slice_labels=dim_labels,
        channel_labels=channel_names,
        colorbars="col",
        fig=fig,
        cmap=cmap,
        interpolation="antialiased",
    )


if not disable_fig_save:
    plt.savefig(
        experiment_results_dir / f"Spline_DTI_sub-{viz_subj_id}_pred_result.png"
    );

---

## End Experiment

In [None]:
pl_logger.experiment.flush()
# Close tensorboard logger.
# Don't finalize if the experiment was for debugging.
if "debug" not in EXPERIMENT_NAME.casefold():
    pl_logger.finalize("success")
    # Experiment is complete, move the results directory to its final location.
    if experiment_results_dir != final_experiment_results_dir:
        print("Moving out of tmp location")
        experiment_results_dir = experiment_results_dir.rename(
            final_experiment_results_dir
        )
        log_txt_file = experiment_results_dir / log_txt_file.name