# Pain in the Net - DIQT with Anatomical Refinement
Application of Deep Image Quality Transfer (DIQT) with domain adaptation.


Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

---

References:

* `R. Tanno et al., “Uncertainty modelling in deep learning for safer neuroimage enhancement: Demonstration in diffusion MRI,” NeuroImage, vol. 225, p. 117366, Jan. 2021, doi: 10.1016/j.neuroimage.2020.117366.`
* `D. C. Alexander et al., “Image quality transfer and applications in diffusion MRI,” NeuroImage, vol. 152, pp. 283–298, May 2017, doi: 10.1016/j.neuroimage.2017.02.089.`
* `S. B. Blumberg, R. Tanno, I. Kokkinos, and D. C. Alexander, “Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images,” in Medical Image Computing and Computer Assisted Intervention – MICCAI 2018, Cham, 2018, pp. 118–125, doi: 10.1007/978-3-030-00928-1_14.`


## Imports & Environment Setup

### Imports

In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 2

# imports
import collections
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
import zipfile

import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.patheffects
import matplotlib.pyplot as plt
import seaborn as sns

# Data management libraries.
import nibabel as nib
import nibabel.processing
import natsort
from natsort import natsorted
import pprint
from pprint import pprint as ppr
import box
from box import Box

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchio
import pytorch_lightning as pl
import monai
import einops
import torchinfo

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy

import pitn

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [None]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    if "CUDA_PYTORCH_DEVICE_IDX" in os.environ.keys():
        dev_idx = int(os.environ["CUDA_PYTORCH_DEVICE_IDX"])
    else:
        dev_idx = 0
    device = torch.device(f"cuda:{dev_idx}")
    print("CUDA Device IDX ", dev_idx)
    torch.cuda.set_device(device)
    print("CUDA Current Device ", torch.cuda.current_device())
    print("CUDA Device properties: ", torch.cuda.get_device_properties(device))
    # Activate cudnn benchmarking to optimize convolution algorithm speed.
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = True
        print("CuDNN convolution optimization enabled.")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

### Specs Recording

In [None]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():
    # GPU information
    try:
        gpu_info = pitn.utils.system.get_gpu_specs()
        print(gpu_info)
    except NameError:
        print("CUDA Version: ", torch.version.cuda)
else:
    print("CUDA not in use, falling back to CPU")

In [None]:
# cap is defined in an ipython magic command
print(cap)

### Data Variables & Definitions Setup

In [None]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"])
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"])
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

## Parameter Reading & Experiment Setup

### Parameters

<div class="alert alert-block alert-info"> <b>NOTE</b> Here are all the parameters! This makes it easy to find them! </div>

In [None]:
params = Box(default_box=True)

# General experiment-wide params
###############################################
params.experiment_name = "uvers_pitn_anat_stream_dti_split_1"
params.override_experiment_name = False
###############################################
# 6 channels for the 6 DTI components
params.n_channels = 6
params.n_subjs = 48
params.lr_vox_size = 2.5
params.fr_vox_size = 1.25
params.use_anat = True
params.use_log_euclid = False
params.use_half_precision_float = True
params.progress_bar = True
params.num_workers = 8

# Data params
params.data.fr_dir = data_dir / f"scale-{params.fr_vox_size:.2f}mm"
params.data.lr_dir = data_dir / f"scale-{params.lr_vox_size:.2f}mm"
params.data.dti_fname_pattern = r"sub-*dti.nii.gz"
params.data.mask_fname_pattern = r"dti/sub-*mask.nii.gz"
params.data.anat_descr = "t2w"
params.data.anat_fname_patterns = [
    f"sub-*t2w.nii.gz",
]
# The data were downsampled artificially by this factor.
params.data.downsampled_by_factor = params.lr_vox_size / params.fr_vox_size
params.data.downsampled_by_factor = (
    int(params.data.downsampled_by_factor)
    if int(params.data.downsampled_by_factor) == params.data.downsampled_by_factor
    else params.data.downsampled_by_factor
)

# This is the number of voxels to remove (read: center crop out) from the network's
# prediction. This allows for an "oversampling" of the low-res voxels to help inform a
# more constrained HR prediction. This value of voxels will be removed from each spatial
# dimension (D, H, and W) starting at the center of the output patches.
# Ex. A size of 1 will remove the 2 outer-most voxels from each dimension in the output,
# while still keeping the corresponding voxels in the LR input.
params.hr_center_crop_per_side = 0

# Maximum allowed eigenvalue for *all* DTIs. This was calculated as the median of the
# eigenvalue thresholds found in the "notebooks/data/dti_thresholding.ipynb" notebook.
# Actual computed value is 0.0033200803422369068, rounded here
# **This counts as outlier removal and will change both the training and test data**
params.data.eigval_clip_cutoff = 0.00332008

# Second data scaling method, where the training data will be scaled and possibly clipped,
# but the testing data will be compared on the originals.
# Scale input data by the valid values of each channel of the DTI.
# I.e., Dx,x in [0, 1], Dx,y in [-1, 1], Dy,y in [0, 1], Dy,z in [-1, 1], etc.
params.data.scale_method = "standard"

# Network params.
# The network's goal is to upsample the input by this factor.
params.net.upscale_factor = params.data.downsampled_by_factor
params.net.kwargs.n_res_units = 3
params.net.kwargs.n_dense_units = 3
params.net.kwargs.interior_channels = 24
params.net.kwargs.anat_in_channels = 1
params.net.kwargs.anat_interior_channels = 14
params.net.kwargs.anat_n_res_units = 2
params.net.kwargs.anat_n_dense_units = 2

params.net.kwargs.activate_fn = "elu"
params.net.kwargs.upsample_activate_fn = "elu"
params.net.kwargs.center_crop_output_side_amt = params.hr_center_crop_per_side

# Adam optimizer kwargs
params.optim.name = "AdamW"
params.optim.kwargs.lr = 2.5e-4
# params.optim.kwargs.lr = 1e-3
params.optim.kwargs.betas = (0.9, 0.999)
params.optim.kwargs.eps = (
    1e-8 if not params.use_half_precision_float else torch.finfo(torch.float16).tiny
)

# Force subject split to be read in from config file.
# # Testing params
# params.test.dataset_subj_percent = 0.4

# # Validation params
# params.val.dataset_subj_percent = 0.2

# # Testing params
# params.test.subjs = [
#     "701535",
#     "978578",
#     "118124",
#     "894774",
#     "185947",
#     "297655",
#     "135528",
#     "679770",
#     "792867",
#     "567961",
#     "189450",
#     "227432",
#     "108828",
#     "307127",
#     "156637",
#     "803240",
#     "164030",
#     "196952",
#     "753251",
#     "140117",
#     "103515",
#     "198047",
#     "124220",
#     "118730",
#     "303624",
#     "103010",
#     "397154",
#     "700634",
#     "810439",
#     "382242",
#     "203923",
#     "224022",
#     "175035",
#     "167238",
# ]
# params.test.dataset_n_subjs = len(params.test.subjs)

# # Validation params
# params.val.subjs = ["644246", "567759", "231928", "157437"]
# params.val.dataset_n_subjs = len(params.val.subjs)

# # Training params
# params.train.subjs = [
#     "634748",
#     "386250",
#     "751348",
#     "150019",
#     "910241",
#     "406432",
#     "815247",
#     "690152",
#     "141422",
#     "100408",
# ]
# params.train.dataset_n_subjs = len(params.train.subjs)

params.train.in_patch_size = (24, 24, 24)
params.train.batch_size = 32
params.train.samples_per_subj_per_epoch = 4000
params.train.max_epochs = 50
params.train.loss_name = "vfro"
params.train.lambda_dti_stream_loss = 0.35
# Percentage of subjs in dataset that go into the training set.
# params.train.dataset_subj_percent = 1 - (
#     params.test.dataset_subj_percent + params.val.dataset_subj_percent
# )
params.train.grad_2norm_clip_val = 0.25
params.train.accumulate_grad_batches = None
# Learning rate scheduler config.
params.train.lr_scheduler = None
# num_test_subjs = int(np.ceil(params.n_subjs * params.test.dataset_subj_percent))
# num_val_subjs = int(np.ceil(params.n_subjs * params.val.dataset_subj_percent))
# num_train_subjs = params.n_subjs - (num_test_subjs + num_val_subjs)
# params.train.lr_scheduler.kwargs.steps_per_epoch = (
#     params.train.samples_per_subj_per_epoch * num_train_subjs // params.train.batch_size
# )

# If a config file exists, override the defaults with those values.
try:
    if "PITN_CONFIG" in os.environ.keys():
        config_fname = Path(os.environ["PITN_CONFIG"])
    else:
        config_fname = pitn.utils.system.get_file_glob_unique(Path("."), r"config.*")
    f_type = config_fname.suffix.casefold()
    if f_type in {".yaml", ".yml"}:
        f_params = Box.from_yaml(filename=config_fname)
    elif f_type == ".json":
        f_params = Box.from_json(filename=config_fname)
    elif f_type == ".toml":
        f_params = Box.from_toml(filename=config_fname)
    else:
        raise RuntimeError()

    params.merge_update(f_params)

except:
    print("WARNING: Config file not loaded")
    pass

# Remove the default_box behavior now that params have been fully read in.
p = Box(default_box=False)
p.merge_update(params)
params = p

In [None]:
# Choose a subset of all params as the hyperparams of the model.
hyperparams = Box(
    batch_size=params.train.batch_size,
    samples_per_subj_epoch=params.train.samples_per_subj_per_epoch,
    epochs=params.train.max_epochs,
    loss_fn=params.train.loss_name,
    optim=params.optim.name,
    anat=params.data.anat_descr if params.use_anat else False,
    lambda_dti_loss=params.train.lambda_dti_stream_loss,
    n_subjs=params.n_subjs,
).to_dict()

if "lr_scheduler" in params.train and params.train.lr_scheduler:
    hyperparams["lr_scheduler"] = params.train.lr_scheduler.name

### Experiment Logging Setup

In [None]:
# tensorboard experiment logging setup.
EXPERIMENT_NAME = params.experiment_name

ts = datetime.datetime.now().replace(microsecond=0).isoformat()
# Break ISO format because many programs don't like having colons ':' in a filename.
ts = ts.replace(":", "_")
if not params.override_experiment_name:
    experiment_name = ts + "__" + EXPERIMENT_NAME
else:
    experiment_name = EXPERIMENT_NAME
run_name = experiment_name
# experiment_results_dir = results_dir / experiment_name

# Create temporary directory for results directory, in case experiment does not finish.
# Only grab directories that are timestamped starting with a year.
tmp_dirs = list(tmp_results_dir.glob("[0-9][0-9][0-9][0-9]*"))

# Only keep up to N tmp results.
n_tmp_to_keep = 5
if len(tmp_dirs) > (n_tmp_to_keep - 1):
    print(f"More than {n_tmp_to_keep} temporary results, culling to the most recent")
    tmps_to_delete = natsorted([str(tmp_dir) for tmp_dir in tmp_dirs])[
        : -(n_tmp_to_keep - 1)
    ]
    for tmp_dir in tmps_to_delete:
        shutil.rmtree(tmp_dir)
        print("Deleted temporary results directory ", tmp_dir)

experiment_results_dir = tmp_results_dir / experiment_name
# Final target directory, to be made when experiment is complete.
final_experiment_results_dir = results_dir / experiment_name

In [None]:
print(experiment_name)

In [None]:
# Pass this object into the pytorchlightning Trainer object, for easier logging within
# the training/testing loops.
pl_logger = pl.loggers.TensorBoardLogger(
    tmp_results_dir,
    name=experiment_name,
    version="",
    log_graph=False,
    default_hp_metric=False,
)
# Use the lower-level logger for logging histograms, images, etc.
logger = pl_logger.experiment

# Create a separate txt file to log streams of events & info besides parameters & results.
log_txt_file = Path(logger.log_dir) / "log.txt"
print_safe_params = Box(params.copy(), box_recast={"fr_dir": str, "lr_dir": str})
with open(log_txt_file, "a+") as f:
    f.write(f"Experiment Name: {experiment_name}\n")
    f.write(f"Timestamp: {ts}\n")
    # Parameters.
    f.write(pprint.pformat(print_safe_params.to_dict()) + "\n")
    # cap is defined in an ipython magic command
    f.write(f"Environment and Hardware Info:\n {cap}\n\n")
params_file = Path(logger.log_dir) / "run_params.yaml"
print_safe_params.to_yaml(filename=params_file)

## Data Loading

### Subject ID Selection

In [None]:
# Find data directories for each subject.
subj_dirs: Box = Box()

selected_ids = [
    "397154",
    "224022",
    "140117",
    "751348",
    "894774",
    "156637",
    "227432",
    "303624",
    "185947",
    "810439",
    "753251",
    "644246",
    "141422",
    "135528",
    "103010",
    "700634",
    "406432",
    "803240",
    "815247",
    "167238",
    "100408",
    "792867",
    "157437",
    "164030",
    "103515",
    "118730",
    "198047",
    "189450",
    "203923",
    "108828",
    "124220",
    "386250",
    "118124",
    "701535",
    "679770",
    "382242",
    "231928",
    "196952",  # Hold-out subject; for visualization, ensure never in the train or val sets
    "567961",
    "910241",
    "175035",
    "567759",
    "978578",
    "150019",
    "690152",
    "297655",
    "307127",
    "634748",
]
HOLDOUT_SUBJ_ID = "196952"
## Sub-set the chosen participants for dev and debugging!
sampled_subjs = random.sample(selected_ids, params.n_subjs)
if len(sampled_subjs) < len(selected_ids):
    warnings.warn(
        f"WARNING: Sub-selecting {len(sampled_subjs)}/{len(selected_ids)} "
        + "participants for dev and debugging. "
        + f"Subj IDs selected: {sampled_subjs}"
    )
selected_subjs = sampled_subjs
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_subjs = natsorted(selected_subjs)

for subj_id in selected_subjs:
    subj_dirs[subj_id] = Box()
    subj_dirs[subj_id].fr = pitn.utils.system.get_file_glob_unique(
        params.data.fr_dir, f"*{subj_id}*"
    )
    subj_dirs[subj_id].lr = pitn.utils.system.get_file_glob_unique(
        params.data.lr_dir, f"*{subj_id}*"
    )
    assert subj_dirs[subj_id].fr.exists()
    assert subj_dirs[subj_id].lr.exists()
ppr(subj_dirs)

In [None]:
with open(log_txt_file, "a+") as f:
    f.write(f"Selected Subjects: {selected_ids}\n")

logger.add_text("subjs", pprint.pformat(selected_ids))

### Loading and Preprocessing

In [None]:
# Prep for Dataset loading.

# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=True)

# HR -> LR patch coordinate conversion function.
fr2lr_patch_coords_fn = {
    "lr_dti": functools.partial(
        pitn.coords.transform.int_downscale_patch_idx,
        downscale_factor=params.data.downsampled_by_factor,
        downscale_patch_shape=params.train.in_patch_size,
    )
}
fr2lr_patch_coords_fn["lr_log_euclid"] = fr2lr_patch_coords_fn["lr_dti"]
fr2lr_patch_coords_fn["lr_mask"] = fr2lr_patch_coords_fn["lr_dti"]

# Kwargs for the patches dataset (the _VolPatchDataset class) of the HR volumes.
patch_kwargs = dict(
    patch_shape=tuple(
        np.floor(
            np.asarray(params.train.in_patch_size) * params.data.downsampled_by_factor
        ).astype(int)
    ),
    stride=1,
    meta_keys_to_patch_index={"dti", "log_euclid", "anat", "mask"},
    mask_name="mask",
)

data_scaler_cls = pitn.nn.norm.norm_method_lookup[params.data.scale_method]
data_scaler_kws = list()
for vol in ("dti", "log_euclid", "anat", "lr_dti", "lr_log_euclid"):
    data_scaler_kws.append(vol + "_scale_kwargs")

# Coefficients to the log-euclidean lower triangle/6D vector that properly scales
# the Euclidean distance under the log-euclidean metrics.
mat_norm_coeffs = torch.ones(6)
mat_norm_coeffs[torch.as_tensor([1, 3, 4])] = np.sqrt(2)
mat_norm_coeffs = mat_norm_coeffs.reshape(-1, 1, 1, 1)


def fix_downsample_shape_errors(
    fr_vol: torch.Tensor, fr_affine: torch.Tensor, target_spatial_shape: tuple
):
    """Small utility to fix shape differences between LR and FR data."""
    target_shape = np.asarray(target_spatial_shape)
    if fr_vol.shape[1:] != tuple(target_shape):
        # Use torchio objects because they fix the affine matrix, too.
        # Flip before transform to pad on the right/top/furthest side of the dimension
        # first, before the left/bottom/closest.
        flip_vol = fr_vol.flip([1, 2, 3])
        im = torchio.ScalarImage(tensor=flip_vol, affine=fr_affine)
        transform = torchio.transforms.CropOrPad(target_spatial_shape, 0, copy=False)
        im = transform(im)
        result_vol = im["data"]
        # Unflip.
        result_vol = result_vol.flip([1, 2, 3])
        result_aff = im["affine"]
    else:
        result_vol = fr_vol
        result_aff = fr_affine

    return result_vol, result_aff


def orient_to_viz(vol, affine):

    if torch.is_tensor(vol):
        v = vol.detach().cpu().numpy()
    else:
        v = vol
    v = np.rot90(np.rot90(v, k=1, axes=(1, 3)), k=2, axes=(2, 3))
    if torch.is_tensor(vol):
        v = torch.from_numpy(np.copy(v)).to(vol)

    # Adjust the affine matrix.
    full_rot_aff = np.zeros_like(affine)
    full_rot_aff[-1, -1] = 1.0
    # 90 degree rot around the second axis.
    q1 = nib.quaternions.angle_axis2quat(np.pi / 2, [0, 1, 0])
    # 180 degree rot around the first axis.
    q2 = nib.quaternions.angle_axis2quat(np.pi, [1, 0, 0])
    new_q = nib.quaternions.mult(q1, q2)
    rot_aff = nib.quaternions.quat2mat(new_q)
    full_rot_aff[:-1, :-1] = rot_aff
    new_aff = full_rot_aff @ affine

    return v, new_aff

#### Data Loading

In [None]:
# Import and organize all data.
subj_data: dict = dict()

meta_keys_to_keep = {"affine", "original_affine"}

with torch.no_grad():
    for subj_id, subj_dir in subj_dirs.items():

        data = dict()
        data["subj_id"] = subj_id
        fr_subj_dir = subj_dirs[subj_id]["fr"]
        lr_subj_dir = subj_dirs[subj_id]["lr"]
        data["fr_subj_dir"] = fr_subj_dir
        data["lr_subj_dir"] = lr_subj_dir

        ####### Low-resolution DTIs/volumes
        lr_dti_f = pitn.utils.system.get_file_glob_unique(
            lr_subj_dir, params.data.dti_fname_pattern
        )
        im = nib_reader.read(lr_dti_f)
        lr_dti, meta = nib_reader.get_data(im)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        lr_dti = torch.from_numpy(lr_dti)
        lr_dti, meta["affine"] = orient_to_viz(lr_dti, meta["affine"])
        data["lr_dti"] = lr_dti
        data["lr_dti_meta_dict"] = meta

        # May need to handle shape errors when re-upscaling back from LR to HR.
        lr_dti_shape = np.asarray(lr_dti.shape[1:])
        target_fr_shape = np.floor(lr_dti_shape * params.net.upscale_factor).astype(int)

        ####### Full-resolution images/volumes.
        # DTI.
        dti_f = pitn.utils.system.get_file_glob_unique(
            fr_subj_dir, params.data.dti_fname_pattern
        )
        im = nib_reader.read(dti_f)
        dti, meta = nib_reader.get_data(im)
        dti = torch.from_numpy(dti)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        dti, meta["affine"] = fix_downsample_shape_errors(
            dti, meta["affine"], target_fr_shape
        )
        dti, meta["affine"] = orient_to_viz(dti, meta["affine"])
        data["dti"] = dti
        data["dti_meta_dict"] = meta

        # Diffusion mask.
        mask_f = pitn.utils.system.get_file_glob_unique(
            fr_subj_dir, params.data.mask_fname_pattern
        )
        im = nib_reader.read(mask_f)
        mask, meta = nib_reader.get_data(im)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        mask = torch.from_numpy(mask)
        # Add channel dim if not available.
        if mask.ndim == 3:
            mask = mask[None]
        mask, meta["affine"] = fix_downsample_shape_errors(
            mask, meta["affine"], target_fr_shape
        )
        mask, meta["affine"] = orient_to_viz(mask, meta["affine"])
        mask = mask.bool()
        data["mask"] = mask
        data["mask_meta_dict"] = meta

        # Anatomical/structural volume.
        # May have multiple types of anatomical refinement volumes, so concat them
        # channel-wise and just call it 'anat'
        anat_vols = list()
        anat_meta = None
        for anat_f_pattern in params.data.anat_fname_patterns:
            anat_f = pitn.utils.system.get_file_glob_unique(fr_subj_dir, anat_f_pattern)
            im = nib_reader.read(anat_f)
            anat, meta = nib_reader.get_data(im)
            meta = {k: meta[k] for k in meta_keys_to_keep}
            anat = torch.from_numpy(anat)
            if anat.ndim == 3:
                anat = anat[
                    None,
                ]
            anat, meta["affine"] = fix_downsample_shape_errors(
                anat, meta["affine"], target_fr_shape
            )
            anat, meta["affine"] = orient_to_viz(anat, meta["affine"])
            # Just choose the first meta dict that is found, they should all be the
            # same...
            if anat_meta is None:
                anat_meta = meta
            anat_vols.append(anat)

        anat_vols = torch.concat(anat_vols, dim=0)
        data["anat"] = anat_vols
        data["anat_meta_dict"] = anat_meta

        # Construct a quick and cheap mask for the LR DTI
        cheap_lr_mask = F.interpolate(
            data["mask"][None].float(),
            size=data["lr_dti"][0].shape,
            mode="nearest",
        )[0]
        data["lr_mask"] = cheap_lr_mask.bool()

        # Consider this as the "noise correction" step to have more informative, consistent
        # results with minimal biasing. Otherwise, outliers (which are clearly errors) can
        # change loss and performance metrics by orders of magnitude for no good reason!
        if "eigval_clip_cutoff" in params.data and params.data.eigval_clip_cutoff:
            correct_dti = pitn.data.outliers.clip_dti_eigvals(
                data["dti"].to(device),
                tensor_components_dim=0,
                eigval_max=params.data.eigval_clip_cutoff,
            ).to(data["dti"])
            correct_dti = correct_dti * data["mask"]
            ####
            sae_fr = (
                F.l1_loss(data["dti"], correct_dti, reduction="none") * data["mask"]
            )
            sae_fr = sae_fr.view(6, -1).sum(1)
            mae_fr = sae_fr / torch.count_nonzero(data["mask"])
            print(f"---Subj {subj_id}---")
            print(
                "MAE of FR DTI after eigenvalue clipping:\n",
                mae_fr.tolist(),
            )
            ####
            data["dti"] = correct_dti

            correct_lr_dti = pitn.data.outliers.clip_dti_eigvals(
                data["lr_dti"].to(device),
                tensor_components_dim=0,
                eigval_max=params.data.eigval_clip_cutoff,
            ).to(data["lr_dti"])
            lr_mask = (data["lr_dti"] != 0).max(0, keepdim=True).values
            correct_lr_dti = correct_lr_dti * lr_mask
            ####
            sae_lr = (
                F.l1_loss(data["lr_dti"], correct_lr_dti, reduction="none") * lr_mask
            )
            sae_lr = sae_lr.view(6, -1).sum(1)
            mae_lr = sae_lr / torch.count_nonzero(lr_mask)
            print(
                "MAE of LR DTI after eigenvalue clipping:\n",
                mae_lr.tolist(),
            )
            ####
            data["lr_dti"] = correct_lr_dti

        ####### Log-euclid pre-computed volumes

        # LR log-euclid volume.
        lr_log_euclid = pitn.eig.tril_vec2sym_mat(data["lr_dti"], tril_dim=0)
        lr_log_euclid = pitn.riemann.log_euclid.log_map(lr_log_euclid)
        lr_log_euclid = pitn.eig.sym_mat2tril_vec(lr_log_euclid, tril_dim=0)
        lr_log_euclid = lr_log_euclid * mat_norm_coeffs
        data["lr_log_euclid"] = lr_log_euclid

        log_euclid = pitn.eig.tril_vec2sym_mat(data["dti"], tril_dim=0)
        log_euclid = pitn.riemann.log_euclid.log_map(log_euclid)
        log_euclid = pitn.eig.sym_mat2tril_vec(log_euclid, tril_dim=0)
        log_euclid = log_euclid * mat_norm_coeffs
        data["log_euclid"] = log_euclid

        ######## Normalized Subj Volumes
        if params.data.scale_method == "standard":
            # LR params.
            mask = data["lr_mask"]
            for vol in ("lr_dti", "lr_log_euclid"):
                scale_kwargs = data_scaler_cls.compute_scale_kwargs(
                    data[vol], mask=mask, batched=False
                )
                data[vol + "_scale_kwargs"] = scale_kwargs.copy()

            # FR params
            mask = data["mask"]
            for vol in ("dti", "log_euclid", "anat"):
                scale_kwargs = data_scaler_cls.compute_scale_kwargs(
                    data[vol], mask=mask, batched=False
                )
                data[vol + "_scale_kwargs"] = scale_kwargs.copy()

        vol_names = {
            "dti",
            "anat",
            "mask",
            "lr_dti",
            "lr_mask",
            "log_euclid",
            "lr_log_euclid",
        }

        metadata_names = set(data.keys()) - vol_names
        vol_d = {k: data[k] for k in vol_names}
        meta_d = {k: data[k] for k in metadata_names}

        # Create multi-volume dataset for this subj-session.
        subj_dataset = pitn.data.SubjSesDataset(
            vol_d,
            primary_vol_name="dti",
            special_secondary2primary_coords_fns=fr2lr_patch_coords_fn,
            transform=None,
            primary_patch_kwargs=patch_kwargs,
            **meta_d,
        )

        # Finalize this subject.
        subj_data[subj_id] = subj_dataset

        print("=" * 20)

print("===Data Loaded===")

## Model Training

### Set Up Patch-Based Data Loaders

In [None]:
# Data train/validation/test split

num_subjs = len(subj_data)
if "dataset_n_subjs" in params.test and params.test.dataset_n_subjs:
    num_test_subjs = int(params.test.dataset_n_subjs)
elif "dataset_subj_percent" in params.test and params.test.dataset_subj_percent:
    num_test_subjs = int(np.ceil(params.n_subjs * params.test.dataset_subj_percent))
else:
    raise RuntimeError("ERROR: Number of test subjects not given")

if "dataset_n_subjs" in params.val and params.val.dataset_n_subjs:
    num_val_subjs = int(params.val.dataset_n_subjs)
elif "dataset_subj_percent" in params.val and params.val.dataset_subj_percent:
    num_val_subjs = int(np.ceil(params.n_subjs * params.val.dataset_subj_percent))
else:
    raise RuntimeError("ERROR: Number of validation subjects not given")

if "dataset_n_subjs" in params.train and params.train.dataset_n_subjs:
    num_train_subjs = int(params.train.dataset_n_subjs)
else:
    num_train_subjs = max(1, params.n_subjs - (num_test_subjs + num_val_subjs))

if num_test_subjs + num_val_subjs + num_train_subjs > num_subjs:
    raise RuntimeError(
        "ERROR: Too many subjects selected train/val/test split "
        + f"{(num_test_subjs, num_val_subjs, num_train_subjs)}, "
        + f"only {num_subjs} available."
    )
subj_list = list(subj_data.keys())
# Randomly shuffle the list of subjects, then choose the first `num_test_subjs` subjects
# for testing.
random.shuffle(subj_list)

# make sure the pre-set holdout subject is in the *test set*, not in the training or
# validation sets.
if HOLDOUT_SUBJ_ID in subj_list:
    subj_list.remove(HOLDOUT_SUBJ_ID)
    subj_list.insert(0, HOLDOUT_SUBJ_ID)

# Choose the remaining for training/validation.
# If only 1 subject is available, assume this is a debugging run.
if num_subjs == 1:
    warnings.warn(
        "DEBUG: Only 1 subject selected, mixing training, validation, and testing sets"
    )
    num_train_subjs = num_val_subjs = num_test_subjs = 1

    test_subjs = subj_list
    val_subjs = subj_list
    train_subjs = subj_list
# If the TVS split is given in the parameters, take that split instead of a random one.
elif "subjs" in params.test and "subjs" in params.val and "subjs" in params.train:
    test_subjs = params.test.subjs
    val_subjs = params.val.subjs
    train_subjs = params.train.subjs
else:
    test_subjs = subj_list[:num_test_subjs]
    val_subjs = subj_list[num_test_subjs : (num_test_subjs + num_val_subjs)]
    train_subjs = subj_list[(num_test_subjs + num_val_subjs) :]

print(f"{num_train_subjs}/{num_subjs} Training subject ID(s): {train_subjs}")
print(f"{num_val_subjs}/{num_subjs} Validation subject ID(s): {val_subjs}")
print(f"{num_test_subjs}/{num_subjs} Test subject ID(s): {test_subjs}")
with open(log_txt_file, "a+") as f:
    f.write(f"{num_train_subjs}/{num_subjs} Training subject ID(s): {train_subjs}\n")
    f.write(f"{num_val_subjs}/{num_subjs} Validation subject ID(s): {val_subjs}\n")
    f.write(f"{num_test_subjs}/{num_subjs} Test subject ID(s): {test_subjs}\n")

logger.add_text("train_subjs", str(train_subjs))
logger.add_text("val_subjs", str(val_subjs))
logger.add_text("test_subjs", str(test_subjs))

In [None]:
# Create Datasets and DataLoaders for test, validation, and training steps.
sample_kws = {
    **{
        "subj_id": "subj_id",
        "dti": "dti",
        "anat": "anat",
        "lr_dti": "lr_dti",
        "mask": "mask",
        "lr_mask": "lr_mask",
        "log_euclid": "log_euclid",
        "lr_log_euclid": "lr_log_euclid",
    },
    **{k: k for k in data_scaler_kws},
}

# Select only the keys that are needed during training, to save on processing time
# spent indexing and transferring data to the GPU.
if params.use_log_euclid:
    train_sample_vol_keys = {
        "log_euclid",
        "lr_log_euclid",
    }
    train_sample_meta_keys = set(filter(lambda s: "log_euclid" in s, data_scaler_kws))
    train_sample_primary_key = "log_euclid"
else:
    train_sample_vol_keys = {
        "dti",
        "lr_dti",
    }
    train_sample_meta_keys = set(filter(lambda s: "dti" in s, data_scaler_kws))
    train_sample_primary_key = "dti"

train_sample_meta_keys.add("anat_scale_kwargs")
train_sample_vol_keys.add("anat")
train_sample_vol_keys.add("mask")
train_sample_vol_keys.add("lr_mask")

# Train
train_ds = list()
for subj_id in train_subjs:
    subj_data[subj_id].set_patch_sample_keys(
        train_sample_primary_key, *(train_sample_vol_keys - {train_sample_primary_key})
    )
    train_ds.append(subj_data[subj_id].patches)
train_dataset = torch.utils.data.ConcatDataset(train_ds)
train_sampler = pitn.samplers.ConcatDatasetBalancedRandomSampler(
    train_dataset.datasets,
    max_samples_per_dataset=params.train.samples_per_subj_per_epoch,
    # Make sure we resample, performance sucks otherwise.
    resample_after_empty=True,
)
train_collate_fn = functools.partial(
    pitn.samplers.collate_dicts,
    **dict(
        zip(
            train_sample_vol_keys | train_sample_meta_keys,
            train_sample_vol_keys | train_sample_meta_keys,
        )
    ),
)
train_loader = monai.data.DataLoader(
    train_dataset,
    sampler=train_sampler,
    batch_size=params.train.batch_size,
    collate_fn=train_collate_fn,
    pin_memory=True,
    num_workers=params.num_workers,
    persistent_workers=True,
    prefetch_factor=10,
)

# Test & Validation
# Only need raw DTIs for testing and validation, not training.
tv_sample_kws = sample_kws
test_val_collate_fn = functools.partial(pitn.samplers.collate_dicts, **tv_sample_kws)

test_ds = list()
for subj_id in test_subjs:
    test_ds.append(subj_data[subj_id])
test_dataset = torch.utils.data.ConcatDataset(test_ds)
test_loader = monai.data.DataLoader(
    test_dataset, collate_fn=test_val_collate_fn, batch_size=1
)

val_ds = list()
for subj_id in val_subjs:
    val_ds.append(subj_data[subj_id])
val_dataset = torch.utils.data.ConcatDataset(val_ds)
val_loader = monai.data.DataLoader(
    val_dataset, collate_fn=test_val_collate_fn, batch_size=1
)

### Calculate Aggregate Statistics for Normalization

In [None]:
# Only choose subjects in the training and validation datasets.

subj_agg_stats = Box(default_box=True)
subj_agg_stats.dti.min = torch.zeros(params.n_channels).to(
    list(subj_data.values())[0][0]["dti"]
)
subj_agg_stats.dti.max = torch.zeros(params.n_channels).to(subj_agg_stats.dti.min)

subj_agg_stats.log_euclid.min = subj_agg_stats.dti.min
subj_agg_stats.log_euclid.max = subj_agg_stats.dti.max

for subj_id in set(train_subjs).union(set(val_subjs)):
    s = subj_data[subj_id]
    fr_mask = s[0]["mask"]
    dti = torch.masked_select(s[0]["dti"], fr_mask)
    subj_agg_stats.dti.min = torch.minimum(
        subj_agg_stats.dti.min, dti.view(params.n_channels, -1).min(-1).values
    )
    subj_agg_stats.dti.max = torch.maximum(
        subj_agg_stats.dti.max, dti.view(params.n_channels, -1).max(-1).values
    )

    lr_dti = s[0]["lr_dti"]
    lr_mask = (lr_dti == 0).all(0)[
        None,
    ]
    lr_dti = torch.masked_select(lr_dti, lr_mask)
    subj_agg_stats.dti.min = torch.minimum(
        subj_agg_stats.dti.min, lr_dti.view(params.n_channels, -1).min(-1).values
    )
    subj_agg_stats.dti.max = torch.maximum(
        subj_agg_stats.dti.max, lr_dti.view(params.n_channels, -1).max(-1).values
    )

    log_euclid = torch.masked_select(s[0]["log_euclid"], fr_mask).view(
        params.n_channels, -1
    )
    subj_agg_stats.log_euclid.min = torch.minimum(
        subj_agg_stats.log_euclid.min,
        log_euclid.view(params.n_channels, -1).min(-1).values,
    )
    subj_agg_stats.log_euclid.max = torch.maximum(
        subj_agg_stats.log_euclid.max,
        log_euclid.view(params.n_channels, -1).max(-1).values,
    )
    lr_log_euclid = torch.masked_select(s[0]["lr_log_euclid"], lr_mask).view(
        params.n_channels, -1
    )
    # lr_log_euclid = s[0]["lr_log_euclid"][lr_mask]
    subj_agg_stats.log_euclid.min = torch.minimum(
        subj_agg_stats.log_euclid.min,
        lr_log_euclid.view(params.n_channels, -1).min(-1).values,
    )
    subj_agg_stats.log_euclid.max = torch.maximum(
        subj_agg_stats.log_euclid.max,
        lr_log_euclid.view(params.n_channels, -1).max(-1).values,
    )


print(subj_agg_stats.dti.min)
print(subj_agg_stats.dti.max)
print(subj_agg_stats.log_euclid.min)
print(subj_agg_stats.log_euclid.max)

#### Calculate Ranges for PSNR

In [None]:
expander = functools.partial(einops.rearrange, pattern="c -> 1 c 1 1 1")
# Collect DTI global data features.
dti_min = expander(subj_agg_stats.dti.min)
dti_max = expander(subj_agg_stats.dti.max)

feat_min, feat_max = torch.as_tensor(
    [
        [0] * 6,
        [1] * 6,
    ]
)
feat_min = expander(feat_min)
feat_max = expander(feat_max)

# PSNR is calculated on the final output tensor components, so no log-euclidean or
# scaling will occur here.
psnr_range_params = pitn.data.norm.GlobalScaleParams(
    feature_min=feat_min, feature_max=feat_max, data_min=dti_min, data_max=dti_max
)

In [None]:
# net_scalers = dict()

# expander = functools.partial(einops.rearrange, pattern="c -> 1 c 1 1 1")
# # Collect DTI global data features.

# feat_min: torch.Tensor
# feat_max: torch.Tensor
# dti_min = expander(subj_agg_stats.dti.min)
# dti_max = expander(subj_agg_stats.dti.max)

# if params.use_log_euclid and (
#     "dti_scale_range" in params.data and params.data.dti_scale_range
# ):
#     print("Using log-euclid min/max for tensor input/output scaling.")
#     leu_min = expander(subj_agg_stats.log_euclid.min)
#     leu_max = expander(subj_agg_stats.log_euclid.max)

#     feat_min, feat_max = torch.as_tensor(params.data.dti_scale_range)
#     feat_min = expander(feat_min)
#     feat_max = expander(feat_max)

#     # Add these stats to the network kwargs in a functional form.
#     leu_scaler = pitn.data.norm.MinMaxScaler(
#         feature_min=feat_min,
#         feature_max=feat_max,
#         data_min=leu_min,
#         data_max=leu_max,
#     )
#     net_scalers["input_scaler"] = leu_scaler.scale_to
#     net_scalers["output_descaler"] = leu_scaler.unscale_from

# elif "dti_scale_range" in params.data and params.data.dti_scale_range:
#     print("Using regular tensor component values for input/output scaling.")

#     feat_min, feat_max = torch.as_tensor(params.data.dti_scale_range)
#     feat_min = expander(feat_min)
#     feat_max = expander(feat_max)

#     dti_global_scale_params = pitn.data.norm.GlobalScaleParams(
#         feature_min=feat_min, feature_max=feat_max, data_min=dti_min, data_max=dti_max
#     )

#     # Add these stats to the network kwargs in a functional form.
#     dti_scaler = pitn.data.norm.MinMaxScaler(
#         feature_min=dti_global_scale_params.feature_min,
#         feature_max=dti_global_scale_params.feature_max,
#         data_min=dti_global_scale_params.data_min,
#         data_max=dti_global_scale_params.data_max,
#     )
#     net_scalers["input_scaler"] = dti_scaler.scale_to
#     net_scalers["output_descaler"] = dti_scaler.unscale_from
# else:
#     print("No scaling")

#     feat_min, feat_max = torch.as_tensor(
#         [
#             [0] * 6,
#             [1] * 6,
#         ]
#     )
#     feat_min = expander(feat_min)
#     feat_max = expander(feat_max)

#     net_scalers["input_scaler"] = lambda o: o
#     net_scalers["output_descaler"] = lambda o: o

# # PSNR is calculated on the final output tensor components, so no log-euclidean or
# # scaling will occur here.
# psnr_range_params = pitn.data.norm.GlobalScaleParams(
#     feature_min=feat_min, feature_max=feat_max, data_min=dti_min, data_max=dti_max
# )

# ### Anatomical input scaling.
# if params.use_anat:
#     if "anat_scale_range" in params.data and params.data.anat_scale_range:
#         # Collect anat global data features.
#         anat_min = expander(subj_agg_stats.anat.min)
#         anat_max = expander(subj_agg_stats.anat.max)

#         if params.use_log_euclid and (
#             "dti_scale_range" in params.data and params.data.dti_scale_range
#         ):
#             # If the log-euclid metrics are being used, then the anatomical needs to be log-
#             # transformed to hit the range we want. So, combine two scalings with a log in the
#             # middle.
#             print("Scaling anat to log-euclidean, roughly")
#             # Pre-log scaling to match the 1st dti component range.
#             pre_log_anat_scaler = pitn.data.norm.MinMaxScaler(
#                 feature_min=expander(subj_agg_stats.dti.min)[:, 0][:, None],
#                 feature_max=expander(subj_agg_stats.dti.max)[:, 0][:, None],
#                 data_min=anat_min,
#                 data_max=anat_max,
#             )

#             leu_feat_min, leu_feat_max = torch.as_tensor(params.data.anat_scale_range)
#             leu_feat_min = expander(torch.atleast_1d(leu_feat_min))
#             leu_feat_max = expander(torch.atleast_1d(leu_feat_max))

#             leu_data_min = expander(subj_agg_stats.log_euclid.min)[:, 0][:, None]
#             leu_data_max = expander(subj_agg_stats.log_euclid.max)[:, 0][:, None]
#             post_log_anat_scaler = pitn.data.norm.MinMaxScaler(
#                 feature_min=leu_feat_min,
#                 feature_max=leu_feat_max,
#                 data_min=leu_data_min,
#                 data_max=leu_data_max,
#             )

#             # Compose them together for the multi-mode input scaling callable.
#             net_scalers[
#                 "multi_modal_input_scaler"
#             ] = lambda anat: post_log_anat_scaler.scale_to(
#                 torch.log(pre_log_anat_scaler.scale_to(anat).clamp_min(1e-6))
#             )

#         else:
#             feat_min, feat_max = torch.as_tensor(params.data.anat_scale_range)
#             feat_min = expander(torch.atleast_1d(feat_min))
#             feat_max = expander(torch.atleast_1d(feat_max))
#             anat_global_scale_params = pitn.data.norm.GlobalScaleParams(
#                 feature_min=feat_min,
#                 feature_max=feat_max,
#                 data_min=anat_min,
#                 data_max=anat_max,
#             )
#             anat_scaler = pitn.data.norm.MinMaxScaler(
#                 feature_min=anat_global_scale_params.feature_min,
#                 feature_max=anat_global_scale_params.feature_max,
#                 data_min=anat_global_scale_params.data_min,
#                 data_max=anat_global_scale_params.data_max,
#             )

#             net_scalers["multi_modal_input_scaler"] = anat_scaler.scale_to

### Model Definition

In [None]:
# Full pytorch-lightning module for contained training, validation, and testing.
debug_prob = -1 / (
    params.train.samples_per_subj_per_epoch * len(train_subjs) / params.train.batch_size
)


class DIQTCascadeSystem(pl.LightningModule):
    def __init__(
        self,
        channels: int,
        batch_size: int,
        in_patch_shape: tuple,
        upscale_factor: int,
        train_loss_method: str,
        val_subj_ids: tuple,
        lambda_dti_stream_loss: float,
        psnr_range_params: pitn.data.norm.GlobalScaleParams,
        opt_params: dict,
        use_log_euclid: bool,
        lr_scheduler_name: str = None,
        lr_scheduler_kwargs: dict = None,
        net_kwargs: dict = dict(),
        **hyper_parameters,
    ):
        super().__init__()
        self.save_hyperparameters(
            "channels",
            "batch_size",
            "train_loss_method",
            *list(hyper_parameters.keys()),
        )
        self._channels = channels
        self._batch_size = batch_size
        self._in_patch_shape = in_patch_shape
        self._upscale_factor = upscale_factor
        self.lambda_dti_stream = lambda_dti_stream_loss
        self.use_le = use_log_euclid
        self._val_viz_subj_id = random.choice(val_subj_ids)
        self._psnr_range_params = psnr_range_params
        # Coefficients to the log-euclidean lower triangle/6D vector that properly scales
        # the Euclidean distance under the log-euclidean metrics.
        self.mat_norm_coeffs = torch.ones(6).float()
        self.mat_norm_coeffs[torch.as_tensor([1, 3, 4])] = np.sqrt(2)
        self.mat_norm_coeffs = mat_norm_coeffs.reshape(-1, 1, 1, 1)
        # Parameters
        # Network parameters
        if self.use_le:
            self.net = pitn.nn.sr.CascadeUpsampleAnatRefineLogEuclid(
                self._channels, upscale_factor=self._upscale_factor, **net_kwargs
            )
            # raise NotImplementedError("ERROR: LE not implemented with anat yet")
            # self.net = pitn.nn.sr.CascadeUpsampleLogEuclid(
            #     self._channels, upscale_factor=self._upscale_factor, **net_kwargs
            # )
        else:
            self.net = pitn.nn.sr.CascadeUpsampleAnatRefine(
                self._channels, upscale_factor=self._upscale_factor, **net_kwargs
            )

        ## Training parameters
        self.opt_params = opt_params
        self._lr_scheduler_name = lr_scheduler_name
        self.lr_scheduler_kwargs = lr_scheduler_kwargs

        # Select loss method as either one of the pre-selected methods, or a custom
        # callable.
        try:
            self._loss_fn = pitn.utils.torch_lookups.loss_fn[
                train_loss_method.casefold()
            ]()
        except (AttributeError, KeyError):
            if callable(train_loss_method):
                self._loss_fn = train_loss_method
            else:
                raise ValueError(
                    f"ERROR: Invalid loss function specification {train_loss_method}, "
                    + f"expected one of {pitn.utils.torch_lookups.loss_fn.keys()} "
                    + "or a callable."
                )

        self._val_subvol_range = dict()

        # My own dinky logging object.
        self.plain_log = Box(
            {
                "train_loss": list(),
                "val_loss": {"rmse": list(), "nrmse": list()},
                "test_loss": {
                    "rmse": dict(),
                    "nrmse": dict(),
                    "rmse_log_euclid": dict(),
                    "nrmse_log_euclid": dict(),
                    "scaled_psnr": dict(),
                    "ssim_fa": dict(),
                    "rmse_fa": dict(),
                    "nrmse_fa": dict(),
                },
                "viz": {
                    "test_preds": dict(),
                    "test_preds_pre_anat": dict(),
                },
            },
        )

    def forward(self, *args, **kwargs):
        y = self.net(*args, **kwargs)
        return y

    def on_train_epoch_start(self):
        self.print(f"\tStart Epoch {self.current_epoch}")

    def training_step(self, batch, batch_idx):
        batch = Box(batch)
        if self.use_le:
            x = batch.lr_log_euclid
            y = batch.log_euclid
            x_scale_kwargs = batch.lr_log_euclid_scale_kwargs
            y_scale_kwargs = batch.log_euclid_scale_kwargs
        else:
            x = batch.lr_dti
            y = batch.dti
            x_scale_kwargs = batch.lr_dti_scale_kwargs
            y_scale_kwargs = batch.dti_scale_kwargs
        x_anat = batch.anat
        x_anat_scale_kwargs = batch.anat_scale_kwargs
        lr_mask = batch.lr_mask.bool()
        uncropped_mask = batch.mask.bool()
        mask = self.net.crop_full_output(uncropped_mask)

        # Determine whether input should be transformed or not.
        transform_x = True
        if params.use_log_euclid:
            transform_x = False
            with torch.no_grad():
                x = self.net.input_scale_fn(
                    x.float(), mask=lr_mask, **x_scale_kwargs
                ).float()
        # Determine whether to take both outputs according to the lambda loss
        # coefficient. If the coefficient is 0, then the stream 1 prediction should not
        # even be calculated.
        if self.lambda_dti_stream == 0:
            return_two_stream = False
        else:
            return_two_stream = True

        y_pred = self.net(
            x,
            x_anat,
            transform_x=transform_x,
            transform_x_anat=True,
            transform_y=False,
            scale_x_kwargs={**x_scale_kwargs, "mask": lr_mask},
            scale_x_anat_kwargs={**x_anat_scale_kwargs, "mask": uncropped_mask},
            return_two_stream=return_two_stream,
        )

        # Calculate loss in the same space as the network prediction for numerical
        # stability and convergence during training.
        with torch.no_grad():
            y = self.net.transform_ground_truth_for_training(
                y.float(), crop=True, **y_scale_kwargs, mask=mask
            )

        if return_two_stream:
            y_pred_1, y_pred_2 = y_pred
            loss_1 = self._loss_fn(y_pred_1, y, mask=mask)
            loss_2 = self._loss_fn(y_pred_2, y, mask=mask)
        else:
            loss_1 = 0
            loss_2 = self._loss_fn(y_pred, y, mask=mask)

        loss = (self.lambda_dti_stream * loss_1) + (
            (1 - self.lambda_dti_stream) * loss_2
        )
        self.log("train_loss", loss, batch_size=self._batch_size)
        self.log(
            "train/loss_pre_anat",
            loss_1,
            batch_size=self._batch_size,
        )
        self.log(
            "train/loss_post_anat",
            loss_2,
            batch_size=self._batch_size,
        )

        self.plain_log["train_loss"].append(float(loss.detach().cpu()))
        if torch.isnan(loss).any().detach().cpu().item():
            raise RuntimeError("ERROR: Nan loss")
        return loss

    def validation_step(self, batch, batch_idx):

        batch = Box(batch)
        # Assume batch size of 1 for the validation set.
        subj_id = batch.subj_id[0]
        x = batch.lr_dti.float()
        y = batch.dti.float()
        if self.use_le:
            x_scale_kwargs = batch.lr_log_euclid_scale_kwargs
            y_scale_kwargs = batch.log_euclid_scale_kwargs
        else:
            x_scale_kwargs = batch.lr_dti_scale_kwargs
            y_scale_kwargs = batch.dti_scale_kwargs

        x_anat = batch.anat
        x_anat_scale_kwargs = batch.anat_scale_kwargs

        y = self.net.crop_full_output(y)
        uncropped_mask = batch.mask.bool()
        mask = self.net.crop_full_output(uncropped_mask)

        x_scale_kwargs["mask"] = batch.lr_mask.bool()
        x_anat_scale_kwargs["mask"] = uncropped_mask
        y_scale_kwargs["mask"] = mask

        y_pred_pre_anat, y_pred = self.net(
            x,
            x_anat,
            scale_x_kwargs=x_scale_kwargs,
            scale_x_anat_kwargs=x_anat_scale_kwargs,
            scale_y_kwargs=y_scale_kwargs,
            return_two_stream=True,
        )

        y_pred_pre_anat = y_pred_pre_anat.float()
        y_pred = y_pred.float()

        rmse_loss = pitn.nn.loss.dti_root_vec_fro_norm_loss(
            y_pred, y, mask=mask, scale_off_diags=True, reduction="mean"
        )

        nrmse_loss = pitn.metrics.minmax_normalized_dti_root_vec_fro_norm(
            y_pred,
            y,
            mask=mask,
            scale_off_diags=True,
            reduction="mean",
        )
        self.log("val_loss/nrmse", nrmse_loss)
        self.log("val_loss/rmse", rmse_loss)
        self.plain_log.val_loss.nrmse.append(float(nrmse_loss.detach().cpu()))
        self.plain_log.val_loss.rmse.append(float(rmse_loss.detach().cpu()))

        # Plot visual of validation volume.
        if subj_id == self._val_viz_subj_id and not self._val_subvol_range:
            # Take range between 1/4 to 3/4 the size of each dimension.
            fr_space = np.asarray(y.shape[2:])
            fr_low = np.floor(fr_space * 1 / 4).astype(int)
            fr_high = np.floor(fr_space * 3 / 4).astype(int)
            subvol_slice = np.s_[
                fr_low[0] : fr_high[0], fr_low[1] : fr_high[1], fr_low[2] : fr_high[2]
            ]
            self._val_subvol_range["fr"] = (
                0,
                ...,
            ) + subvol_slice

            lr_space = np.asarray(x.shape[2:])
            lr_low = np.floor(lr_space * 1 / 4).astype(int)
            lr_high = np.floor(lr_space * 3 / 4).astype(int)
            subvol_slice = np.s_[
                lr_low[0] : lr_high[0], lr_low[1] : lr_high[1], lr_low[2] : lr_high[2]
            ]
            self._val_subvol_range["lr"] = (
                0,
                ...,
            ) + subvol_slice

        if subj_id == self._val_viz_subj_id:
            # breakpoint()
            x_subvol = x.detach()[self._val_subvol_range["lr"]].float()
            y_subvol = y.detach()[self._val_subvol_range["fr"]].float()
            pred_pre_anat_subvol = y_pred_pre_anat.detach()[
                self._val_subvol_range["fr"]
            ].float()
            pred_subvol = y_pred.detach()[self._val_subvol_range["fr"]].float()
            pre_anat_vs_post_subvol = torch.abs(pred_pre_anat_subvol - pred_subvol)
            # x_min = einops.reduce(x_subvol, "c d h w -> c 1 1 1", "min")
            # y_min = einops.reduce(y_subvol, "c d h w -> c 1 1 1", "min")
            # pred_min = einops.reduce(pred_subvol, "c d h w -> c 1 1 1", "min")
            # subvol_min = torch.minimum(x_min, y_min).minimum(pred_min)
            # x_max = einops.reduce(x_subvol, "c d h w -> c 1 1 1", "max")
            # y_max = einops.reduce(y_subvol, "c d h w -> c 1 1 1", "max")
            # pred_min = einops.reduce(pred_subvol, "c d h w -> c 1 1 1", "min")
            # subvol_max = torch.maximum(x_max, y_max).maximum(pred_max)
            # pred_subvol = pred_subvol.clamp(
            #     min=self._val_subvol_min, max=self._val_subvol_min
            # )
            # pred_pre_anat_subvol = y_pred_pre_anat.detach()[
            #     self._val_subvol_range["fr"]
            # ].float()
            # x_mode_refine_subvol = x_mode_refine.detach()[
            #     self._val_subvol_range["fr"]
            # ].expand_as(y_subvol)
            # mask_subvol = mask.detach()[self._val_subvol_range["fr"]]

            # Create grid plot
            # Plot settings to propogate into the figure.
            with mpl.rc_context(
                {
                    "font.size": 6.0,
                    "axes.labelpad": 10.0,
                    "figure.autolayout": False,
                    "figure.constrained_layout.use": True,
                    "ytick.color": "red",
                }
            ):
                fig = plt.figure(dpi=130, figsize=(6, 4))
                channel_names = [
                    r"$D_{x,x}$",
                    r"$D_{x,y}$",
                    r"$D_{y,y}$",
                    r"$D_{x,z}$",
                    r"$D_{y,z}$",
                    r"$D_{z,z}$",
                ]
                slice_labels = [
                    "\nAxial",
                    "\nCoronal",
                    "\nSagg.",
                ]
                img_labels = [
                    "GT",
                    "Pred",
                    "Pred Pre Anat",
                    "Input",
                    "Pre/Post Anat Diff",
                ]
                fig = pitn.viz.plot_vol_slices(
                    y_subvol,
                    pred_subvol,
                    pred_pre_anat_subvol,
                    x_subvol,
                    pre_anat_vs_post_subvol,
                    slice_idx=(0.55, None, None),
                    title=f"Subj {subj_id} Step {self.global_step}",
                    vol_labels=img_labels,
                    channel_labels=channel_names,
                    slice_labels=slice_labels,
                    colorbars="cols",
                    fig=fig,
                    interpolation="antialiased",
                    cmap="gray",
                )
                self.logger.experiment.add_figure("val_slice", fig, self.global_step)

        return {
            "rmse": rmse_loss,
            "nrmse": nrmse_loss,
        }

    def on_test_start(self):
        # Initialize the metrics as hyperparams so they appear under the tensorboard
        # hparams tab. From:
        # <https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#logging-hyperparameters>
        self.logger.log_hyperparams(
            self.hparams,
            {
                "hp/rmse": 0,
                "hp/nrmse": 0,
                "hp/scaled_psnr": 0,
                "hp/ssim_fa": 0,
                "hp/rmse_log_euclid": 0,
                "hp/nrmse_log_euclid": 0,
                "hp/rmse_fa": 0,
                "hp/nrmse_fa": 0,
            },
        )

    def test_step(self, batch: dict, batch_idx):

        batch = Box(batch)
        # Assume batch size of 1.
        subj_id = batch.subj_id[0]

        # Input
        x = batch.lr_dti.float()
        log_euclid_x = batch.lr_log_euclid
        x_anat = batch.anat

        # Mask
        uncropped_mask = batch.mask.bool()
        mask = self.net.crop_full_output(uncropped_mask)

        # Ground truth.
        y = batch.dti.float()
        y = self.net.crop_full_output(y)
        log_euclid_y = batch.log_euclid
        log_euclid_y = self.net.crop_full_output(log_euclid_y)
        y_fa = pitn.metrics.fast_fa(y, foreground_mask=mask)

        # Scale kwargs.
        if self.use_le:
            x_scale_kwargs = batch.lr_log_euclid_scale_kwargs
            y_scale_kwargs = batch.log_euclid_scale_kwargs
        else:
            x_scale_kwargs = batch.lr_dti_scale_kwargs
            y_scale_kwargs = batch.dti_scale_kwargs
        x_anat_scale_kwargs = batch.anat_scale_kwargs
        x_scale_kwargs["mask"] = batch.lr_mask.bool()
        x_anat_scale_kwargs["mask"] = uncropped_mask
        y_scale_kwargs["mask"] = mask

        # Network predictions
        y_pred_pre_anat, y_pred = self.net(
            x,
            x_anat,
            scale_x_kwargs=x_scale_kwargs,
            scale_x_anat_kwargs=x_anat_scale_kwargs,
            scale_y_kwargs=y_scale_kwargs,
            return_two_stream=True,
        )
        y_pred = y_pred.float()
        y_pred_pre_anat = y_pred_pre_anat.float()
        # If the network operates on log domain natively, then run another forward pass.
        if self.use_le:
            log_euclid_y_pred = self.net(
                self.net.input_scale_fn(log_euclid_x, **x_scale_kwargs),
                x_anat,
                transform_x=False,
                transform_x_anat=True,
                transform_y=False,
                scale_x_anat_kwargs=x_anat_scale_kwargs,
                return_two_stream=False,
            )
            log_euclid_y_pred = self.net.output_descale_fn(
                log_euclid_y_pred, **y_scale_kwargs
            )
        # If the network operates in euclidian space natively, then take the original
        # prediction and log map it.
        else:
            log_euclid_y_pred = pitn.eig.tril_vec2sym_mat(y_pred, tril_dim=1)
            log_euclid_y_pred = pitn.riemann.log_euclid.log_map(log_euclid_y_pred)
            log_euclid_y_pred = pitn.eig.sym_mat2tril_vec(
                log_euclid_y_pred, tril_dim=1, dim1=-2, dim2=-1
            )
            log_euclid_y_pred = log_euclid_y_pred / self.mat_norm_coeffs.to(
                log_euclid_y_pred
            )

        y_pred_fa = pitn.metrics.fast_fa(y_pred, foreground_mask=mask)

        # Mask select the target and prediction(s)
        # Mask ground truths.
        y_select = torch.masked_select(y, mask).view(y.shape[0], y.shape[1], -1)
        y_fa_select = torch.masked_select(y_fa, mask).view(
            y_fa.shape[0], y_fa.shape[1], -1
        )

        # Mask predictions.
        y_pred_select = torch.masked_select(y_pred, mask).view(
            y_pred.shape[0], y_pred.shape[1], -1
        )
        y_pred_fa_select = torch.masked_select(y_pred_fa, mask).view(
            y_pred_fa.shape[0], y_pred_fa.shape[1], -1
        )

        ###### Calculate network performance metrics.
        # MSE metrics.
        rmse_loss = pitn.nn.loss.dti_root_vec_fro_norm_loss(
            y_pred, y, mask=mask, scale_off_diags=True, reduction="mean"
        )
        nrmse_loss = pitn.metrics.minmax_normalized_dti_root_vec_fro_norm(
            y_pred,
            y,
            mask=mask,
            scale_off_diags=True,
            reduction="mean",
        )

        rmse_log_euclid_loss = pitn.nn.loss.dti_root_vec_fro_norm_loss(
            log_euclid_y_pred,
            log_euclid_y,
            mask=mask,
            scale_off_diags=False,
            reduction="mean",
        )
        nrmse_log_euclid_loss = pitn.metrics.minmax_normalized_dti_root_vec_fro_norm(
            log_euclid_y_pred,
            log_euclid_y,
            mask=mask,
            scale_off_diags=False,
            reduction="mean",
        )

        # We want to find the *actual* RMSE and NRMSE of the FA maps, because they are
        # no longer DTIs.
        rmse_fa_loss = torch.sqrt(
            F.mse_loss(
                y_pred_fa_select,
                y_fa_select,
                reduction="mean",
            )
        )
        nrmse_fa_loss = pitn.metrics.minmax_normalized_rmse(
            y_pred_fa_select,
            y_fa_select,
            reduction="mean",
        )

        # PSNR metric
        # Need to reshape the mins and maxes to play nice with the mask-selected tensors.
        min_shape_pre = tuple(self._psnr_range_params.data_min.shape[:2])
        max_shape_pre = tuple(self._psnr_range_params.data_max.shape[:2])
        scaled_psnr_loss = pitn.metrics.psnr_batch_channel_regularized(
            y_pred_select,
            y_select,
            range_min=self._psnr_range_params.data_min.view(*min_shape_pre, -1),
            range_max=self._psnr_range_params.data_max.view(*max_shape_pre, -1),
        )

        # Perceptual metrics
        ssim_fa_loss = pitn.metrics.ssim_y_range(
            y_pred_fa,
            y_fa,
        )

        # on_epoch and reduce_fx gather the individual test epoch values and aggregate
        # them at the end of the testing epoch.
        self.log_dict(
            {
                "test_loss/rmse": rmse_loss,
                "test_loss/nrmse": nrmse_loss,
                "test_loss/nrmse_log_euclid": nrmse_log_euclid_loss,
                "test_loss/rmse_log_euclid": rmse_log_euclid_loss,
                "test_loss/scaled_psnr": scaled_psnr_loss,
                "test_loss/ssim_fa": ssim_fa_loss,
                "test_loss/rmse_fa": rmse_fa_loss,
                "test_loss/nrmse_fa": nrmse_fa_loss,
            },
            on_epoch=True,
            reduce_fx=torch.mean,
        )
        # Log loss as an hparam metric to be shown alongside the experiment's hparams.
        self.log_dict(
            {
                "hp/rmse": rmse_loss,
                "hp/nrmse": nrmse_loss,
                "hp/nrmse_log_euclid": nrmse_log_euclid_loss,
                "hp/rmse_log_euclid": rmse_log_euclid_loss,
                "hp/scaled_psnr": scaled_psnr_loss,
                "hp/ssim_fa": ssim_fa_loss,
                "hp/rmse_fa": rmse_fa_loss,
                "hp/nrmse_fa": nrmse_fa_loss,
            },
            on_epoch=True,
            reduce_fx=torch.mean,
        )
        self.plain_log.test_loss.rmse[subj_id] = rmse_loss.detach().cpu().item()
        self.plain_log.test_loss.nrmse[subj_id] = nrmse_loss.detach().cpu().item()
        self.plain_log.test_loss.rmse_log_euclid[subj_id] = (
            rmse_log_euclid_loss.detach().cpu().item()
        )
        self.plain_log.test_loss.nrmse_log_euclid[subj_id] = (
            nrmse_log_euclid_loss.detach().cpu().item()
        )
        self.plain_log.test_loss.scaled_psnr[subj_id] = (
            scaled_psnr_loss.detach().cpu().item()
        )
        self.plain_log.test_loss.ssim_fa[subj_id] = ssim_fa_loss.detach().cpu().item()
        self.plain_log.test_loss.rmse_fa[subj_id] = rmse_fa_loss.detach().cpu().item()
        self.plain_log.test_loss.nrmse_fa[subj_id] = nrmse_fa_loss.detach().cpu().item()

        # Store entire predicted DTI for saving & visualization.
        self.plain_log.viz.test_preds[subj_id] = y_pred[0].detach().cpu()
        # Also store the pre-anat prediction, just in case.
        self.plain_log.viz.test_preds_pre_anat[subj_id] = (
            y_pred_pre_anat[0].detach().cpu()
        )

        return {
            "rmse": rmse_loss,
            "nrmse": nrmse_loss,
            "nrmse_log_euclid": nrmse_log_euclid_loss,
            "rmse_log_euclid": rmse_log_euclid_loss,
            "scaled_psnr": scaled_psnr_loss,
            "ssim_fa": ssim_fa_loss,
            "rmse_fa": rmse_fa_loss,
            "nrmse_fa": nrmse_fa_loss,
        }

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.net.parameters(), **self.opt_params)
        opt_system = {"optimizer": optimizer}

        if self._lr_scheduler_name:
            lr_scheduler_class = pitn.utils.torch_lookups.lr_scheduler[
                self._lr_scheduler_name
            ]
            if self._lr_scheduler_name.casefold() == "sequential":
                schedulers = [
                    schedule(optimizer)
                    for schedule in self.lr_scheduler_kwargs["schedulers"]
                ]
                scheduler_kwargs = self.lr_scheduler_kwargs.copy()
                scheduler_kwargs["schedulers"] = schedulers
                lr_scheduler = lr_scheduler_class(optimizer, **scheduler_kwargs)
                lr_scheduler.optimizer = optimizer
            else:
                lr_scheduler = lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs)
            # See
            # <https://github.com/PyTorchLightning/pytorch-lightning/issues/4576#issuecomment-723648061>
            # Interval for the LR stepping needs to be configured here.
            # See also
            # <https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers>
            lr_scheduler_config = {
                "name": "lr " + self._lr_scheduler_name,
                "scheduler": lr_scheduler,
                "interval": "step",
            }
            opt_system["lr_scheduler"] = lr_scheduler_config

        return opt_system

### Training Loop

In [None]:
net_kwargs: dict = dict()
net_kwargs.update(params.net.kwargs)
if params.data.scale_method == "standard":
    scaler = data_scaler_cls()
    net_kwargs["input_scale_fn"] = scaler.scale_to
    net_kwargs["output_descale_fn"] = scaler.unscale_from

In [None]:
train_start_timestamp = datetime.datetime.now().replace(microsecond=0)

model_kwargs = dict(
    channels=params.n_channels,
    batch_size=params.train.batch_size,
    in_patch_shape=params.train.in_patch_size,
    lambda_dti_stream_loss=params.train.lambda_dti_stream_loss,
    upscale_factor=params.net.upscale_factor,
    train_loss_method=params.train.loss_name,
    use_log_euclid=params.use_log_euclid,
    val_subj_ids=val_subjs,
    psnr_range_params=psnr_range_params,
    opt_params=params.optim.kwargs,
    net_kwargs=net_kwargs,
)

if "lr_scheduler" in params.train and params.train.lr_scheduler:
    model_kwargs["lr_scheduler_name"] = params.train.lr_scheduler.name
    model_kwargs["lr_scheduler_kwargs"] = params.train.lr_scheduler.kwargs
# Update init kwargs with hyperparams, in case there are overlapping names.
model_kwargs.update(**hyperparams)

# Create model from given kwargs.
model = DIQTCascadeSystem(**model_kwargs)

# Build up model summary by feeding in some random input. Also initializes the lazy conv
# channel sizes, makes it easier to read the logs.
with torch.no_grad():
    amp_enabled = params.use_half_precision_float
    with torch.cuda.amp.autocast(enabled=amp_enabled):
        if params.use_log_euclid:
            rx = torch.randn_like(train_dataset[0]["lr_log_euclid"]).repeat(
                params.train.batch_size, 1, 1, 1, 1
            )
        else:
            rx = torch.randn_like(train_dataset[0]["lr_dti"]).repeat(
                params.train.batch_size, 1, 1, 1, 1
            )
        rx_anat = torch.randn_like(train_dataset[0]["anat"]).repeat(
            params.train.batch_size, 1, 1, 1, 1
        )

        model.net(
            rx, rx_anat, transform_x=False, transform_x_anat=False, transform_y=False
        )
        model_summary = torchinfo.summary(
            model.net,
            input_data=[rx, rx_anat],
            # batch_dim=0,
            col_names=("output_size", "num_params", "kernel_size"),
            col_width=30,
            depth=10,
            row_settings=("depth", "var_names"),
            device=device,
            verbose=0,
            transform_x=False,
            transform_x_anat=False,
            transform_y=False,
        )

with open(log_txt_file, "a+") as f:
    f.write(f"Model overview: \n{model}\n\n")
    f.write("torchinfo Model Summary: \n\n")
    f.write(str(model_summary))
    f.write("\n\n")

lr_monitor = pl.callbacks.LearningRateMonitor("step")
device_monitor = pl.callbacks.DeviceStatsMonitor()

half_precision_kwargs = dict()
if params.use_half_precision_float:
    half_precision_kwargs["precision"] = 16
    half_precision_kwargs["amp_backend"] = "native"

# Create trainer object.
trainer = pl.Trainer(
    # fast_dev_run=10,
    gpus=[dev_idx],
    accelerator="gpu",
    enable_checkpointing=False,
    max_epochs=params.train.max_epochs,
    logger=pl_logger,
    log_every_n_steps=50,
    # run validation every 0.5 epochs
    val_check_interval=0.5,
    # max_time={"hours": 4, "minutes": 30},
    # callbacks=[device_monitor],
    benchmark=True,
    enable_progress_bar=params.progress_bar,
    # track_grad_norm=2,
    gradient_clip_val=params.train.grad_2norm_clip_val,
    accumulate_grad_batches=params.train.accumulate_grad_batches,
    **half_precision_kwargs,
)

# Many warnings are produced here, so it's better for my sanity (i.e., worse in every other
# way) to just filter and ignore them...
# with warnings.catch_warnings():
#     warnings.simplefilter("ignore")
try:
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
except RuntimeError as e:
    with open(log_txt_file, "a+") as f:
        f.write("\n")
        f.write("!!!!! Fatal ERROR !!!!!!!\n")
        f.write("Traceback:\n")
        f.write(str(e) + "\n")
        f.write("!!!!!!!!!!!!!!!!!!!!!!!!!\n")
        raise e

train_duration = datetime.datetime.now().replace(microsecond=0) - train_start_timestamp
print(f"Train duration: {train_duration}")

In [None]:
# Save out trained model
trainer.save_checkpoint(str(experiment_results_dir / "model.ckpt"))

In [None]:
torch.save(model.state_dict(), experiment_results_dir / "model_state_dict.pt")

In [None]:
with open(log_txt_file, "a+") as f:
    f.write("\n")
    f.write(f"Training time: {train_duration}\n")
    f.write(
        f"\t{train_duration.days} Days, "
        + f"{train_duration.seconds // 3600} Hours,"
        + f"{(train_duration.seconds // 60) % 60} Minutes,"
        + f'{train_duration.seconds % 60} Seconds"\n'
    )

In [None]:
# Plot rolling average window of training loss values.
plt.figure(dpi=110)
window = 1000
rolling_mean = (
    np.convolve(model.plain_log["train_loss"], np.ones(window), "valid") / window
)
rolling_start = 100
plt.plot(
    np.arange(
        window + rolling_start,
        window + rolling_start + len(rolling_mean[rolling_start:]),
    ),
    rolling_mean[rolling_start:],
)
plt.title("Training Loss " + params.train.loss_name + f"\nRolling Mean {window}")
plt.xlabel("Epoch")
plt.ylabel("Loss")
# plt.ylim(0, 1)
print(np.median(rolling_mean))
print(
    np.mean(model.plain_log["train_loss"][: window + rolling_start]),
    np.var(model.plain_log["train_loss"][: window + rolling_start]),
    np.max(model.plain_log["train_loss"][: window + rolling_start]),
)

plt.savefig(experiment_results_dir / "train_loss.png")

## Model Testing/Evaluation

### Testing Loop

In [None]:
# mod_state = model.state_dict()

# test_mod = DIQTCascadeSystem(
#     channels=params.n_channels,
#     batch_size=params.train.batch_size,
#     in_patch_shape=params.train.in_patch_size,
#     upscale_factor=params.net.upscale_factor,
#     anat_batch_key=params.data.anat_type,
#     train_loss_method=params.train.loss_name,
#     opt_params=params.optim.kwargs,
#     hparams=hyperparams,
#     **params.net.kwargs,
# )
# test_mod.load_state_dict(mod_state)
# trainer.test(test_mod, dataloaders=test_loader, ckpt_path=None, verbose=True)

In [None]:
trainer.test(model, dataloaders=test_loader, ckpt_path=None, verbose=True)

In [None]:
test_losses = (
    "rmse",
    "nrmse",
    "rmse_log_euclid",
    "nrmse_log_euclid",
    "scaled_psnr",
    "ssim_fa",
    "rmse_fa",
    "nrmse_fa",
)
loss_comparison_directions = {
    "rmse": "↓",
    "nrmse": "↓",
    "rmse_log_euclid": "↓",
    "nrmse_log_euclid": "↓",
    "scaled_psnr": "↑",
    "ssim_fa": "↑",
    "rmse_fa": "↓",
    "nrmse_fa": "↓",
}

test_results = Box(subj_id=list(), model=list(), metric=list(), value=list())
for subj_id in test_subjs:
    for metric in test_losses:
        # DIQT model
        test_results.subj_id.append(subj_id)
        test_results.model.append("diqt")
        test_results.metric.append(metric)
        test_results.value.append(model.plain_log.test_loss[metric][subj_id])

# Convert to a pandas dataframe.
test_results = pd.DataFrame(test_results.to_dict())

with open(log_txt_file, "a+") as f:
    f.write(f"Test loss functions: {list(test_losses)}\n")

test_loss_log_file = experiment_results_dir / "test_loss.csv"
test_results.to_csv(test_loss_log_file, index=False)

### Evaluation Visualization

#### Comparison within experiment

In [None]:
with mpl.rc_context(
    {
        "font.size": 8.0,
    }
):
    fig, axs = plt.subplots(
        ncols=len(test_losses),
        sharex=True,
        figsize=(11, 4),
        dpi=130,
        gridspec_kw={"wspace": 1.0, "hspace": 0},
    )
    sns.despine(fig=fig, top=True, right=True)

    for i, l in enumerate(test_losses):

        ax = axs[i]
        df = test_results.loc[test_results.metric == l]
        vplot = sns.violinplot(
            x="model", y="value", data=df, ax=ax, scale="count", inner=None
        )
        axs[i].grid(axis="y", alpha=0.5)
        points_plot = sns.swarmplot(
            x="model",
            y="value",
            hue="subj_id",
            data=df,
            ax=ax,
            # color="white",
            edgecolor="black",
            size=4,
            linewidth=0.8,
        )
        points_plot.get_legend().remove()

        # Calculate mean performance score.
        means = df.groupby("model").mean()
        # Make sure the order follows seaborn's x-axis ordering.
        model_order = list(map(lambda ax: ax.get_text(), axs[i].get_xticklabels()))
        means = means.reindex(model_order)

        # Grab colors corresponding to each model.
        colors = sns.color_palette(None, n_colors=len(df.model.unique()))

        lines = ax.hlines(
            y=means.value,
            xmin=np.arange(0, len(colors)) - 0.5 + 0.05,
            xmax=np.arange(1, len(colors) + 1) - 0.5 - 0.05,
            colors=colors,
            lw=1.5,
        )

        outline_path_effects = [
            mpl.patheffects.Stroke(linewidth=5, foreground="white", alpha=0.9),
            mpl.patheffects.Normal(),
        ]
        lines.set_path_effects(outline_path_effects)

        ax.set_xticklabels(ax.get_xticklabels(), rotation=25)

        fig.canvas.draw()
        ax_format = ax.get_yaxis().get_major_formatter()

        for m, c in zip(means.value, colors):

            ax.annotate(
                f"{m:.4g}",
                xy=(ax.get_xlim()[0] + (ax.get_xlim()[0] * 0.4), m),
                xycoords="data",
                color=c,
                ha="right",
                va="center",
                annotation_clip=False,
                fontweight="bold",
                snap=True,
                bbox=dict(
                    boxstyle="square,pad=0.3", fc="white", lw=0, snap=True, alpha=0.75
                ),
            )
        ax.set_title(f"{l.replace('_', ' ')} {loss_comparison_directions[l]}")
        ax.set_ylabel("")
        ax.set_xlabel("")
plt.savefig(experiment_results_dir / "test_result_within_experiment.png")

#### Comparison with other works

In [None]:
# Plot testing loss values over all patches.
fig, ax_prob = plt.subplots(figsize=(8, 4), dpi=120)
log_scale = False

hist = sns.histplot(
    list(model.plain_log["test_loss"].rmse.values()),
    alpha=0.5,
    stat="count",
    log_scale=log_scale,
    ax=ax_prob,
    legend=False,
    hatch="\\\\",
    ec="blue",
)
hist.yaxis.set_major_locator(mpl.ticker.MaxNLocator(integer=True))
plt.xlabel("Loss in $mm^2/second$")

# Draw means of different comparison models.
comparison_kwargs = {"ls": "-", "lw": 2.5}
# Plot the current DNN model performance.
plt.axvline(
    np.asarray(list(model.plain_log["test_loss"].rmse.values())).mean(),
    label="Current Model Mean",
    color="blue",
    **comparison_kwargs,
)

plt.axvline(
    9.72e-4,
    label="(Tanno etal, 2017 &\n Blumberg etal, 2018)\nESPCN Baseline",
    color="green",
    **comparison_kwargs,
)
plt.axvline(
    9.76e-4,
    label="(Tanno etal, 2017 &\n Blumberg etal, 2018)\nBest ESPCN\n[but not really]",
    color="purple",
    **comparison_kwargs,
)
# Best performing Blumberg, et. al., 2018 paper model.
plt.axvline(
    8.46e-4,
    label="(Blumberg etal, 2018)\nBest Overall",
    color="pink",
    **comparison_kwargs,
)

plt.legend(fontsize="small")
plt.title("Test Loss Histogram Over All Subjects with Test Metric RMSE")
plt.savefig(experiment_results_dir / "test_rmse_hist.png")

In [None]:
# Plot testing loss values over all subjects.
fig, ax = plt.subplots(figsize=(8, 4), dpi=120)

models = (
    "(Ours)\nCurrent Model",
    "(Tanno etal, 2021)\nC-Spline\nMean Approx.",
    "(Tanno etal, 2021)\nRand. Forest\nMean Approx.",
    "(Tanno etal, 2017 &\n Blumberg etal, 2018)\nESPCN Baseline",
    '(Tanno etal, 2017 &\n Blumberg etal, 2018)\n"Best" ESPCN',
    "(Blumberg etal, 2018)\nBest Overall",
)

rmse_bounds = np.asarray(
    [
        [0, 0],
        [31.738e-4, 10.069e-4],
        [23.139e-4, 6.974e-4],
        [13.609e-4, 6.212e-4],
        [13.82e-4, 6.29e-4],
        [12.13e-4, 5.58e-4],
    ]
)

rmse_scores = (
    np.asarray(list(model.plain_log["test_loss"].rmse.values())).mean(),
    rmse_bounds[2].mean(),
    rmse_bounds[3].mean(),
    9.72e-4,
    9.76e-4,
    8.46e-4,
)

rmse_score_ranges = np.asarray(rmse_scores)[:, None] - rmse_bounds
rmse_score_ranges[:1] = rmse_score_ranges[:1] * 0
rmse_score_ranges = rmse_score_ranges.T

colors = sns.color_palette("deep", n_colors=len(rmse_scores))

ax.grid(True, axis="y", zorder=1000)
ax.set_axisbelow(True)
# Plot our evaluation scores.
end_idx = 2
ax.bar(
    models[:end_idx],
    rmse_scores[:end_idx],
    color=colors[:end_idx],
    edgecolor="black",
    lw=0.75,
)
for container in ax.containers:
    if isinstance(container, mpl.container.BarContainer):
        ax.bar_label(container, fmt="%.3e")


start_idx = end_idx
# Plot the crazy evaluation scores.
ax.bar(
    models[start_idx:],
    height=rmse_bounds[start_idx:, 0] - rmse_bounds[start_idx:, 1],
    bottom=rmse_bounds[start_idx:, 1],
    color=colors[start_idx:],
    edgecolor="black",
    lw=0.75,
    alpha=0.8,
)
end_idx = start_idx + 2
bar_width = ax.patches[0].get_width()
# Dotted lines for the "approximate" actual rmse score.
ax.hlines(
    rmse_scores[start_idx:end_idx],
    xmin=np.arange(start_idx, end_idx) - bar_width / 2,
    xmax=np.arange(start_idx, end_idx) + bar_width / 2,
    color="black",
    ls="--",
)
for score, x in zip(
    rmse_scores[start_idx:end_idx], np.arange(start_idx, end_idx) - bar_width / 2
):
    ax.annotate(f"{score:.3e}", (x, score + 0.03 * np.asarray(rmse_scores).max()))

start_idx = end_idx
end_idx = len(models)
ax.hlines(
    rmse_scores[start_idx:],
    xmin=np.arange(start_idx, end_idx) - bar_width / 2,
    xmax=np.arange(start_idx, end_idx) + bar_width / 2,
    color="black",
    ls="-",
)

for score, x in zip(
    rmse_scores[start_idx:end_idx], np.arange(start_idx, end_idx) - bar_width / 2
):
    ax.annotate(f"{score:.3e}", (x, score + 0.03 * np.asarray(rmse_scores).max()))

ax.set_ylim(bottom=0, top=ax.get_ylim()[1] * 1.1)
ax.set_xlabel("Model")
ax.set_ylabel("Loss in $mm^2/second$")

ax.set_title("Mean Over Subjects Test Loss RMSE")
ax.set_xticks(models)
ax.set_xticklabels(
    models, fontsize="x-small", rotation=25, ha="right", rotation_mode="anchor"
)
plt.savefig(experiment_results_dir / "test_rmse_bar.png")

In [None]:
sorted_test_idx = np.argsort(np.asarray(list(model.plain_log.test_loss.rmse.values())))
sorted_test_results = dict(
    list(model.plain_log.test_loss.rmse.items())[i] for i in sorted_test_idx
)
ppr(sorted_test_results, sort_dicts=False)
print(np.mean(list(list(model.plain_log.test_loss.rmse.values()))))

In [None]:
diqt_results = test_results.loc[test_results.model == "diqt"]

logger.add_histogram(
    "test/rmse_dist", np.asarray(diqt_results.loc[diqt_results.metric == "rmse"].value)
)
logger.add_histogram(
    "test/nrmse_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "nrmse"].value),
)
logger.add_histogram(
    "test/rmse_log_euclid_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "rmse_log_euclid"].value),
)
logger.add_histogram(
    "test/nrmse_log_euclid_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "nrmse_log_euclid"].value),
)

logger.add_histogram(
    "test/scaled_psnr_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "scaled_psnr"].value),
)
logger.add_histogram(
    "test/ssim_fa_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "ssim_fa"].value),
)
logger.add_histogram(
    "test/rmse_fa_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "rmse_fa"].value),
)
logger.add_histogram(
    "test/nrmse_fa_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "nrmse_fa"].value),
)

## Whole-Volume Visualization

### Setup

In [None]:
# Debug flag(s)
disable_fig_save = False

In [None]:
# Create full 3D volumes of full-res ground truth, low-res downsample, full-res mask,
# anatomical image, and full-res predictions.

results_viz = Box(default_box=True)
with torch.no_grad():

    for subj in test_dataset:
        # Index into the only item in the subject dataset.
        s = Box(default_box=True)
        subj_id = subj["subj_id"]
        print(f"Starting subject {subj_id}")

        # Collect all variants of the volume and aggregate into one container object.
        s.mask = model.net.crop_full_output(subj["mask"])
        s.dti = model.net.crop_full_output(subj["dti"])
        s.affine = subj["dti_meta_dict"]["affine"]
        s.lr_dti = subj["lr_dti"]

        s.pred = model.plain_log.viz.test_preds[subj_id]
        s.pred_pre_anat = model.plain_log.viz.test_preds_pre_anat[subj_id]

        s.dti = s.dti * s.mask
        s.pred = s.pred * s.mask
        s.metrics.rmse = model.plain_log.test_loss.rmse[subj_id]
        s.abs_error = torch.abs(s.pred - s.dti)

        for k in {
            "mask",
            "dti",
            "lr_dti",
            "pred",
            "pred_pre_anat",
            "abs_error",
        }:
            v = s[k]
            if torch.is_tensor(v):
                v = v.detach().cpu().numpy()
            if k == "mask":
                v = v.astype(bool)
            else:
                v = v.astype(float)
            # s[k] = np.rot90(v)
            s[k] = v

        results_viz[subj_id] = s
        print(f"Finished subject {subj_id}")

In [None]:
# Save out all network predictions to Nifti2 files and compress them into a zip archive.
if not disable_fig_save:
    img_names = list()
    for subj_id, viz in results_viz.items():
        nib_img = nib.Nifti2Image(viz.pred, viz.affine)

        filename = experiment_results_dir / f"{subj_id}_predicted_dti.nii.gz"
        nib.save(nib_img, str(filename))
        img_names.append(filename)

    with zipfile.ZipFile(experiment_results_dir / "predicted_dti.zip", "w") as fzip:
        for filename in img_names:
            fzip.write(
                filename,
                arcname=filename.name,
                compress_type=zipfile.ZIP_DEFLATED,
                compresslevel=7,
            )
            os.remove(filename)

# And the pre-anat predictions.
if not disable_fig_save:
    img_names = list()
    for subj_id, viz in results_viz.items():
        nib_img = nib.Nifti2Image(viz.pred_pre_anat, viz.affine)

        filename = experiment_results_dir / f"{subj_id}_predicted_pre_anat_dti.nii.gz"
        nib.save(nib_img, str(filename))
        img_names.append(filename)

    with zipfile.ZipFile(
        experiment_results_dir / "predicted_pre_anat_dti.zip", "w"
    ) as fzip:
        for filename in img_names:
            fzip.write(
                filename,
                arcname=filename.name,
                compress_type=zipfile.ZIP_DEFLATED,
                compresslevel=7,
            )
            os.remove(filename)
    # Make sure we exit the 'with' statement above.
    print("Done with files")

In [None]:
# Pick the worst performing subject from the test set to visualize.
sel_rmse = test_results.loc[
    (test_results.model == "diqt") & (test_results.metric == "rmse")
][["subj_id", "value"]]
sel_rmse = sel_rmse.sort_values("value")
# Or 2nd worst performing...
bad_rmse = sel_rmse.iloc[-2]
viz_subj_id = bad_rmse.subj_id
print(viz_subj_id, bad_rmse.value)
viz_subj = results_viz[viz_subj_id]

In [None]:
# Select indices for visualizing.
dti_shape = np.asarray(viz_subj.dti.shape[1:])
lr_dti_shape = np.asarray(viz_subj.lr_dti.shape[1:])

viz_idx = dti_shape // 2
# Last dimension (saggital) shouldn't be exactly centered, as the longitudinal fissure
# doesn't have many fibers outside the corpus collosum.
viz_idx[2] = viz_idx[2] + 6
viz_lr_idx = lr_dti_shape // 2
viz_lr_idx[2] = viz_lr_idx[2] + 6 // params.data.downsampled_by_factor

viz_slice_idx = [
    np.s_[:, viz_idx[0], :, :],
    np.s_[:, :, viz_idx[1], :],
    np.s_[:, :, :, viz_idx[2]],
]

viz_lr_slice_idx = [
    np.s_[:, viz_lr_idx[0], :, :],
    np.s_[:, :, viz_lr_idx[1], :],
    np.s_[:, :, :, viz_lr_idx[2]],
]

### FA-Weighted Direction Maps

In [None]:
ims = list()
for im, slices in zip(
    [viz_subj.dti, viz_subj.pred, viz_subj.lr_dti],
    (viz_slice_idx, viz_slice_idx, viz_lr_slice_idx),
):
    for sl in slices:
        selection = im[sl]
        fa_w = pitn.viz.direction_map(selection, channels_first=True)
        fa_w = fa_w.transpose(1, 2, 0)
        ims.append(fa_w)

dim_labels = [
    "Axial",
    "Coronal",
    "Saggital",
]
vol_labels = ["Ground Truth", "Pred", "LR Input"]

with mpl.rc_context(
    {
        "font.size": 12.0,
        "axes.labelpad": 10.0,
        "figure.autolayout": False,
        "figure.constrained_layout.use": True,
    }
):
    fig = plt.figure(dpi=130, figsize=(7, 7))
    fig = pitn.viz.plot_im_grid(
        ims,
        nrows=len(vol_labels),
        title=f"Subj {viz_subj_id} Prediction Results",
        row_headers=vol_labels,
        col_headers=dim_labels,
        colorbars=None,
        fig=fig,
        interpolation="antialiased",
    )
if not disable_fig_save:
    plt.savefig(experiment_results_dir / "dir_map_pred.png");

### DTI Channel-Wise Visualization

In [None]:
channel_names = [
    "$D_{x,x}$",
    "$D_{x,y}$",
    "$D_{y,y}$",
    "$D_{x,z}$",
    "$D_{y,z}$",
    "$D_{z,z}$",
]

dim_labels = [
    "\nAxial",
    "\nCor.",
    "\nSagg.",
]

dti_names = [
    "FR",
    "Pred",
    "LR",
    "Abs Err",
]

#### Global Normalization

In [None]:
# Display all 6 DTIs for ground truth, predicted, and root squared error
cmap = "gray"
title = f"DTI Subj {viz_subj_id} Summary"

with mpl.rc_context(
    {
        "font.size": 6.0,
        "axes.labelpad": 10.0,
        "figure.autolayout": False,
        "figure.constrained_layout.use": True,
    }
):
    fig = plt.figure(dpi=150, figsize=(8, 12))
    fig = pitn.viz.plot_vol_slices(
        viz_subj.dti,
        viz_subj.pred,
        viz_subj.lr_dti,
        viz_subj.abs_error,
        slice_idx=(None, None, viz_idx[2] / dti_shape[2]),
        title=title,
        vol_labels=dti_names,
        slice_labels=dim_labels,
        channel_labels=channel_names,
        colorbars="col",
        fig=fig,
        cmap=cmap,
        interpolation="antialiased",
    )


if not disable_fig_save:
    plt.savefig(experiment_results_dir / f"DIQT_DTI_sub-{viz_subj_id}_pred_result.png");

#### Channel-Wise Normalization

---

## End Experiment

In [None]:
pl_logger.experiment.flush()
# Close tensorboard logger.
# Don't finalize if the experiment was for debugging.
if "debug" not in EXPERIMENT_NAME.casefold():
    pl_logger.finalize("success")
    # Experiment is complete, move the results directory to its final location.
    if experiment_results_dir != final_experiment_results_dir:
        print("Moving out of tmp location")
        experiment_results_dir = experiment_results_dir.rename(
            final_experiment_results_dir
        )
        log_txt_file = experiment_results_dir / log_txt_file.name

In [None]:
# for subj_id in test_subjs[:5]:
#     gt_dti = subj_data[subj_id][0]['dti'][None,]
#     mask = subj_data[subj_id][0]['mask'][None,]
#     pred_dti = model.plain_log.viz.test_preds[subj_id][None]

#     gt_fa = pitn.metrics.fast_fa(gt_dti, mask)
#     pred_fa = pitn.metrics.fast_fa(pred_dti, mask)

#     gt_select = torch.masked_select(gt_fa, mask)
#     pred_select = torch.masked_select(pred_fa, mask)

#     diff = gt_select - pred_select
#     diff = diff.detach().cpu().numpy().flatten()
#     sns.histplot(diff)
#     print((diff == 0).sum())
#     print(diff.mean())
#     print(np.std(diff))

#     plt.title(subj_id)
#     plt.show()

In [None]:
# run_dir = experiment_results_dir.parent / "2022-03-01T06_58_21__pitn_dti_mid_net"
# pred_dir = run_dir / "predicted_dti"
# dti_only_test_subj_files = list(pred_dir.glob('[0-9]*'))
# dti_only_ids = [d.name[:6] for d in dti_only_test_subj_files]
# print(dti_only_ids)

# for subj_id in dti_only_ids[:5]:
#     gt_dti = subj_data[subj_id][0]['dti'][None]
#     mask = subj_data[subj_id][0]['mask'][None]

#     pred_dti = nib.load(list(pred_dir.glob(subj_id + "*"))[0])
#     pred_dti = pred_dti.get_fdata()[None]
#     pred_dti = torch.from_numpy(pred_dti)
#     gt_fa = pitn.metrics.fast_fa(gt_dti, mask)
#     pred_fa = pitn.metrics.fast_fa(pred_dti, mask)

#     gt_select = torch.masked_select(gt_fa, mask)
#     pred_select = torch.masked_select(pred_fa, mask)
#     plt.imshow(gt_fa[0, 0, :, 80] - pred_fa[0, 0, :, 80], cmap='gray')
#     plt.colorbar()
#     plt.show()
#     diff = gt_select - pred_select
#     diff = diff.detach().cpu().numpy().flatten()
#     print((diff == 0).sum())
#     sns.histplot(diff)
#     print(diff.mean())
#     print(np.std(diff))
#     plt.title(subj_id)
#     plt.show()

In [None]:
# gt_fa.shape