# Pain in the Net
Application of Deep Image Quality Transfer (DIQT) with domain adaptation.


Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

---

References:

* `R. Tanno et al., “Uncertainty modelling in deep learning for safer neuroimage enhancement: Demonstration in diffusion MRI,” NeuroImage, vol. 225, p. 117366, Jan. 2021, doi: 10.1016/j.neuroimage.2020.117366.`
* `D. C. Alexander et al., “Image quality transfer and applications in diffusion MRI,” NeuroImage, vol. 152, pp. 283–298, May 2017, doi: 10.1016/j.neuroimage.2017.02.089.`
* `S. B. Blumberg, R. Tanno, I. Kokkinos, and D. C. Alexander, “Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images,” in Medical Image Computing and Computer Assisted Intervention – MICCAI 2018, Cham, 2018, pp. 118–125, doi: 10.1007/978-3-030-00928-1_14.`


## Imports & Environment Setup

### Imports

In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 2

# imports
import collections
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
import zipfile

import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.patheffects
import matplotlib.pyplot as plt
import seaborn as sns
import ipyplot

# Data management libraries.
import nibabel as nib
import nibabel.processing
import natsort
from natsort import natsorted
from addict import Addict
import pprint
from pprint import pprint as ppr
import box
from box import Box

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchio
import pytorch_lightning as pl
import monai

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy

import pitn

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [None]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device("cuda")
    # Activate cudnn benchmarking to optimize convolution algorithm speed.
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = True
        print("CuDNN convolution optimization enabled.")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

### Specs Recording

In [None]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():
    # GPU information
    try:
        gpu_info = pitn.utils.system.get_gpu_specs()
        print(gpu_info)
    except NameError:
        print("CUDA Version: ", torch.version.cuda)
else:
    print("CUDA not in use, falling back to CPU")

In [None]:
# cap is defined in an ipython magic command
print(cap)

Author: Tyler Spears

Last updated: 2022-02-07T20:28:14.236511+00:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.23.1

Compiler    : GCC 7.3.0
OS          : Linux
Release     : 5.4.0-91-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

Git hash: 61be6a52797624cf59758b441ee67f9ab3096d94

torchio          : 0.18.30
pytorch_lightning: 1.5.9
dipy             : 1.4.1
natsort          : 7.1.1
pandas           : 1.2.3
torch            : 1.10.0
seaborn          : 0.11.1
sys              : 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0]
scipy            : 1.5.3
nibabel          : 3.2.1
skimage          : 0.18.1
numpy            : 1.20.2
box              : 5.4.1
matplotlib       : 3.4.1
pitn             : 0.0.post1.dev132+g02c0d1a
monai            : 0.8.0
json             : 2.0.9

  id  Name                Driver Version    CUDA Version  Total Memory    uuid
----  ----------------  ----------------  --------------

### Data Variables & Definitions Setup

In [None]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"])
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"])
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

## Parameter Reading & Experiment Setup

### Parameters

<div class="alert alert-block alert-info"> <b>NOTE</b> Here are all the parameters! This makes it easy to find them! </div>

In [None]:
params = Box(default_box=True)

# General experiment-wide params
###############################################
params.experiment_name = "debug_rmse_metrics"
###############################################
# 6 channels for the 6 DTI components
params.n_channels = 6
params.n_subjs = 16
params.lr_vox_size = 2.5
params.fr_vox_size = 1.25
params.use_anat = True

# Data params
params.data.fr_dir = data_dir / f"scale-{params.fr_vox_size:.2f}mm"
params.data.lr_dir = data_dir / f"scale-{params.lr_vox_size:.2f}mm"
params.data.dti_fname_pattern = r"sub-*dti.nii.gz"
params.data.mask_fname_pattern = r"dti/sub-*mask.nii.gz"
params.data.anat_type = "t2w"
params.data.anat_fname_pattern = f"sub-*{params.data.anat_type}.nii.gz"
# The data were downsampled artificially by this factor.
params.data.downsampled_by_factor = params.lr_vox_size / params.fr_vox_size
params.data.downsampled_by_factor = (
    int(params.data.downsampled_by_factor)
    if int(params.data.downsampled_by_factor) == params.data.downsampled_by_factor
    else params.data.downsampled_by_factor
)

# This is the number of voxels to remove (read: center crop out) from the network's
# prediction. This allows for an "oversampling" of the low-res voxels to help inform a
# more constrained HR prediction. This value of voxels will be removed from each spatial
# dimension (D, H, and W) starting at the center of the output patches.
# Ex. A size of 1 will remove the 2 outer-most voxels from each dimension in the output,
# while still keeping the corresponding voxels in the LR input.
params.hr_center_crop_per_side = 0

# Quantile clamping to be done on the outer edge of the mask.
# NOTE: This clamping *will* effect the test and validation scores, as those voxels
# clamped by this are considered as "errors"/"noise" and will be discarded in testing.
# 80,000 is ~ the average volume of the entire brain mask.
params.data.edge_correction_max_vox_to_change = 300
# params.data.mask_edge_clamp_max_quantile = False
# Second data scaling method, where the training data will be scaled and possibly clipped,
# but the testing data will be compared on the originals.
# Scale input data by the valid values of each channel of the DTI.
# I.e., Dx,x in [0, 1], Dx,y in [-1, 1], Dy,y in [0, 1], Dy,z in [-1, 1], etc.
params.data.dti_scale_range = ((0, -1, 0, -1, -1, 0), (1, 1, 1, 1, 1, 1))
params.data.anat_scale_range = (0, 1)
params.data.scale_to_quantiles = (0.0001, 0.9999)
params.data.clip_to_quantiles = True

# Network params.
# The network's goal is to upsample the input by this factor.
params.net.upscale_factor = params.data.downsampled_by_factor
params.net.kwargs.n_res_units = 3
params.net.kwargs.n_dense_units = 3
params.net.kwargs.interior_channels = params.n_channels * 2
params.net.kwargs.activate_fn = F.elu
params.net.kwargs.upsample_activate_fn = F.elu
params.net.kwargs.center_crop_output_side_amt = params.hr_center_crop_per_side

# Adam optimizer kwargs
params.optim.name = "AdamW"
params.optim.kwargs.lr = 5e-4
params.optim.kwargs.betas = (0.9, 0.999)

# Testing params
params.test.dataset_subj_percent = 0.4

# Validation params
params.val.dataset_subj_percent = 0.2

# Training params
params.train.in_patch_size = (24, 24, 24)
params.train.batch_size = 32
params.train.samples_per_subj_per_epoch = 8000
params.train.max_epochs = 50
params.train.loss_name = "mse"
# Percentage of subjs in dataset that go into the training set.
params.train.dataset_subj_percent = 1 - (
    params.test.dataset_subj_percent + params.val.dataset_subj_percent
)

# Learning rate scheduler config.
params.train.lr_scheduler.name = "OneCycleLR"
params.train.lr_scheduler.kwargs.max_lr = 3e-3
params.train.lr_scheduler.kwargs.epochs = params.train.max_epochs
num_test_subjs = int(np.ceil(params.n_subjs * params.test.dataset_subj_percent))
num_val_subjs = int(np.ceil(params.n_subjs * params.val.dataset_subj_percent))
num_train_subjs = params.n_subjs - (num_test_subjs + num_val_subjs)
params.train.lr_scheduler.kwargs.steps_per_epoch = (
    params.train.samples_per_subj_per_epoch * num_train_subjs // params.train.batch_size
)

# If a config file exists, override the defaults with those values.
try:
    if "PITN_CONFIG" in os.environ.keys():
        config_fname = Path(os.environ["PITN_CONFIG"])
    else:
        config_fname = pitn.utils.system.get_file_glob_unique(Path("."), r"config.*")
    f_type = config_fname.suffix.casefold()
    if f_type in {"yaml", "yml"}:
        f_params = Box.from_yaml(config_fname)
    elif f_type == "json":
        f_params = Box.from_json(config_fname)
    elif f_type == "toml":
        f_params = Box.from_toml(config_fname)
    else:
        raise RuntimeError()

    params.merge_update(f_params)

except:
    pass

# Remove the default_box behavior now that params have been fully read in.
p = Box(default_box=False)
p.merge_update(params)
params = p

In [None]:
# Choose a subset of all params as the hyperparams of the model.
hyperparams = Box(
    batch_size=params.train.batch_size,
    samples_per_subj_epoch=params.train.samples_per_subj_per_epoch,
    epochs=params.train.max_epochs,
    loss_fn=params.train.loss_name,
    lr_scheduler=params.train.lr_scheduler.name,
    optim=params.optim.name,
    anat=params.data.anat_type if params.use_anat else False,
    n_subjs=params.n_subjs,
).to_dict()

### Experiment Logging Setup

In [None]:
# tensorboard experiment logging setup.
EXPERIMENT_NAME = params.experiment_name

ts = datetime.datetime.now().replace(microsecond=0).isoformat()
# Break ISO format because many programs don't like having colons ':' in a filename.
ts = ts.replace(":", "_")
experiment_name = ts + "__" + EXPERIMENT_NAME
run_name = experiment_name
# experiment_results_dir = results_dir / experiment_name

# Create temporary directory for results directory, in case experiment does not finish.
tmp_dirs = list(tmp_results_dir.glob("*"))

# Only keep up to N tmp results.
n_tmp_to_keep = 5
if len(tmp_dirs) > (n_tmp_to_keep - 1):
    print(f"More than {n_tmp_to_keep} temporary results, culling to the most recent")
    tmps_to_delete = natsorted([str(tmp_dir) for tmp_dir in tmp_dirs])[
        : -(n_tmp_to_keep - 1)
    ]
    for tmp_dir in tmps_to_delete:
        shutil.rmtree(tmp_dir)
        print("Deleted temporary results directory ", tmp_dir)

experiment_results_dir = tmp_results_dir / experiment_name
# Final target directory, to be made when experiment is complete.
final_experiment_results_dir = results_dir / experiment_name

In [None]:
print(experiment_name)

2022-02-07T20_28_14__debug_rmse_metrics


In [None]:
# Pass this object into the pytorchlightning Trainer object, for easier logging within
# the training/testing loops.
pl_logger = pl.loggers.TensorBoardLogger(
    tmp_results_dir,
    name=experiment_name,
    version="",
    log_graph=False,
    default_hp_metric=False,
)
# Use the lower-level logger for logging histograms, images, etc.
logger = pl_logger.experiment

# Create a separate txt file to log streams of events & info besides parameters & results.
log_txt_file = Path(logger.log_dir) / "log.txt"
with open(log_txt_file, "a+") as f:
    f.write(f"Experiment Name: {experiment_name}\n")
    f.write(f"Timestamp: {ts}\n")
    # Parameters.
    f.write(pprint.pformat(params.to_dict()) + "\n")
    # cap is defined in an ipython magic command
    f.write(f"Environment and Hardware Info:\n {cap}\n\n")

## Data Loading

### Subject ID Selection

In [None]:
# Find data directories for each subject.
subj_dirs: Box = Box()

selected_ids = [
    "397154",
    "224022",
    "140117",
    "751348",
    "894774",
    "156637",
    "227432",
    "303624",
    "185947",
    "810439",
    "753251",
    "644246",
    "141422",
    "135528",
    "103010",
    "700634",
]

## Sub-set the chosen participants for dev and debugging!
sampled_subjs = random.sample(selected_ids, params.n_subjs)
if len(sampled_subjs) < len(selected_ids):
    warnings.warn(
        f"WARNING: Sub-selecting {len(sampled_subjs)}/{len(selected_ids)} "
        + "participants for dev and debugging. "
        + f"Subj IDs selected: {sampled_subjs}"
    )
selected_subjs = sampled_subjs
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_subjs = natsorted(selected_subjs)

for subj_id in selected_subjs:
    subj_dirs[subj_id] = Box()
    subj_dirs[subj_id].fr = pitn.utils.system.get_file_glob_unique(
        params.data.fr_dir, f"*{subj_id}*"
    )
    subj_dirs[subj_id].lr = pitn.utils.system.get_file_glob_unique(
        params.data.lr_dir, f"*{subj_id}*"
    )
    assert subj_dirs[subj_id].fr.exists()
    assert subj_dirs[subj_id].lr.exists()
ppr(subj_dirs)

{'103010': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-103010'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-103010')},
 '135528': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-135528'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-135528')},
 '140117': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-140117'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-140117')},
 '141422': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-141422'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-141422')},
 '156637': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/su

In [None]:
with open(log_txt_file, "a+") as f:
    f.write(f"Selected Subjects: {selected_ids}\n")

logger.add_text("subjs", pprint.pformat(selected_ids))

### Loading and Preprocessing

In [None]:
# Prep for Dataset loading.

# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=True)

# HR -> LR patch coordinate conversion function.
fr2lr_patch_coords_fn = {
    "lr_dti": functools.partial(
        pitn.coords.transform.int_downscale_patch_idx,
        downscale_factor=params.data.downsampled_by_factor,
        downscale_patch_shape=params.train.in_patch_size,
    )
}

# Kwargs for the patches dataset (the _VolPatchDataset class) of the HR volumes.
patch_kwargs = dict(
    patch_shape=tuple(
        np.floor(
            np.asarray(params.train.in_patch_size) * params.data.downsampled_by_factor
        ).astype(int)
    ),
    stride=1,
    meta_keys_to_patch_index={"dti", "mask", params.data.anat_type},
    mask_name="mask",
)


def fix_downsample_shape_errors(
    fr_vol: torch.Tensor, fr_affine: torch.Tensor, target_spatial_shape: tuple
):
    """Small utility to fix shape differences between LR and FR data."""
    target_shape = np.asarray(target_spatial_shape)
    if fr_vol.shape[1:] != tuple(target_shape):
        # Use torchio objects because they fix the affine matrix, too.
        # Flip before transform to pad on the right/top/furthest side of the dimension
        # first, before the left/bottom/closest.
        flip_vol = fr_vol.flip([1, 2, 3])
        im = torchio.ScalarImage(tensor=flip_vol, affine=fr_affine)
        transform = torchio.transforms.CropOrPad(target_spatial_shape, 0, copy=False)
        im = transform(im)
        result_vol = im["data"]
        # Unflip.
        result_vol = result_vol.flip([1, 2, 3])
        result_aff = im["affine"]
    else:
        result_vol = fr_vol
        result_aff = fr_affine

    return result_vol, result_aff


def orient_to_viz(vol, affine):

    if torch.is_tensor(vol):
        v = vol.detach().cpu().numpy()
    else:
        v = vol
    v = np.rot90(np.rot90(v, k=1, axes=(1, 3)), k=2, axes=(2, 3))
    if torch.is_tensor(vol):
        v = torch.from_numpy(np.copy(v)).to(vol)

    # Adjust the affine matrix.
    full_rot_aff = np.zeros_like(affine)
    full_rot_aff[-1, -1] = 1.0
    # 90 degree rot around the second axis.
    q1 = nib.quaternions.angle_axis2quat(np.pi / 2, [0, 1, 0])
    # 180 degree rot around the first axis.
    q2 = nib.quaternions.angle_axis2quat(np.pi, [1, 0, 0])
    new_q = nib.quaternions.mult(q1, q2)
    rot_aff = nib.quaternions.quat2mat(new_q)
    full_rot_aff[:-1, :-1] = rot_aff
    new_aff = full_rot_aff @ affine

    return v, new_aff

#### Data Loading

In [None]:
# Import and organize all data.
subj_data: dict = dict()

meta_keys_to_keep = {"affine", "original_affine"}
for subj_id, subj_dir in subj_dirs.items():

    data = dict()
    data["subj_id"] = subj_id
    fr_subj_dir = subj_dirs[subj_id]["fr"]
    lr_subj_dir = subj_dirs[subj_id]["lr"]
    data["fr_subj_dir"] = fr_subj_dir
    data["lr_subj_dir"] = lr_subj_dir

    # Low-resolution DTI.
    lr_dti_f = pitn.utils.system.get_file_glob_unique(
        lr_subj_dir, params.data.dti_fname_pattern
    )
    im = nib_reader.read(lr_dti_f)
    lr_dti, meta = nib_reader.get_data(im)
    meta = {k: meta[k] for k in meta_keys_to_keep}
    lr_dti = torch.from_numpy(lr_dti)
    lr_dti, meta["affine"] = orient_to_viz(lr_dti, meta["affine"])
    data["lr_dti"] = lr_dti
    # Store raw LR DTI for cubic spline comparisons.
    data["raw_lr_dti"] = lr_dti
    data["lr_dti_meta_dict"] = meta

    # May need to handle shape errors when re-upscaling back from LR to HR.
    lr_dti_shape = np.asarray(lr_dti.shape[1:])
    target_fr_shape = np.floor(lr_dti_shape * params.net.upscale_factor).astype(int)

    # Full-resolution images/volumes.
    # DTI.
    dti_f = pitn.utils.system.get_file_glob_unique(
        fr_subj_dir, params.data.dti_fname_pattern
    )
    im = nib_reader.read(dti_f)
    dti, meta = nib_reader.get_data(im)
    dti = torch.from_numpy(dti)
    meta = {k: meta[k] for k in meta_keys_to_keep}
    dti, meta["affine"] = fix_downsample_shape_errors(
        dti, meta["affine"], target_fr_shape
    )
    dti, meta["affine"] = orient_to_viz(dti, meta["affine"])
    data["dti"] = dti
    # Store raw DTI for validation and testing.
    data["raw_dti"] = dti
    data["dti_meta_dict"] = meta

    # Diffusion mask.
    mask_f = pitn.utils.system.get_file_glob_unique(
        fr_subj_dir, params.data.mask_fname_pattern
    )
    im = nib_reader.read(mask_f)
    mask, meta = nib_reader.get_data(im)
    meta = {k: meta[k] for k in meta_keys_to_keep}
    mask = torch.from_numpy(mask)
    # Add channel dim if not available.
    if mask.ndim == 3:
        mask = mask[
            None,
        ]
    mask, meta["affine"] = fix_downsample_shape_errors(
        mask, meta["affine"], target_fr_shape
    )
    mask, meta["affine"] = orient_to_viz(mask, meta["affine"])
    mask = mask.bool()
    data["mask"] = mask
    data["mask_meta_dict"] = meta

    # Anatomical/structural volume.
    anat_f = pitn.utils.system.get_file_glob_unique(
        fr_subj_dir, params.data.anat_fname_pattern
    )
    im = nib_reader.read(anat_f)
    anat, meta = nib_reader.get_data(im)
    meta = {k: meta[k] for k in meta_keys_to_keep}
    anat = torch.from_numpy(anat)
    if anat.ndim == 3:
        anat = anat[
            None,
        ]
    anat, meta["affine"] = fix_downsample_shape_errors(
        anat, meta["affine"], target_fr_shape
    )
    anat, meta["affine"] = orient_to_viz(anat, meta["affine"])
    data[params.data.anat_type] = anat
    data[params.data.anat_type + "_meta_dict"] = meta

    # plt.clf()
    # dti = data['dti']
    # fig, axs = plt.subplots(nrows=3, dpi=110)
    # plt.title(f"Subject {subj_id}")
    # axs[0].plot(dti.view(6, 144, -1).max(2).values.numpy().T)
    # axs[1].plot(dti.permute(0, 2, 1, 3).contiguous().view(6, 174, -1).max(2).values.numpy().T)
    # axs[2].plot(dti.permute(0, 3, 1, 2).contiguous().view(6, 144, -1).max(2).values.numpy().T)
    # plt.show(close=True)

    # Consider this as the "noise correction" step to have more understandable,
    # stable, and consistent metric results. Otherwise, metrics can change by orders of
    # magnitude for no good reason!
    if params.data.edge_correction_max_vox_to_change:

        correct_dti = pitn.data.norm.correct_edge_noise_with_median(
            data["dti"],
            data["mask"],
            max_num_vox_to_change=params.data.edge_correction_max_vox_to_change,
            erosion_st_elem=skimage.morphology.ball(4),
            median_st_elem=skimage.morphology.cube(2),
        )
        print("DTI Correction Mean Abs. Error Per Tensor Component: ")
        abs_diff = torch.abs(data["dti"] - correct_dti)
        mae = abs_diff.view(6, -1).sum(1)
        mae = mae / data["mask"].sum()
        mae_str = ""

        for c in range(params.n_channels):
            mae_v = mae.detach().cpu()[c]
            # iqr = torch.quantile(data["dti"][c][data['mask'][0]], torch.as_tensor([0.25, 0.75]))
            # iqr = torch.abs(iqr[1] - iqr[0]).item()
            num_changed = (abs_diff[c] > 1e-13).sum().item()
            med = torch.median(data["dti"][c][data["mask"][0]]).item()
            s = (
                "\t"
                + str(mae_v.item())
                + f" with {num_changed} changes vs. median "
                + str(med)
                + "\n"
            )
            mae_str = mae_str + s

        print(mae_str)
        data["dti"] = correct_dti

    elif params.data.mask_edge_quantile_clamp:
        data["dti"] = pitn.data.norm.mask_constrain_clamp(
            data["dti"],
            data["mask"],
            quantile_clamp=params.data.mask_edge_quantile_clamp,
            selection_st_elem=skimage.morphology.ball(2),
        )

        # cheap_lr_mask = F.interpolate(
        #     data["mask"][
        #         None,
        #     ].float(),
        #     size=tuple(data["lr_dti"].shape[1:]),
        #     mode="nearest",
        # ).bool()[0]
        # data["lr_dti"] = pitn.data.norm.mask_constrain_clamp(
        #     data["lr_dti"],
        #     cheap_lr_mask,
        #     quantile_clamp=params.data.mask_edge_quantile_clamp,
        #     selection_st_elem=skimage.morphology.ball(1),
        # )
    # plt.clf()
    # dti = data['dti']
    # fig, axs = plt.subplots(nrows=3, dpi=110)
    # plt.title(f"Subject {subj_id}")
    # axs[0].plot(dti.view(6, 144, -1).max(2).values.numpy().T)
    # axs[1].plot(dti.permute(0, 2, 1, 3).contiguous().view(6, 174, -1).max(2).values.numpy().T)
    # axs[2].plot(dti.permute(0, 3, 1, 2).contiguous().view(6, 144, -1).max(2).values.numpy().T)
    # plt.show(close=True)

    # Perform scaling of input data?
    if params.data.dti_scale_range:
        scaler = pitn.data.norm.DTIMinMaxScaler(
            params.data.dti_scale_range[0],
            params.data.dti_scale_range[1],
            quantile_low=params.data.scale_to_quantiles[0],
            quantile_high=params.data.scale_to_quantiles[1],
            dim=(1, 2, 3),
            channel_size=params.n_channels,
            clip=params.data.clip_to_quantiles,
        )
        scaled = scaler.scale(data["dti"], stateful=True, keep_orig=False)
        data["dti"] = scaled
        data["dti_scaler"] = scaler

        scaler = pitn.data.norm.DTIMinMaxScaler(
            params.data.dti_scale_range[0],
            params.data.dti_scale_range[1],
            quantile_low=params.data.scale_to_quantiles[0],
            quantile_high=params.data.scale_to_quantiles[1],
            dim=(1, 2, 3),
            channel_size=params.n_channels,
            clip=params.data.clip_to_quantiles,
        )
        scaled = scaler.scale(data["lr_dti"], stateful=True, keep_orig=False)
        data["lr_dti"] = scaled
        data["lr_dti_scaler"] = scaler

    # Perform scaling of input data?
    if params.data.anat_scale_range:
        scaler = pitn.data.norm.DTIMinMaxScaler(
            params.data.anat_scale_range[0],
            params.data.anat_scale_range[1],
            quantile_low=params.data.scale_to_quantiles[0],
            quantile_high=params.data.scale_to_quantiles[1],
            dim=(1, 2, 3),
            channel_size=1,
            clip=params.data.clip_to_quantiles,
        )
        scaled = scaler.scale(
            data[params.data.anat_type], stateful=True, keep_orig=False
        )
        # *****Disables anat mode refinement*******
        # scaled = (scaled * 0) + 1
        data[params.data.anat_type] = scaled
        data[params.data.anat_type + "_scaler"] = scaler

    vol_names = {"dti", "mask", "lr_dti", params.data.anat_type}
    metadata_names = set(data.keys()) - vol_names
    vol_d = {k: data[k] for k in vol_names}
    meta_d = {k: data[k] for k in metadata_names}

    # Create multi-volume dataset for this subj-session.
    subj_dataset = pitn.data.SubjSesDataset(
        vol_d,
        primary_vol_name="dti",
        special_secondary2primary_coords_fns=fr2lr_patch_coords_fn,
        transform=None,
        primary_patch_kwargs=patch_kwargs,
        **meta_d,
    )
    print("Creating patches")
    # Init the patches dataset.
    subj_dataset.patches

    # Finalize this subject.
    subj_data[subj_id] = subj_dataset

    print("=" * 20)

print("===Data Loaded===")

## Model Training

### Set Up Patch-Based Data Loaders

In [None]:
# Data train/validation/test split

num_subjs = len(subj_data)
num_test_subjs = int(np.ceil(params.n_subjs * params.test.dataset_subj_percent))
num_val_subjs = int(np.ceil(params.n_subjs * params.val.dataset_subj_percent))
num_train_subjs = params.n_subjs - (num_test_subjs + num_val_subjs)
subj_list = list(subj_data.keys())
# Randomly shuffle the list of subjects, then choose the first `num_test_subjs` subjects
# for testing.
random.shuffle(subj_list)

# Choose the remaining for training/validation.
# If only 1 subject is available, assume this is a debugging run.
if num_subjs == 1:
    warnings.warn(
        "DEBUG: Only 1 subject selected, mixing training, validation, and testing sets"
    )
    num_train_subjs = num_val_subjs = num_test_subjs = 1

    test_subjs = subj_list
    val_subjs = subj_list
    train_subjs = subj_list
else:
    test_subjs = subj_list[:num_test_subjs]
    val_subjs = subj_list[num_test_subjs : (num_test_subjs + num_val_subjs)]
    train_subjs = subj_list[(num_test_subjs + num_val_subjs) :]

print(f"{num_train_subjs}/{num_subjs} Training subject ID(s): {train_subjs}")
print(f"{num_val_subjs}/{num_subjs} Validation subject ID(s): {val_subjs}")
print(f"{num_test_subjs}/{num_subjs} Test subject ID(s): {test_subjs}")
with open(log_txt_file, "a+") as f:
    f.write(f"{num_train_subjs}/{num_subjs} Training subject ID(s): {train_subjs}\n")
    f.write(f"{num_val_subjs}/{num_subjs} Validation subject ID(s): {val_subjs}\n")
    f.write(f"{num_test_subjs}/{num_subjs} Test subject ID(s): {test_subjs}\n")

logger.add_text("train_subjs", str(train_subjs))
logger.add_text("val_subjs", str(val_subjs))
logger.add_text("test_subjs", str(test_subjs))

In [None]:
# Create Datasets and DataLoaders for test, validation, and training steps.
sample_kws = {
    "subj_id": "subj_id",
    "dti": "dti",
    "dti_scaler": "dti_scaler",
    "lr_dti": "lr_dti",
    "lr_dti_scaler": "lr_dti_scaler",
    "mask": "mask",
    params.data.anat_type: params.data.anat_type,
    params.data.anat_type + "_scaler": params.data.anat_type + "_scaler",
}

# Train
train_ds = list()
for subj_id in train_subjs:
    train_ds.append(subj_data[subj_id].patches)
train_dataset = torch.utils.data.ConcatDataset(train_ds)
train_sampler = pitn.samplers.ConcatDatasetBalancedRandomSampler(
    train_dataset.datasets,
    max_samples_per_dataset=params.train.samples_per_subj_per_epoch,
)
train_collate_fn = functools.partial(pitn.samplers.collate_dicts, **sample_kws)
train_loader = monai.data.DataLoader(
    train_dataset,
    sampler=train_sampler,
    batch_size=params.train.batch_size,
    collate_fn=train_collate_fn,
    pin_memory=True,
    num_workers=7,
    persistent_workers=True,
)

# Test & Validation
# Only need raw DTIs for testing and validation, not training.
tv_sample_kws = {**sample_kws, "raw_dti": "raw_dti", "raw_lr_dti": "raw_lr_dti"}
test_val_collate_fn = functools.partial(pitn.samplers.collate_dicts, **tv_sample_kws)

test_ds = list()
for subj_id in test_subjs:
    test_ds.append(subj_data[subj_id])
test_dataset = torch.utils.data.ConcatDataset(test_ds)
test_loader = monai.data.DataLoader(
    test_dataset, collate_fn=test_val_collate_fn, batch_size=1
)

val_ds = list()
for subj_id in val_subjs:
    val_ds.append(subj_data[subj_id])
val_dataset = torch.utils.data.ConcatDataset(val_ds)
val_loader = monai.data.DataLoader(
    val_dataset, collate_fn=test_val_collate_fn, batch_size=1
)

### Model Definition

In [None]:
# Full pytorch-lightning module for contained training, validation, and testing.
debug_prob = -1 / (
    params.train.samples_per_subj_per_epoch * len(train_subjs) / params.train.batch_size
)


class DIQTCascadeSystem(pl.LightningModule):

    # Specify training loss methods with mappings to their names as strings.
    loss_methods = {
        "MSE".casefold(): torch.nn.MSELoss(reduction="mean"),
        "SSE".casefold(): torch.nn.MSELoss(reduction="sum"),
        "RMSE".casefold(): lambda y_hat, y: torch.sqrt(
            F.mse_loss(y_hat, y, reduction="mean")
        ),
        "L1".casefold(): torch.nn.L1Loss(reduction="mean"),
    }

    def __init__(
        self,
        channels: int,
        batch_size: int,
        in_patch_shape: tuple,
        upscale_factor: int,
        anat_batch_key: str,
        train_loss_method: str,
        val_subj_ids: tuple,
        opt_params: dict,
        lr_scheduler_kwargs: dict,
        net_kwargs: dict = dict(),
        **hyper_parameters,
    ):
        super().__init__()
        self.save_hyperparameters(
            "channels",
            "batch_size",
            "train_loss_method",
            *list(hyper_parameters.keys()),
        )
        self._channels = channels
        self._batch_size = batch_size
        self._in_patch_shape = in_patch_shape
        self._anat_batch_key = anat_batch_key
        self._upscale_factor = upscale_factor
        self._val_viz_subj_id = random.choice(val_subj_ids)
        # self.hparams.update(hparams)

        # Parameters
        # Network parameters
        self.net = pitn.nn.sr.CascadeUpsampleModeRefine(
            self._channels, upscale_factor=self._upscale_factor, **net_kwargs
        )

        ## Training parameters
        self.opt_params = opt_params
        self.lr_scheduler_kwargs = lr_scheduler_kwargs

        # Select loss method as either one of the pre-selected methods, or a custom
        # callable.
        try:
            self._loss_fn = self.loss_methods[train_loss_method.casefold()]
        except (AttributeError, KeyError):
            if callable(train_loss_method):
                self._loss_fn = train_loss_method
            else:
                raise ValueError(
                    f"ERROR: Invalid loss function specification {train_loss_method}, "
                    + f"expected one of {self.loss_methods.keys()} or a callable."
                )

        self._val_subvol_range = dict()

        # My own dinky logging object.
        self.plain_log = Box(
            {
                "train_loss": list(),
                "val_loss": {"rmse": list(), "nrmse": list()},
                "test_loss": {
                    "spline": {
                        "rmse": dict(),
                        "nrmse": dict(),
                        "scaled_psnr": dict(),
                        "scaled_ssim": dict(),
                        "scaled_ms_ssim": dict(),
                    },
                    "rmse": dict(),
                    "nrmse": dict(),
                    "scaled_psnr": dict(),
                    "scaled_ssim": dict(),
                    "scaled_ms_ssim": dict(),
                },
                "viz": {
                    "spline_preds": dict(),
                    "test_preds": dict(),
                },
            },
        )

    def forward(self, x, x_mode_refine):
        y = self.net(x, x_mode_refine)
        return y

    def training_step(self, batch, batch_idx):
        batch = Box(batch)
        x = batch.lr_dti
        y = batch.dti
        y = self.net.crop_full_output(y)
        x_mode_refine = batch[self._anat_batch_key]
        mask = batch.mask.bool()
        mask = self.net.crop_full_output(mask)

        y_pred = self.net(x, x_mode_refine)

        # Only calculate loss on voxels in the brain mask.
        loss = self._loss_fn(
            torch.masked_select(y_pred, mask), torch.masked_select(y, mask)
        )

        self.log("train_loss", loss, batch_size=self._batch_size)
        self.plain_log["train_loss"].append(float(loss.cpu()))
        return loss

    def validation_step(self, batch, batch_idx):

        batch = Box(batch)
        # Assume batch size of 1 for the validation set.
        subj_id = batch.subj_id[0]
        x = batch.lr_dti
        y = batch.raw_dti
        y = self.net.crop_full_output(y)
        mask = batch.mask.bool()
        mask = self.net.crop_full_output(mask)
        x_mode_refine = batch[self._anat_batch_key]

        y_pred = self.net(x, x_mode_refine)

        # Calculate val loss on un-scaled output.
        if batch.lr_dti_scaler:
            y_pred = batch.lr_dti_scaler[0].descale(y_pred[0])[
                None,
            ]

        rmse_loss = torch.sqrt(
            F.mse_loss(
                torch.masked_select(y_pred, mask),
                torch.masked_select(y, mask),
                reduction="mean",
            )
        )
        nrmse_loss = pitn.metrics.minmax_normalized_rmse(
            torch.masked_select(y_pred, mask).view(y.shape[0], y.shape[1], -1),
            torch.masked_select(y, mask).view(y.shape[0], y.shape[1], -1),
            reduction="mean",
        )
        self.log("val_loss/nrmse", nrmse_loss)
        self.log("val_loss/rmse", rmse_loss)
        self.plain_log.val_loss.nrmse.append(float(nrmse_loss.detach().cpu()))
        self.plain_log.val_loss.rmse.append(float(rmse_loss.detach().cpu()))

        # Plot visual of validation volume.
        if subj_id == self._val_viz_subj_id and not self._val_subvol_range:
            # Take range between 1/4 to 3/4 the size of each dimension.
            fr_space = np.asarray(y.shape[2:])
            fr_low = np.floor(fr_space * 1 / 4).astype(int)
            fr_high = np.floor(fr_space * 3 / 4).astype(int)
            subvol_slice = np.s_[
                fr_low[0] : fr_high[0], fr_low[1] : fr_high[1], fr_low[2] : fr_high[2]
            ]
            self._val_subvol_range["fr"] = (
                0,
                ...,
            ) + subvol_slice

            lr_space = np.asarray(x.shape[2:])
            lr_low = np.floor(lr_space * 1 / 4).astype(int)
            lr_high = np.floor(lr_space * 3 / 4).astype(int)
            subvol_slice = np.s_[
                lr_low[0] : lr_high[0], lr_low[1] : lr_high[1], lr_low[2] : lr_high[2]
            ]
            self._val_subvol_range["lr"] = (
                0,
                ...,
            ) + subvol_slice

        if subj_id == self._val_viz_subj_id:
            x_subvol = x.detach()[self._val_subvol_range["lr"]].float()
            y_subvol = y.detach()[self._val_subvol_range["fr"]].float()
            pred_subvol = y_pred.detach()[self._val_subvol_range["fr"]].float()
            # x_mode_refine_subvol = x_mode_refine.detach()[
            #     self._val_subvol_range["fr"]
            # ].expand_as(y_subvol)
            # mask_subvol = mask.detach()[self._val_subvol_range["fr"]]

            # Create grid plot
            # Plot settings to propogate into the figure.
            with mpl.rc_context(
                {
                    "font.size": 8.0,
                    "axes.labelpad": 10.0,
                    "figure.autolayout": False,
                    "figure.constrained_layout.use": True,
                }
            ):
                fig = plt.figure(dpi=130, figsize=(6, 4))
                channel_names = [
                    r"$D_{x,x}$",
                    r"$D_{x,y}$",
                    r"$D_{y,y}$",
                    r"$D_{x,z}$",
                    r"$D_{y,z}$",
                    r"$D_{z,z}$",
                ]
                slice_labels = [
                    "\nAxial",
                    "\nCoronal",
                    "\nSagg.",
                ]
                img_labels = ["GT", "Pred", "Input"]
                fig = pitn.viz.plot_vol_slices(
                    y_subvol,
                    pred_subvol,
                    x_subvol,
                    slice_idx=(0.55, None, None),
                    title=f"Subj {subj_id} Step {self.global_step}",
                    vol_labels=img_labels,
                    channel_labels=channel_names,
                    slice_labels=slice_labels,
                    colorbars="each",
                    fig=fig,
                    interpolation="antialiased",
                    cmap="gray",
                )
                self.logger.experiment.add_figure("val_slice", fig, self.global_step)

        return rmse_loss

    def on_test_start(self):
        # Initialize the metrics as hyperparams so they appear under the tensorboard
        # hparams tab. From:
        # <https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#logging-hyperparameters>
        self.logger.log_hyperparams(
            self.hparams,
            {
                "hp/rmse": 0,
                "hp/nrmse": 0,
                "hp/scaled_psnr": 0,
                "hp/scaled_ssim": 0,
                "hp/scaled_ms_ssim": 0,
            },
        )

    def test_step(self, batch: dict, batch_idx):

        batch = Box(batch)
        # Assume a batch size of 1.
        subj_id = batch.subj_id[0]
        x = batch.lr_dti
        y = batch.raw_dti
        y = self.net.crop_full_output(y)
        mask = batch.mask.bool()
        mask = self.net.crop_full_output(mask)
        x_mode_refine = batch[self._anat_batch_key]

        # y_pred = self.net(x, x_mode_refine).float()
        #*****disables testing*********
        y_pred = torch.randn_like(y)

        # Run spline over the raw LR input.
        spline_x = batch.raw_lr_dti
        spline_y_pred = self.pred_spline(spline_x).float()

        # Calculate test loss on un-scaled inputs & outputs.
        if batch.lr_dti_scaler:
            # Unscale the network prediction *without* using any distribution information
            # from the ground truth volumes.
            y_pred = batch.lr_dti_scaler[0].descale(y_pred[0])[
                None,
            ]

        # Mask select the target and prediction(s)
        y_select = torch.masked_select(y, mask).view(y.shape[0], y.shape[1], -1)
        y_pred_select = torch.masked_select(y_pred, mask).view(
            y_pred.shape[0], y_pred.shape[1], -1
        )
        spline_y_pred_select = torch.masked_select(spline_y_pred, mask).view(
            spline_y_pred.shape[0], spline_y_pred.shape[1], -1
        )

        ###### Calculate network performance metrics.
        rmse_loss = torch.sqrt(
            F.mse_loss(
                y_pred_select,
                y_select,
                reduction="mean",
            )
        )
        nrmse_loss = pitn.metrics.minmax_normalized_rmse(
            y_pred_select,
            y_select,
            reduction="mean",
        )
        scaled_psnr_loss = pitn.metrics.range_scaled_psnr(
            y_pred_select,
            y_select,
            feature_range=(0.0, 1.0),
            reduction="mean",
        )
        scaled_ssim_loss = pitn.metrics.range_scaled_ssim(
            y_pred,
            y,
            feature_range=(0.0, 1.0),
            reduction="mean",
        )
        scaled_ms_ssim_loss = pitn.metrics.range_scaled_ms_ssim(
            y_pred,
            y,
            feature_range=(0.0, 1.0),
            reduction="mean",
        )

        # on_epoch and reduce_fx gather the individual test epoch values and aggregate
        # them at the end of the testing epoch.
        self.log_dict(
            {
                "test_loss/rmse": rmse_loss,
                "test_loss/nrmse": nrmse_loss,
                "test_loss/scaled_psnr": scaled_psnr_loss,
                "test_loss/scaled_ssim": scaled_ssim_loss,
                "test_loss/scaled_ms_ssim": scaled_ms_ssim_loss,
            },
            on_epoch=True,
            reduce_fx=torch.mean,
        )
        # Log loss as an hparam metric to be shown alongside the experiment's hparams.
        self.log_dict(
            {
                "hp/rmse": rmse_loss,
                "hp/nrmse": nrmse_loss,
                "hp/scaled_psnr": scaled_psnr_loss,
                "hp/scaled_ssim": scaled_ssim_loss,
                "hp/scaled_ms_ssim": scaled_ms_ssim_loss,
            },
            on_epoch=True,
            reduce_fx=torch.mean,
        )
        self.plain_log.test_loss.rmse[subj_id] = rmse_loss.detach().cpu().item()
        self.plain_log.test_loss.nrmse[subj_id] = nrmse_loss.detach().cpu().item()
        self.plain_log.test_loss.scaled_psnr[subj_id] = (
            scaled_psnr_loss.detach().cpu().item()
        )
        self.plain_log.test_loss.scaled_ssim[subj_id] = (
            scaled_ssim_loss.detach().cpu().item()
        )
        self.plain_log.test_loss.scaled_ms_ssim[subj_id] = (
            scaled_ms_ssim_loss.detach().cpu().item()
        )

        ###### Handle spline losses and logging.
        spline_rmse = torch.sqrt(
            F.mse_loss(
                spline_y_pred_select,
                y_select,
                reduction="mean",
            )
        )
        spline_nrmse = pitn.metrics.minmax_normalized_rmse(
            spline_y_pred_select,
            y_select,
            reduction="mean",
        )
        spline_scaled_psnr = pitn.metrics.range_scaled_psnr(
            spline_y_pred_select,
            y_select,
            feature_range=(0.0, 1.0),
            reduction="mean",
        )
        spline_scaled_ssim = pitn.metrics.range_scaled_ssim(
            spline_y_pred,
            y,
            feature_range=(0.0, 1.0),
            reduction="mean",
        )
        spline_scaled_ms_ssim = pitn.metrics.range_scaled_ms_ssim(
            spline_y_pred,
            y,
            feature_range=(0.0, 1.0),
            reduction="mean",
        )

        self.plain_log.test_loss.spline.rmse[subj_id] = (
            spline_rmse.detach().cpu().item()
        )
        self.plain_log.test_loss.spline.nrmse[subj_id] = (
            spline_nrmse.detach().cpu().item()
        )
        self.plain_log.test_loss.spline.scaled_psnr[subj_id] = (
            spline_scaled_psnr.detach().cpu().item()
        )
        self.plain_log.test_loss.spline.scaled_ssim[subj_id] = (
            spline_scaled_ssim.detach().cpu().item()
        )
        self.plain_log.test_loss.spline.scaled_ms_ssim[subj_id] = (
            spline_scaled_ms_ssim.detach().cpu().item()
        )

        # Store results for vizualization later.
        self.plain_log.viz.test_preds[subj_id] = y_pred[0].detach().cpu()
        self.plain_log.viz.spline_preds[subj_id] = spline_y_pred[0].detach().cpu()

        return {
            "rmse": rmse_loss,
            "nrmse": nrmse_loss,
            "scaled_psnr": scaled_psnr_loss,
            "scaled_ssim": scaled_ssim_loss,
            "scaled_ms_ssim": scaled_ms_ssim_loss,
        }

    def pred_spline(self, x, order=3):
        x_np = x.detach().cpu().numpy()
        ys = list()
        for b in x_np:
            y_b = scipy.ndimage.zoom(
                b,
                zoom=(
                    1,
                    self._upscale_factor,
                    self._upscale_factor,
                    self._upscale_factor,
                ),
                order=order,
            )

            ys.append(torch.from_numpy(y_b))

        y = torch.stack(ys, dim=0)
        y = y.to(x)

        return y

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.net.parameters(), **self.opt_params)
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, **self.lr_scheduler_kwargs
        )
        opt_system = {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
        return opt_system

### Training Loop

In [None]:
train_start_timestamp = datetime.datetime.now().replace(microsecond=0)
model_kwargs = dict(
    channels=params.n_channels,
    batch_size=params.train.batch_size,
    in_patch_shape=params.train.in_patch_size,
    upscale_factor=params.net.upscale_factor,
    anat_batch_key=params.data.anat_type,
    train_loss_method=params.train.loss_name,
    val_subj_ids=val_subjs,
    opt_params=params.optim.kwargs,
    net_kwargs=params.net.kwargs,
    lr_scheduler_kwargs=params.train.lr_scheduler.kwargs,
)
# Update init kwargs with hyperparams, in case there are overlapping names.
model_kwargs.update(**hyperparams)

# Create model from given kwargs.
model = DIQTCascadeSystem(**model_kwargs)


with open(log_txt_file, "a+") as f:
    f.write(f"Model overview: {model}\n")

# Create trainer object.
trainer = pl.Trainer(
    fast_dev_run=20,
    gpus=1,
    max_epochs=params.train.max_epochs,
    # max_epochs=7,
    logger=pl_logger,
    log_every_n_steps=50,
    # run validation every 0.5 epochs
    val_check_interval=0.5,
    # Enable mixed-precision floating point ops.
    precision=16,
    amp_backend="native",
    # max_time={"minutes": 20},
)

# Many warnings are produced here, so it's better for my sanity (i.e., worse in every other
# way) to just filter and ignore them...
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

train_duration = datetime.datetime.now().replace(microsecond=0) - train_start_timestamp
print(f"Train duration: {train_duration}")

In [None]:
# Save out trained model
trainer.save_checkpoint(str(experiment_results_dir / "model.ckpt"))

In [None]:
with open(log_txt_file, "a+") as f:
    f.write("\n")
    f.write(f"Training time: {train_duration}\n")
    f.write(
        f"\t{train_duration.days} Days, "
        + f"{train_duration.seconds // 3600} Hours,"
        + f"{(train_duration.seconds // 60) % 60} Minutes,"
        + f'{train_duration.seconds % 60} Seconds"\n'
    )

In [None]:
# Plot rolling average window of training loss values.
plt.figure(dpi=110)
window = 1000
rolling_mean = (
    np.convolve(model.plain_log["train_loss"], np.ones(window), "valid") / window
)
rolling_start = 100
plt.plot(
    np.arange(
        window + rolling_start,
        window + rolling_start + len(rolling_mean[rolling_start:]),
    ),
    rolling_mean[rolling_start:],
)
plt.title("Training Loss " + params.train.loss_name + f"\nRolling Mean {window}")
plt.xlabel("Epoch")
plt.ylabel("Loss")
# plt.ylim(0, 1)
print(np.median(rolling_mean))
print(
    np.mean(model.plain_log["train_loss"][: window + rolling_start]),
    np.var(model.plain_log["train_loss"][: window + rolling_start]),
    np.max(model.plain_log["train_loss"][: window + rolling_start]),
)

plt.savefig(experiment_results_dir / "train_loss.png")

## Model Testing/Evaluation

### Testing Loop

In [None]:
# mod_state = model.state_dict()

# test_mod = DIQTCascadeSystem(
#     channels=params.n_channels,
#     batch_size=params.train.batch_size,
#     in_patch_shape=params.train.in_patch_size,
#     upscale_factor=params.net.upscale_factor,
#     anat_batch_key=params.data.anat_type,
#     train_loss_method=params.train.loss_name,
#     opt_params=params.optim.kwargs,
#     hparams=hyperparams,
#     **params.net.kwargs,
# )
# test_mod.load_state_dict(mod_state)
# trainer.test(test_mod, dataloaders=test_loader, ckpt_path=None, verbose=True)

In [None]:
trainer.test(model, dataloaders=test_loader, ckpt_path=None, verbose=True)

In [None]:
test_losses = {"rmse", "nrmse", "scaled_psnr", "scaled_ssim", "scaled_ms_ssim"}
loss_comparison_directions = {
    "rmse": "↓",
    "nrmse": "↓",
    "scaled_psnr": "↑",
    "scaled_ssim": "↑",
    "scaled_ms_ssim": "↑",
}

test_results = Box(subj_id=list(), model=list(), metric=list(), value=list())
for subj_id in test_subjs:
    for metric in test_losses:
        # DIQT model
        test_results.subj_id.append(subj_id)
        test_results.model.append("diqt")
        test_results.metric.append(metric)
        test_results.value.append(model.plain_log.test_loss[metric][subj_id])

        # Cubic spline
        test_results.subj_id.append(subj_id)
        test_results.model.append("cubic_spline")
        test_results.metric.append(metric)
        test_results.value.append(model.plain_log.test_loss.spline[metric][subj_id])
# Convert to a pandas dataframe.
test_results = pd.DataFrame(test_results.to_dict())

with open(log_txt_file, "a+") as f:
    f.write(f"Test loss functions: {list(test_losses)}\n")

test_loss_log_file = experiment_results_dir / "test_loss.csv"
test_results.to_csv(test_loss_log_file, index=False)

### Evaluation Visualization

#### Comparison within experiment

In [None]:
with mpl.rc_context(
    {
        "font.size": 8.0,
    }
):
    fig, axs = plt.subplots(
        ncols=len(test_losses),
        sharex=True,
        figsize=(11, 4),
        dpi=130,
        gridspec_kw={"wspace": 1.0, "hspace": 0},
    )
    sns.despine(fig=fig, top=True, right=True)

    for i, l in enumerate(test_losses):

        ax = axs[i]
        df = test_results.loc[test_results.metric == l]
        vplot = sns.violinplot(
            x="model", y="value", data=df, ax=ax, scale="count", inner=None
        )
        axs[i].grid(axis="y", alpha=0.5)
        points_plot = sns.swarmplot(
            x="model",
            y="value",
            hue="subj_id",
            data=df,
            ax=ax,
            # color="white",
            edgecolor="black",
            size=4,
            linewidth=0.8,
        )
        points_plot.get_legend().remove()

        # Calculate mean performance score.
        means = df.groupby("model").mean()
        # Make sure the order follows seaborn's x-axis ordering.
        model_order = list(map(lambda ax: ax.get_text(), axs[i].get_xticklabels()))
        means = means.reindex(model_order)

        # Grab colors corresponding to each model.
        colors = sns.color_palette(None, n_colors=len(df.model.unique()))

        lines = ax.hlines(
            y=means.value,
            xmin=np.arange(0, len(colors)) - 0.5 + 0.05,
            xmax=np.arange(1, len(colors) + 1) - 0.5 - 0.05,
            colors=colors,
            lw=1.5,
        )

        outline_path_effects = [
            mpl.patheffects.Stroke(linewidth=5, foreground="white", alpha=0.9),
            mpl.patheffects.Normal(),
        ]
        lines.set_path_effects(outline_path_effects)

        ax.set_xticklabels(ax.get_xticklabels(), rotation=25)

        fig.canvas.draw()
        ax_format = ax.get_yaxis().get_major_formatter()

        for m, c in zip(means.value, colors):

            ax.annotate(
                f"{m:.4g}",
                xy=(ax.get_xlim()[0] + (ax.get_xlim()[0] * 0.4), m),
                xycoords="data",
                color=c,
                ha="right",
                va="center",
                annotation_clip=False,
                fontweight="bold",
                snap=True,
                bbox=dict(
                    boxstyle="square,pad=0.3", fc="white", lw=0, snap=True, alpha=0.75
                ),
            )
        ax.set_title(f"{l.replace('_', ' ')} {loss_comparison_directions[l]}")
        ax.set_ylabel("")
        ax.set_xlabel("")
plt.savefig(experiment_results_dir / "test_result_within_experiment.png")

#### Comparison with other works

In [None]:
# Plot testing loss values over all patches.
fig, ax_prob = plt.subplots(figsize=(8, 4), dpi=120)
log_scale = False

hist = sns.histplot(
    list(model.plain_log["test_loss"].rmse.values()),
    alpha=0.5,
    stat="count",
    log_scale=log_scale,
    ax=ax_prob,
    legend=False,
    hatch="\\\\",
    ec="blue",
)
hist.yaxis.set_major_locator(mpl.ticker.MaxNLocator(integer=True))
plt.xlabel("Loss in $mm^2/second$")

# Draw means of different comparison models.
comparison_kwargs = {"ls": "-", "lw": 2.5}
# Plot the current DNN model performance.
plt.axvline(
    np.asarray(list(model.plain_log["test_loss"].rmse.values())).mean(),
    label="Current Model Mean",
    color="blue",
    **comparison_kwargs,
)
# Our spline mean performance.
plt.axvline(
    np.asarray(list(model.plain_log.test_loss.spline.rmse.values())).mean(),
    label="(Ours) Spline Mean Order 3",
    color="black",
    **comparison_kwargs,
)
sns.histplot(
    np.asarray(list(model.plain_log.test_loss.spline.rmse.values())),
    alpha=0.5,
    stat="count",
    log_scale=log_scale,
    ax=ax_prob,
    legend=False,
    color="black",
    hatch="//",
)

plt.axvline(
    9.72e-4,
    label="(Tanno etal, 2017 &\n Blumberg etal, 2018)\nESPCN Baseline",
    color="green",
    **comparison_kwargs,
)
plt.axvline(
    9.76e-4,
    label="(Tanno etal, 2017 &\n Blumberg etal, 2018)\nBest ESPCN\n[but not really]",
    color="purple",
    **comparison_kwargs,
)
# Best performing Blumberg, et. al., 2018 paper model.
plt.axvline(
    8.46e-4,
    label="(Blumberg etal, 2018)\nBest Overall",
    color="pink",
    **comparison_kwargs,
)

plt.legend(fontsize="small")
plt.title("Test Loss Histogram Over All Subjects with Test Metric RMSE")
plt.savefig(experiment_results_dir / "test_rmse_hist.png")

In [None]:
# Plot testing loss values over all subjects.
fig, ax = plt.subplots(figsize=(8, 4), dpi=120)

models = (
    "(Ours)\nCurrent Model",
    "(Ours)\nSpline Order 3",
    "(Tanno etal, 2021)\nC-Spline\nMean Approx.",
    "(Tanno etal, 2021)\nRand. Forest\nMean Approx.",
    "(Tanno etal, 2017 &\n Blumberg etal, 2018)\nESPCN Baseline",
    '(Tanno etal, 2017 &\n Blumberg etal, 2018)\n"Best" ESPCN',
    "(Blumberg etal, 2018)\nBest Overall",
)

rmse_bounds = np.asarray(
    [
        [0, 0],
        [0, 0],
        [31.738e-4, 10.069e-4],
        [23.139e-4, 6.974e-4],
        [13.609e-4, 6.212e-4],
        [13.82e-4, 6.29e-4],
        [12.13e-4, 5.58e-4],
    ]
)

rmse_scores = (
    np.asarray(list(model.plain_log["test_loss"].rmse.values())).mean(),
    np.asarray(list(model.plain_log.test_loss.spline.rmse.values())).mean(),
    rmse_bounds[2].mean(),
    rmse_bounds[3].mean(),
    9.72e-4,
    9.76e-4,
    8.46e-4,
)

rmse_score_ranges = np.asarray(rmse_scores)[:, None] - rmse_bounds
rmse_score_ranges[:2] = rmse_score_ranges[:2] * 0
rmse_score_ranges = rmse_score_ranges.T

colors = sns.color_palette("deep", n_colors=len(rmse_scores))

ax.grid(True, axis="y", zorder=1000)
ax.set_axisbelow(True)
# Plot our evaluation scores.
end_idx = 2
ax.bar(
    models[:end_idx],
    rmse_scores[:end_idx],
    color=colors[:end_idx],
    edgecolor="black",
    lw=0.75,
)
for container in ax.containers:
    if isinstance(container, mpl.container.BarContainer):
        ax.bar_label(container, fmt="%.3e")


start_idx = end_idx
# Plot the crazy evaluation scores.
ax.bar(
    models[start_idx:],
    height=rmse_bounds[start_idx:, 0] - rmse_bounds[start_idx:, 1],
    bottom=rmse_bounds[start_idx:, 1],
    color=colors[start_idx:],
    edgecolor="black",
    lw=0.75,
    alpha=0.8,
)
end_idx = start_idx + 2
bar_width = ax.patches[0].get_width()
# Dotted lines for the "approximate" actual rmse score.
ax.hlines(
    rmse_scores[start_idx:end_idx],
    xmin=np.arange(start_idx, end_idx) - bar_width / 2,
    xmax=np.arange(start_idx, end_idx) + bar_width / 2,
    color="black",
    ls="--",
)
for score, x in zip(
    rmse_scores[start_idx:end_idx], np.arange(start_idx, end_idx) - bar_width / 2
):
    ax.annotate(f"{score:.3e}", (x, score + 0.03 * np.asarray(rmse_scores).max()))

start_idx = end_idx
end_idx = len(models)
ax.hlines(
    rmse_scores[start_idx:],
    xmin=np.arange(start_idx, end_idx) - bar_width / 2,
    xmax=np.arange(start_idx, end_idx) + bar_width / 2,
    color="black",
    ls="-",
)

for score, x in zip(
    rmse_scores[start_idx:end_idx], np.arange(start_idx, end_idx) - bar_width / 2
):
    ax.annotate(f"{score:.3e}", (x, score + 0.03 * np.asarray(rmse_scores).max()))

ax.set_ylim(bottom=0, top=ax.get_ylim()[1] * 1.1)
ax.set_xlabel("Model")
ax.set_ylabel("Loss in $mm^2/second$")

ax.set_title("Mean Over Subjects Test Loss RMSE")
ax.set_xticks(models)
ax.set_xticklabels(
    models, fontsize="x-small", rotation=25, ha="right", rotation_mode="anchor"
)
plt.savefig(experiment_results_dir / "test_rmse_bar.png")

In [None]:
sorted_test_idx = np.argsort(np.asarray(list(model.plain_log.test_loss.rmse.values())))
sorted_test_results = dict(
    list(model.plain_log.test_loss.rmse.items())[i] for i in sorted_test_idx
)
ppr(sorted_test_results, sort_dicts=False)
print(np.mean(list(list(model.plain_log.test_loss.rmse.values()))))

In [None]:
diqt_results = test_results.loc[test_results.model == "diqt"]

logger.add_histogram(
    "test/rmse_dist", np.asarray(diqt_results.loc[diqt_results.metric == "rmse"].value)
)
logger.add_histogram(
    "test/nrmse_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "nrmse"].value),
)
logger.add_histogram(
    "test/scaled_psnr_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "scaled_psnr"].value),
)
logger.add_histogram(
    "test/scaled_ssim_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "scaled_ssim"].value),
)
logger.add_histogram(
    "test/scaled_ms_ssim_dist",
    np.asarray(diqt_results.loc[diqt_results.metric == "scaled_ms_ssim"].value),
)

## Whole-Volume Visualization

### Setup

In [None]:
# Debug flag(s)
disable_fig_save = False

In [None]:
# Create full 3D volumes of full-res ground truth, low-res downsample, full-res mask,
# anatomical image, and full-res predictions.

results_viz = Box(default_box=True)
with torch.no_grad():

    for subj in test_dataset:
        # Index into the only item in the subject dataset.
        s = Box(default_box=True)
        subj_id = subj["subj_id"]
        print(f"Starting subject {subj_id}")

        # Collect all variants of the volume and aggregate into one container object.
        s.mask = model.net.crop_full_output(subj["mask"])
        s.dti = model.net.crop_full_output(subj["raw_dti"])
        s.affine = subj["dti_meta_dict"]["affine"]
        s.lr_dti = subj["lr_dti"]
        s[params.data.anat_type] = model.net.crop_full_output(
            subj[params.data.anat_type]
        )

        s.pred = model.plain_log.viz.test_preds[subj_id]
        s.spline_pred = model.plain_log.viz.spline_preds[subj_id]

        if params.data.dti_scale_range:
            # s.dti = subj["dti_scaler"].descale(s.dti)
            s.lr_dti = subj["lr_dti_scaler"].descale(s.lr_dti)
            # Already descaled prediction in the test_step().
            # s.pred = subj["dti_scaler"].descale(s.pred)
        if params.data.anat_scale_range:
            s[params.data.anat_type] = subj[params.data.anat_type + "_scaler"].descale(
                s[params.data.anat_type]
            )

        s.dti = s.dti * s.mask
        s.pred = s.pred * s.mask
        s.spline_pred = s.spline_pred * s.mask
        s.metrics.rmse = model.plain_log.test_loss.rmse[subj_id]
        s.abs_error = torch.abs(s.pred - s.dti)

        for k in {
            "mask",
            "dti",
            "lr_dti",
            params.data.anat_type,
            "pred",
            "spline_pred",
            "abs_error",
        }:
            v = s[k]
            if torch.is_tensor(v):
                v = v.detach().cpu().numpy()
            if k == "mask":
                v = v.astype(bool)
            else:
                v = v.astype(float)
            # s[k] = np.rot90(v)
            s[k] = v

        results_viz[subj_id] = s
        print(f"Finished subject {subj_id}")

In [None]:
# pitn.viz.plot_vol_slices(
#     np.rot90(results_viz['303624'].dti[0], k=1)
#     , colorbars='each')

In [None]:
# Save out all network predictions to Nifti2 files and compress them into a zip archive.
if not disable_fig_save:
    img_names = list()
    for subj_id, viz in results_viz.items():
        nib_img = nib.Nifti2Image(viz.dti, viz.affine)

        filename = experiment_results_dir / f"{subj_id}_predicted_dti.nii.gz"
        nib.save(nib_img, str(filename))
        img_names.append(filename)

    with zipfile.ZipFile(experiment_results_dir / "predicted_dti.zip", "w") as fzip:
        for filename in img_names:
            fzip.write(
                filename,
                arcname=filename.name,
                compress_type=zipfile.ZIP_DEFLATED,
                compresslevel=6,
            )
            os.remove(filename)
    # Make sure we exit the 'with' statement above.
    print("Done with files")

In [None]:
# Pick the worst performing subject from the test set to visualize.
sel_rmse = test_results.loc[
    (test_results.model == "diqt") & (test_results.metric == "rmse")
][["subj_id", "value"]]
sel_rmse = sel_rmse.sort_values("value")
# Or 2nd worst performing...
bad_rmse = sel_rmse.iloc[-2]
viz_subj_id = bad_rmse.subj_id
print(viz_subj_id, bad_rmse.value)
viz_subj = results_viz[viz_subj_id]

In [None]:
# Select indices for visualizing.
dti_shape = np.asarray(viz_subj.dti.shape[1:])
lr_dti_shape = np.asarray(viz_subj.lr_dti.shape[1:])

viz_idx = dti_shape // 2
# Last dimension (saggital) shouldn't be exactly centered, as the longitudinal fissure
# doesn't have many fibers beyond the corpus collosum.
viz_idx[2] = viz_idx[2] + 6
viz_lr_idx = lr_dti_shape // 2
viz_lr_idx[2] = viz_lr_idx[2] + 6 // params.data.downsampled_by_factor

viz_slice_idx = [
    np.s_[:, viz_idx[0], :, :],
    np.s_[:, :, viz_idx[1], :],
    np.s_[:, :, :, viz_idx[2]],
]

viz_lr_slice_idx = [
    np.s_[:, viz_lr_idx[0], :, :],
    np.s_[:, :, viz_lr_idx[1], :],
    np.s_[:, :, :, viz_lr_idx[2]],
]

### FA-Weighted Direction Maps

In [None]:
ims = list()
for im, slices in zip(
    [viz_subj.dti, viz_subj.pred, viz_subj.spline_pred, viz_subj.lr_dti],
    (viz_slice_idx, viz_slice_idx, viz_slice_idx, viz_lr_slice_idx),
):
    for sl in slices:
        selection = im[sl]
        fa_w = pitn.viz.direction_map(selection, channels_first=True)
        fa_w = fa_w.transpose(1, 2, 0)
        ims.append(fa_w)

dim_labels = [
    "Axial",
    "Coronal",
    "Saggital",
]
vol_labels = ["Ground Truth", "Pred", "Spline Interp.", "LR Input"]

with mpl.rc_context(
    {
        "font.size": 12.0,
        "axes.labelpad": 10.0,
        "figure.autolayout": False,
        "figure.constrained_layout.use": True,
    }
):
    fig = plt.figure(dpi=130, figsize=(7, 7))
    fig = pitn.viz.plot_im_grid(
        ims,
        nrows=len(vol_labels),
        title=f"Subj {viz_subj_id} Prediction Results",
        row_headers=vol_labels,
        col_headers=dim_labels,
        colorbars=None,
        fig=fig,
        interpolation="antialiased",
    )
if not disable_fig_save:
    plt.savefig(experiment_results_dir / "dir_map_pred.png");

### DTI Channel-Wise Visualization

In [None]:
channel_names = [
    "$D_{x,x}$",
    "$D_{x,y}$",
    "$D_{y,y}$",
    "$D_{x,z}$",
    "$D_{y,z}$",
    "$D_{z,z}$",
]

dim_labels = [
    "\nAxial",
    "\nCor.",
    "\nSagg.",
]

dti_names = [
    "FR",
    "Pred",
    "Spline",
    "LR",
    "Abs Err",
]

#### Global Normalization

In [None]:
# Display all 6 DTIs for ground truth, predicted, and root squared error
cmap = "gray"
title = f"DTI Subj {viz_subj_id} Summary"

with mpl.rc_context(
    {
        "font.size": 6.0,
        "axes.labelpad": 10.0,
        "figure.autolayout": False,
        "figure.constrained_layout.use": True,
    }
):
    fig = plt.figure(dpi=150, figsize=(8, 12))
    fig = pitn.viz.plot_vol_slices(
        viz_subj.dti,
        viz_subj.pred,
        viz_subj.spline_pred,
        viz_subj.lr_dti,
        viz_subj.abs_error,
        slice_idx=(None, None, viz_idx[2] / dti_shape[2]),
        title=title,
        vol_labels=dti_names,
        slice_labels=dim_labels,
        channel_labels=channel_names,
        colorbars="col",
        fig=fig,
        cmap=cmap,
        interpolation="antialiased",
    )


if not disable_fig_save:
    plt.savefig(experiment_results_dir / f"DIQT_DTI_sub-{viz_subj_id}_pred_result.png");

#### Channel-Wise Normalization

In [None]:
# # Display all 6 DTIs for ground truth, predicted, low-res input, and root squared error
# # Normalize by index in the DTI coefficients.
# # Reshape and concatenate the dtis in order to compute the quantiles of images with
# # different shapes (e.g., the low-res input patch).

# cmap = "coolwarm"
# dtis = [
#     test_vol_viz[viz_subj_id].dti.fr[(slice(None), *slice_idx)],
#     test_vol_viz[viz_subj_id].dti.lr[(slice(None), *low_res_slice_idx)],
#     test_vol_viz[viz_subj_id].dti.spline[(slice(None), *slice_idx)],
#     test_vol_viz[viz_subj_id].dti.diqt[(slice(None), *slice_idx)],
#     test_vol_viz[viz_subj_id].dti.abs_error[(slice(None), *slice_idx)],
# ]

# nrows = len(dtis)
# ncols = len(channel_names)

# # Don't take the absolute max and min values, as there exist some extreme (e.g., > 3
# # orders of magnitude) outliers. Instead, take some percente quantile.
# max_dtis = np.quantile(
#     np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.95, axis=1
# )
# min_dtis = np.quantile(
#     np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.05, axis=1
# )

# max_dtis = np.max(np.abs([max_dtis, min_dtis]), axis=0)
# min_dtis = -1 * max_dtis

# nrows = len(dtis)
# ncols = len(channel_names)

# fig = plt.figure(figsize=(12, 7), dpi=160)

# grid = mpl.gridspec.GridSpec(
#     nrows,
#     ncols,
#     figure=fig,
#     hspace=0.05,
#     wspace=0.05,
# )

# axs = list()
# max_subplot_height = 0
# for i_row in range(nrows):
#     dti = dtis[i_row]
#     axs_cols = list()

#     for j_col in range(ncols):
#         ax = fig.add_subplot(grid[i_row, j_col])
#         ax.imshow(
#             np.rot90(dti[j_col]),
#             cmap=cmap,
#             interpolation=None,
#             vmin=min_dtis[j_col],
#             vmax=max_dtis[j_col],
#         )
#         if ax.get_subplotspec().is_first_col():
#             ax.set_ylabel(dti_names[i_row], size="small")
#         if ax.get_subplotspec().is_last_row():
#             ax.set_xlabel(channel_names[j_col])

#         # Update highest subplot to put the `suptitle` later on.
#         max_subplot_height = max(
#             max_subplot_height, ax.get_position(original=False).get_points()[1, 1]
#         )
#         ax.set_xticks([])
#         ax.set_yticks([])
#         ax.set_xticklabels([])
#         ax.set_yticklabels([])

#         axs_cols.append(ax)

#     axs.append(axs_cols)

# # Place colorbars on each column.
# for j_col in range(ncols):

#     full_col_ax = [axs[i][j_col] for i in range(nrows)]

#     color_norm = mpl.colors.Normalize(vmin=min_dtis[j_col], vmax=max_dtis[j_col])

#     color_mappable = mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap)
#     cbar = fig.colorbar(
#         color_mappable,
#         ax=full_col_ax,
#         location="top",
#         orientation="horizontal",
#         pad=0.01,
#         shrink=0.85,
#     )
#     cbar.ax.tick_params(labelsize=8, rotation=35)
#     cbar.ax.xaxis.set_major_formatter(mpl.ticker.StrMethodFormatter("{x:g}"))
# #     cbar.ax.ticklabel_format(scilimits=(3, -3), useOffset=False)

# plt.suptitle(
#     "DTI Channel Breakdown, Channel-Wise Normalization",
#     y=max_subplot_height + 0.01,
#     verticalalignment="bottom",
# )
# if not disable_fig_save:
#     plt.savefig(experiment_results_dir / "DTI_channel_sample_channel_wise_norm.png");

---

## End Experiment

In [None]:
pl_logger.experiment.flush()
# Close tensorboard logger.
# Don't finalize if the experiment was for debugging.
if "debug" not in EXPERIMENT_NAME.casefold():
    pl_logger.finalize("success")
    # Experiment is complete, move the results directory to its final location.
    if experiment_results_dir != final_experiment_results_dir:
        print("Moving out of tmp location")
        experiment_results_dir = experiment_results_dir.rename(
            final_experiment_results_dir
        )
        log_txt_file = experiment_results_dir / log_txt_file.name