# Pain in the Net
Application of Deep Image Quality Transfer (DIQT) with domain adaptation.


Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

---

Source work:
`S. B. Blumberg, R. Tanno, I. Kokkinos, and D. C. Alexander, “Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images,” in Medical Image Computing and Computer Assisted Intervention – MICCAI 2018, Cham, 2018, pp. 118–125, doi: 10.1007/978-3-030-00928-1_14.`


## Imports & Environment Setup

### Imports

In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 1

# imports
import collections
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
import zipfile

import ants
import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

import IPython

# Try importing GPUtil for printing GPU specs.
# May not be installed if using CPU only.
try:
    import GPUtil
except ImportError:
    warnings.warn("WARNING: Package GPUtil not found, cannot print GPU specs")
from tabulate import tabulate
from IPython.display import display, Markdown
import ipyplot

# Data management libraries.
import nibabel as nib
import natsort
from natsort import natsorted
from addict import Addict
import pprint
from pprint import pprint as ppr

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchio
import pytorch_lightning as pl
import monai

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [None]:
# Project-specific scripts
# It's easier to import it this way rather than make an entirely new package, due to
# conflicts with local packages and anaconda installations.
# You made me do this, poor python package management!!
if "PROJECT_ROOT" in os.environ:
    lib_location = str(Path(os.environ["PROJECT_ROOT"]).resolve())
else:
    lib_location = str(Path("../../").resolve())
if lib_location not in sys.path:
    sys.path.insert(0, lib_location)
import lib as pitn

# Include the top-level lib module along with its submodules.
%aimport lib
# Grab all submodules of lib, not including modules outside of the package.
includes = list(
    filter(
        lambda m: m.startswith("lib."),
        map(lambda x: x[1].__name__, inspect.getmembers(pitn, inspect.ismodule)),
    )
)
# Run aimport magic with constructed includes.
ipy = IPython.get_ipython()
ipy.run_line_magic("aimport", ", ".join(includes))

In [None]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

### Specs Recording

In [None]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():

    # GPU information
    # Taken from
    # <https://www.thepythoncode.com/article/get-hardware-system-information-python>.
    # If GPUtil is not installed, skip this step.
    try:
        gpus = GPUtil.getGPUs()
        print("=" * 50, "GPU Specs", "=" * 50)
        list_gpus = []
        for gpu in gpus:
            # get the GPU id
            gpu_id = gpu.id
            # name of GPU
            gpu_name = gpu.name
            driver_version = gpu.driver
            cuda_version = torch.version.cuda
            # get total memory
            gpu_total_memory = f"{gpu.memoryTotal}MB"
            gpu_uuid = gpu.uuid
            list_gpus.append(
                (
                    gpu_id,
                    gpu_name,
                    driver_version,
                    cuda_version,
                    gpu_total_memory,
                    gpu_uuid,
                )
            )

        print(
            tabulate(
                list_gpus,
                headers=(
                    "id",
                    "Name",
                    "Driver Version",
                    "CUDA Version",
                    "Total Memory",
                    "uuid",
                ),
            )
        )
    except NameError:
        print("CUDA Version: ", torch.version.cuda)

else:
    print("CUDA not in use, falling back to CPU")

In [None]:
# cap is defined in an ipython magic command
print(cap)

Author: Tyler Spears

Last updated: 2021-10-05T14:13:43.997845+00:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.23.1

Compiler    : GCC 7.3.0
OS          : Linux
Release     : 5.4.0-88-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

Git hash: 5123101df9112c32921025ab162ce3c3ee751aa1

seaborn          : 0.11.1
GPUtil           : 1.4.0
json             : 2.0.9
dipy             : 1.4.1
IPython          : 7.23.1
ipywidgets       : 7.6.3
matplotlib       : 3.4.1
pandas           : 1.2.3
scipy            : 1.5.3
natsort          : 7.1.1
torch            : 1.9.0
skimage          : 0.18.1
monai            : 0.7.dev2138
pytorch_lightning: 1.4.5
kornia           : 0.5.8
torchio          : 0.18.37
numpy            : 1.20.2
addict           : 2.4.0
ants             : 0.2.7
sys              : 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0]
nibabel          : 3.2.1

  id  Name       Driver Version      CUDA Version

### Data Variables & Definitions Setup

In [None]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"]) / "hcp"
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"]) / "hcp"
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

### Experiment Logging Setup

In [None]:
# tensorboard experiment logging setup.
EXPERIMENT_NAME = "debug_dti_nifti_saving"

ts = datetime.datetime.now().replace(microsecond=0).isoformat()
# Break ISO format because many programs don't like having colons ':' in a filename.
ts = ts.replace(":", "_")
experiment_name = ts + "__" + EXPERIMENT_NAME
run_name = experiment_name
# experiment_results_dir = results_dir / experiment_name

# Create temporary directory for results directory, in case experiment does not finish.
tmp_dirs = list(tmp_results_dir.glob("*"))

# Only keep up to N tmp results.
n_tmp_to_keep = 3
if len(tmp_dirs) > (n_tmp_to_keep - 1):
    print(f"More than {n_tmp_to_keep} temporary results, culling to the most recent")
    tmps_to_delete = natsorted([str(tmp_dir) for tmp_dir in tmp_dirs])[
        : -(n_tmp_to_keep - 1)
    ]
    for tmp_dir in tmps_to_delete:
        shutil.rmtree(tmp_dir)
        print("Deleted temporary results directory ", tmp_dir)

experiment_results_dir = tmp_results_dir / experiment_name
# Final target directory, to be made when experiment is complete.
final_experiment_results_dir = results_dir / experiment_name

In [None]:
print(experiment_name)

In [None]:
# Pass this object into the pytorchlightning Trainer object, for easier logging within
# the training/testing loops.
pl_logger = pl.loggers.TensorBoardLogger(
    tmp_results_dir,
    name=experiment_name,
    version="",
    log_graph=False,
    default_hp_metric=False,
)
# Use the lower-level logger for logging histograms, images, etc.
logger = pl_logger.experiment

# Create a separate txt file to log streams of events & info besides parameters & results.
log_txt_file = Path(logger.log_dir) / "log.txt"
with open(log_txt_file, "a+") as f:
    f.write(f"Experiment Name: {experiment_name}\n")
    f.write(f"Timestamp: {ts}\n")
    # cap is defined in an ipython magic command
    f.write(f"Environment and Hardware Info:\n {cap}\n\n")

## Parameters and Function Definitions

### Parameters

In [None]:
# Dict to keep track of experiment configuration parameters. Will not be logged to
# tensorboard.
exp_params = Addict()
# Dict to keep track of tensorboard hparams that we *specifically* want to compare
# between runs.
compare_hparams = Addict(hparam=Addict(), metric=Addict())

In [None]:
# Voxel sizes for the mean downsampling.
source_vox_size = 1.25
target_vox_size = 2.0

downsample_factor = target_vox_size / source_vox_size
# Include b=0 shells and b=1000 shells for DTI fitting.
bval_range = (0, 1500)
dti_fit_method = "WLS"
exp_params.update(
    {
        "downsample_factor": downsample_factor,
        "source_vox_size": source_vox_size,
        "target_vox_size": target_vox_size,
        "bval_range": bval_range,
        "dti_fit_method": dti_fit_method,
    }
)

#### Patch Parameters

In [None]:
# Patch parameters
batch_size = 12
# 6 channels for the 6 DTI components
channels = 6

# Output patch shapes
h_out = 14
w_out = 14
d_out = 14

# This is the factor that determines how over-extended the input patch should be
# relative to the size of the full-res patch.
# $low_res_patch_dim = \frac{full_res_patch_dim}{downsample_factor} \times low_res_sample_extension$
# A value of 1 indicates that the input patch dims will be exactly divided by the
# downsample factor. A dilation > 1 increases the "spatial extent" of the input
# patch, providing information outside of the target HR patch.
low_res_sample_extension = 1.556

# Output shape after shuffling.
output_patch_shape = (channels, h_out, w_out, d_out)
output_spatial_patch_shape = output_patch_shape[1:]

# Two-step upsampling?
two_step_upsample = True

if two_step_upsample:
    diqt_downsample_factor = np.ceil(downsample_factor).astype(int)
    interp_out_vox_size = source_vox_size
    interp_in_vox_size = target_vox_size / diqt_downsample_factor

    interp_downsample_factor = interp_in_vox_size / interp_out_vox_size

    # Input size to non-integer interpolated downsampleing.
    interp_h_in = np.ceil(h_out / interp_downsample_factor)
    interp_w_in = np.ceil(w_out / interp_downsample_factor)
    interp_d_in = np.ceil(d_out / interp_downsample_factor)

    # Input shape to ESPCN shuffling.
    shuffle_h_in = interp_h_in / diqt_downsample_factor
    shuffle_w_in = interp_w_in / diqt_downsample_factor
    shuffle_d_in = interp_d_in / diqt_downsample_factor

    # DIQT network input patch shapes
    h_in = round(shuffle_h_in * low_res_sample_extension)
    w_in = round(shuffle_w_in * low_res_sample_extension)
    d_in = round(shuffle_d_in * low_res_sample_extension)
    input_patch_shape = (channels, h_in, w_in, d_in)
    input_spatial_patch_shape = input_patch_shape[1:]

    # Patch size in FR-space when accounting for the low-res over-extension/over-sampling
    # factor.
    fr_extension_patch_size = tuple(
        np.asarray(input_spatial_patch_shape) * downsample_factor
    )
    fr_extension_amount = tuple(
        np.ceil(
            np.asarray(fr_extension_patch_size) - np.asarray(output_spatial_patch_shape)
        ).astype(int)
    )
else:
    # Input patch parameters
    h_in = round(h_out / (downsample_factor) * low_res_sample_extension)
    w_in = round(w_out / (downsample_factor) * low_res_sample_extension)
    d_in = round(d_out / (downsample_factor) * low_res_sample_extension)
    input_patch_shape = (channels, h_in, w_in, d_in)
    input_spatial_patch_shape = input_patch_shape[1:]

    # Pre-shuffle output patch sizes.
    rounded_downsample_factor = int(np.ceil(downsample_factor))
    unshuffled_channels_out = int(channels * rounded_downsample_factor**3)
    # Output before shuffling
    unshuffled_output_patch_shape = (
        unshuffled_channels_out,
        h_out // rounded_downsample_factor,
        w_out // rounded_downsample_factor,
        d_out // rounded_downsample_factor,
    )

    # Patch size in FR-space when accounting for the low-res over-extension/over-sampling
    # factor.
    fr_extension_patch_size = tuple(
        np.ceil(np.asarray(input_spatial_patch_shape) * downsample_factor).astype(int)
    )
    fr_extension_amount = tuple(
        np.asarray(fr_extension_patch_size) - np.asarray(output_spatial_patch_shape)
    )

In [None]:
exp_params.patch.update(
    batch_size=batch_size,
    channels=channels,
    low_res_sample_extension=low_res_sample_extension,
    input_shape=input_patch_shape,
    output_shape=output_patch_shape,
    two_step_upsample=two_step_upsample,
)

#### Data Parameters

In [None]:
# Data parameters.
num_subject_samples = 16
# Should the data be normalized as a pre-processing step?
# Can be:
# { None, "channels" }
data_norm_method = "channels"
# Quantile range of DTI channels that will have voxel intensities clamped.
# Only has an effect if using data normalization.

# In other words, for each subject, for each channel, for each vol in {FR, LR}, any voxel
# values <= the first quantile value will be clamped to that quantile, and any voxels
# >= the second quantile value will be clamped to that quantile.
clamp_quantiles = (0.0001, 0.9999)

In [None]:
exp_params.data.update(
    num_subject=num_subject_samples,
    data_norm_method=data_norm_method,
    clamp_quantiles=clamp_quantiles,
)

#### Training and Testing Setup

In [None]:
# Training/testing parameters.
# Percentages will be rounded off to the nearest subject, with the test and validation
# sizes rounded *up*, ensuring at least 1 subject in each.
test_percent = 0.4
val_percent = 0.2
train_percent = 1 - (test_percent + val_percent)

# NN parameters.
max_epochs = 200
network_norm_method = None
train_loss_name = "mse"

# Optimization parameters.
opt_name = "Adam"
opt_params = {"lr": 7e-4, "betas": (0.9, 0.999)}

# Spline interpolation baseline parameters.
spline_interp_order = 3

In [None]:
# Number of voxels to dilate the mask in FR space.
# Just make it 0...
dilation_size = 0

In [None]:
exp_params.update(test_percent=test_percent)
exp_params.train.update(
    train_percent=train_percent, max_epochs=max_epochs, train_loss_name=train_loss_name
)
exp_params.nn.update(network_norm_method=network_norm_method)
exp_params.opt.update(opt_params)
exp_params.opt.name = opt_name
exp_params.spline.update(order=spline_interp_order)
exp_params.preproc.update(dilation_size=dilation_size)

In [None]:
with open(log_txt_file, "a+") as f:
    f.write(pprint.pformat(exp_params) + "\n")

### Function Definitions

In [None]:
# Utility functions
def patch_center(
    patch: torch.Tensor, sub_sample_strategy="lower", keepdim=False
) -> torch.Tensor:
    """Extract 3D multi-channel patch center.

    Expects patch of shape '[B x C x] W x H x D'

    sub_sample_strategy: str
        Strategy for handling center coordinates of even-sized dimensions.
        Options:
            Strategies over indices:
                'lower': Take the voxel to the left of the center.
                'upper': Take the voxel to the right of the center.

            Strategies over multi-dimensional voxels:
                'max': Take the max of all center voxels.
                'min': Take the minimum of all center voxels.
                'mean': Take the average of all center voxels.
                'agg': Don't reduce at all, and return the center voxels.
    """
    strategy_fn = {
        "idx_fns": {
            "lower".casefold(): lambda i: int(i),
            "upper".casefold(): lambda i: int(i) + 1,
        },
        "vox_fns": {
            "max".casefold(): lambda p: torch.amax(p, dim=(-3, -2, -1), keepdim=True),
            "min".casefold(): lambda p: torch.amin(p, dim=(-3, -2, -1), keepdim=True),
            "mean".casefold(): lambda p: p.mean(dim=(-3, -2, -1), keepdim=True),
            "agg".casefold(): lambda p: p,
        },
    }

    strat = sub_sample_strategy.casefold()
    if (strat not in strategy_fn["idx_fns"].keys()) and (
        strat not in strategy_fn["vox_fns"].keys()
    ):
        raise ValueError(
            f"ERROR: Invalid strategy; got {sub_sample_strategy}, expected one of"
            + f"{list(strategy_fn['idx_fns'].keys()) + list(strategy_fn['vox_fns'].keys())}"
        )
    patch_spatial_shape = patch.shape[-3:]
    centers = torch.as_tensor(patch_spatial_shape) / 2
    center = list()
    for dim in centers:
        if int(dim) != dim:
            dim = slice(int(math.floor(dim)), int(math.ceil(dim)))
        elif strat in strategy_fn["idx_fns"].keys():
            dim = int(strategy_fn["idx_fns"][strat](dim))
            dim = slice(dim, dim + 1)
        elif strat in strategy_fn["vox_fns"].keys():
            dim = slice(int(dim), int(dim) + 2)
        else:
            raise ValueError("ERROR: Invalid strategy")
        center.append(dim)

    center_patches = patch[..., center[0], center[1], center[2]]

    if (
        center_patches.shape[-3:] != (1, 1, 1)
        and strat in strategy_fn["vox_fns"].keys()
    ):
        center_patches = strategy_fn["vox_fns"][strat](center_patches)

    if not keepdim:
        center_patches = center_patches.squeeze()

    return center_patches

In [None]:
BoxplotStats = collections.namedtuple(
    "BoxplotStats",
    ["low_outliers", "low", "q1", "median", "q3", "high", "high_outliers"],
)


def batch_boxplot_stats(batch):
    """Quick calculation of a batch of 1D values for showing boxplot stats."""
    q1, median, q3 = np.quantile(batch, q=[0.25, 0.5, 0.75], axis=1)
    iqr = q3 - q1
    low = q1 - (1.5 * iqr)
    high = q3 + (1.5 * iqr)
    low_outliers = list()
    high_outliers = list()
    # Number of outliers may be different for each batch, so it needs to be a list of
    # arrays.
    for i_batch in range(len(batch)):
        batch_i = batch[i_batch]
        low_i = low[i_batch]
        low_outliers.append(batch_i[np.where(batch_i < low_i)])
        high_i = high[i_batch]
        high_outliers.append(batch_i[np.where(batch_i > high_i)])

    return BoxplotStats(low_outliers, low, q1, median, q3, high, high_outliers)

In [None]:
# Quick check on full volume/batch distributions.


def desc_channel_dists(vols, mask=None):
    t_vols = torch.as_tensor(vols)

    if t_vols.ndim == 4:
        t_vols = t_vols[None, ...]

    if mask is not None:
        t_mask = torch.as_tensor(mask)
        if mask.ndim == 4:
            mask = mask[0]
    else:
        t_mask = torch.ones_like(t_vols[0, 0]).bool()

    results = "means | vars\n"
    for t_vol_i in t_vols:
        masked_vol = torch.masked_select(t_vol_i, t_mask).reshape(t_vol_i.shape[0], -1)
        mean_i = torch.mean(masked_vol, dim=1)
        var_i = torch.var(masked_vol, dim=1)
        mvs = [
            (f"{mv[0]} | {mv[1]}\n")
            for mv in torch.stack([mean_i, var_i], dim=-1).tolist()
        ]
        results = results + "".join(mvs)
        results = results + ("=" * (len(mvs[-1]) - 1)) + "\n"

    return results

In [None]:
def plot_dti_box_row(
    fig,
    grid,
    row_idx: int,
    subj_id: int,
    shared_axs_rows: list,
    shared_axs_cols: list,
    fr_vol: np.ndarray,
    lr_vol: np.ndarray,
    colors: list = list(sns.color_palette("Set2", n_colors=2)),
):

    dti_channel_names = [
        "$D_{xx}$",
        "$D_{xy}$",
        "$D_{yy}$",
        "$D_{xz}$",
        "$D_{yz}$",
        "$D_{zz}$",
    ]

    for i_channel, channel_name in enumerate(dti_channel_names):
        cell = grid[row_idx, i_channel]

        ax = fig.add_subplot(
            cell,
            sharex=shared_axs_cols[channel_name],
            sharey=shared_axs_rows[subj_id],
        )
        if shared_axs_cols[channel_name] is None:
            shared_axs_cols[channel_name] = ax
        if shared_axs_rows[subj_id] is None:
            shared_axs_rows[subj_id] = ax

        #         quantile_outlier_cutoff = (0.1, 0.9)
        fr_channel = fr_vol[i_channel]
        #         fr_nn = fr_nn[
        #             (np.quantile(fr_nn, quantile_outlier_cutoff[0]) <= fr_nn)
        #             & (fr_nn <= np.quantile(fr_nn, quantile_outlier_cutoff[1]))
        #         ]
        lr_channel = lr_vol[i_channel]
        #         lr_nn = lr_nn[
        #             (np.quantile(lr_nn, quantile_outlier_cutoff[0]) <= lr_nn)
        #             & (lr_nn <= np.quantile(lr_nn, quantile_outlier_cutoff[1]))
        #         ]
        #         fr_norm = normed_fr_vol[i_channel].detach().cpu().numpy()
        #         lr_norm = normed_lr_vol[i_channel].detach().cpu().numpy()

        num_fr_vox = len(fr_channel)
        num_lr_vox = len(lr_channel)

        resolution_labels = (["FR",] * num_fr_vox) + (
            [
                "LR",
            ]
            * num_lr_vox
        )

        df = pd.DataFrame(
            {
                "data": np.concatenate([fr_channel, lr_channel]),
                "resolution": resolution_labels,
            }
        )

        sns.boxenplot(
            data=df,
            y="resolution",
            x="data",
            orient="h",
            ax=ax,
            palette=colors,
            k_depth="proportion",
            outlier_prop=0.11,
            showfliers=False,
        )

        if not cell.is_last_row():
            plt.setp(ax.get_xticklabels(), visible=False)
        else:
            plt.setp(ax.get_xticklabels(), fontsize="x-small", rotation=25)

        if not cell.is_first_col():
            plt.setp(ax.get_yticklabels(), visible=False)
            ax.set_ylabel("")
        else:
            ax.set_ylabel(subj_id)

        ax.set_xlabel("")
        if cell.is_first_row():
            ax.set_title(channel_name)

    return fig, shared_axs_rows, shared_axs_cols

## Data Loading

### Subject ID Selection

In [None]:
# Find data directories for each subject.
subj_dirs: dict = dict()

selected_ids = [
    "397154",
    "224022",
    "140117",
    "751348",
    "894774",
    "156637",
    "227432",
    "303624",
    "185947",
    "810439",
    "753251",
    "644246",
    "141422",
    "135528",
    "103010",
    "700634",
]

## Sub-set the chosen participants for dev and debugging!
selected_ids = random.sample(selected_ids, num_subject_samples)
warnings.warn(
    "WARNING: Sub-selecting participants for dev and debugging. "
    + f"Subj IDs selected: {selected_ids}"
)
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_ids = natsorted(list(map(lambda s: int(s), selected_ids)))

for subj_id in selected_ids:
    subj_dirs[subj_id] = data_dir / f"{subj_id}/T1w/Diffusion"
    assert subj_dirs[subj_id].exists()
subj_dirs

The 90 scans are taken from the $b=1000 \ s/mm^2$. However, the $b=0$ shells are still required for fitting the diffusion tensors (DTI's), so those will need to be kept, too.

To find those, sub-select with the $0 < bvals < 1500$, or roughly thereabout. A b-val of $995$ or $1005$ still counts as a b=1000.

In [None]:
with open(log_txt_file, "a+") as f:
    f.write(f"Selected Subjects: {selected_ids}\n")

logger.add_text("subjs", pprint.pformat(selected_ids))

### Loading and Preprocessing

In [None]:
# Set up the transformation pipeline.

preproc_transforms = torchio.Compose(
    [
        torchio.transforms.ToCanonical(include=("dwi", "brain_mask"), copy=False),
        pitn.transforms.BValSelectionTransform(
            bval_range=bval_range,
            bval_key="bvals",
            bvec_key="bvecs",
            include="dwi",
            copy=False,
        ),
        # Pad by the dilation factor, then dilate the mask.
        torchio.transforms.Pad(
            dilation_size,
            padding_mode=0,
            include=("dwi", "brain_mask"),
            copy=False,
        ),
        # Pad by the amount of extension voxels in FR space, so LR indices cannot
        # go out of bounds.
        torchio.transforms.Pad(
            fr_extension_amount,
            padding_mode=0,
            include=("dwi", "brain_mask"),
            copy=False,
        ),
        pitn.transforms.FractionalMeanDownsampleTransform(
            source_vox_size=exp_params.source_vox_size,
            target_vox_size=exp_params.target_vox_size,
            include=("dwi", "brain_mask"),
            keep={"dwi": "fr_dwi", "brain_mask": "fr_brain_mask"},
            copy=False,
        ),
        pitn.transforms.RenameImageTransform(
            {"dwi": "lr_dwi", "brain_mask": "lr_brain_mask"}, copy=False
        ),
        pitn.transforms.FitDTITransform(
            "bvals",
            "bvecs",
            "fr_brain_mask",
            fit_method=dti_fit_method,
            include=("fr_dwi"),
            #             cache_dir="./.cache",
            copy=False,
        ),
        pitn.transforms.FitDTITransform(
            "bvals",
            "bvecs",
            "lr_brain_mask",
            fit_method=dti_fit_method,
            include=("lr_dwi"),
            #             cache_dir="./.cache",
            copy=False,
        ),
        pitn.transforms.RenameImageTransform(
            {"fr_dwi": "gt_dti", "lr_dwi": "lr_dti"}, copy=False
        ),
        pitn.transforms.ImageToDictTransform(
            include=("lr_dti", "lr_brain_mask"), copy=False
        ),
    ]
)

#### Data Loading & Feature Creation

In [None]:
# ## DEBUG fractional downsampling


# c = torchio.Compose(preproc_transforms.transforms[:7])
# t = list()
# for subj_id, subj_dir in subj_dirs.items():

#     # Sub-select volumes with only bvals in a certain range. E.x. bvals <= 1100 mm/s^2,
#     # a.k.a. only the b=0 and b=1000 shells.
#     bvals = torch.as_tensor(np.loadtxt(subj_dir / "bvals").astype(int))
#     bvecs = torch.as_tensor(np.loadtxt(subj_dir / "bvecs"))
#     # Reshape to be N x 3
#     if bvecs.shape[0] == 3:
#         bvecs = bvecs.T

#     brain_mask = torchio.LabelMap(
#         subj_dir / "nodif_brain_mask.nii.gz",
#         type=torchio.LABEL,
#         channels_last=False,
#     )
#     brain_mask.set_data(brain_mask.data.bool())
#     mask_volume = brain_mask["data"].sum()

#     dwi = torchio.ScalarImage(
#         subj_dir / "data.nii.gz",
#         type=torchio.INTENSITY,
#         bvals=bvals,
#         bvecs=bvecs,
#         reader=pitn.io.nifti_reader,
#         channels_last=True,
#     )

#     subject_dict = torchio.Subject(subj_id=subj_id, dwi=dwi, brain_mask=brain_mask)

#     preproc_subj = c(subject_dict)
#     t.append(preproc_subj)

In [None]:
summary_stats_fmt = "github"
summary_stats_header = [
    "Subj ID",
    "Resolution",
    "Channel Index",
    "Mean",
    "Var",
    "Num Outliers (Lower)",
    "Low",
    "25th Percentile",
    "Median",
    "75th Percentile",
    "High",
    "Num Outliers (Upper)",
]
dti_channel_names = ["D xx", "D xy", "D yy", "D xz", "D yz", "D zz"]
n_subj = len(subj_dirs.keys())
colors = list(sns.color_palette("Set2", n_colors=2))

In [None]:
# Import all image data into a sequence of `torchio.Subject` objects.
subj_data: dict = dict()

subj_stats = dict()
# Grid plot for displaying enhanced boxplots of subject DTI image intensities.
plt.clf()

fig_non_norm = plt.figure(dpi=120, figsize=(6 * 2.5, 1.5 * n_subj))
fig_non_norm.suptitle("Non-Normalized Subject DTI Distributions")
grid_non_norm = mpl.gridspec.GridSpec(
    nrows=len(subj_dirs.keys()), ncols=len(dti_channel_names), figure=fig_non_norm
)

share_axs_non_norm = collections.defaultdict(
    lambda: None,
    row=collections.defaultdict(lambda: None),
    col=collections.defaultdict(lambda: None),
)

for k in summary_stats_header:
    subj_stats[k] = list()

for subj_id, subj_dir in subj_dirs.items():

    # Sub-select volumes with only bvals in a certain range. E.x. bvals <= 1100 mm/s^2,
    # a.k.a. only the b=0 and b=1000 shells.
    bvals = torch.as_tensor(np.loadtxt(subj_dir / "bvals").astype(int))
    bvecs = torch.as_tensor(np.loadtxt(subj_dir / "bvecs"))
    # Reshape to be N x 3
    if bvecs.shape[0] == 3:
        bvecs = bvecs.T

    brain_mask = torchio.LabelMap(
        subj_dir / "nodif_brain_mask.nii.gz",
        type=torchio.LABEL,
        channels_last=False,
    )
    brain_mask.set_data(brain_mask.data.bool())
    mask_volume = brain_mask["data"].sum()

    dwi = torchio.ScalarImage(
        subj_dir / "data.nii.gz",
        type=torchio.INTENSITY,
        bvals=bvals,
        bvecs=bvecs,
        reader=pitn.io.nifti_reader,
        channels_last=True,
    )

    subject_dict = torchio.Subject(subj_id=subj_id, dwi=dwi, brain_mask=brain_mask)
    preproc_subj = preproc_transforms(subject_dict)

    # Print and log some statistics of the subject data.
    # Process FR (ground truth) volumes.
    # Grab FR voxels within the mask.
    # The normalized target is named "fr_dti", whereas the source/ground truth DTI
    # (without normalization or clamping) is named "gt_dti".
    fr_vol = preproc_subj.gt_dti.tensor
    fr_mask = preproc_subj.fr_brain_mask.tensor.bool()
    masked_fr_vol = torch.masked_select(fr_vol, fr_mask).reshape(fr_vol.shape[0], -1)
    # Estimate FR means and vars.
    fr_channel_means = masked_fr_vol.mean(dim=1)
    fr_channel_means = fr_channel_means.reshape(-1, 1, 1, 1)
    fr_channel_vars = masked_fr_vol.var(dim=1)
    fr_channel_vars = fr_channel_vars.reshape(-1, 1, 1, 1)

    # Add FR stats to the summary table.
    fr_vol_stats = batch_boxplot_stats(masked_fr_vol)
    subj_stats["Subj ID"].extend(
        list(itertools.repeat(subj_id, len(fr_vol_stats.median)))
    )
    subj_stats["Resolution"].extend(
        list(itertools.repeat(tuple(fr_vol.shape[1:]), len(fr_vol_stats.median)))
    )
    subj_stats["Channel Index"].extend(dti_channel_names)
    subj_stats["Mean"].extend(fr_channel_means.cpu().flatten().tolist())
    subj_stats["Var"].extend(fr_channel_vars.cpu().flatten().tolist())
    # Append FR boxplot stats to their corresponding fields. In other words, all columns
    # after the var column.
    for i, field in enumerate(
        summary_stats_header[(summary_stats_header.index("Var") + 1) :]
    ):
        if "Num" in field:
            subj_stats[field].extend(list(map(len, fr_vol_stats[i])))
        else:
            subj_stats[field].extend(fr_vol_stats[i])

    # Process LR volumes
    # Grab LR voxels within the mask.
    lr_vol = preproc_subj.lr_dti["data"]
    lr_mask = preproc_subj.lr_brain_mask["data"].bool()
    masked_lr_vol = torch.masked_select(lr_vol, lr_mask).reshape(lr_vol.shape[0], -1)

    lr_channel_means = masked_lr_vol.mean(dim=1)
    lr_channel_means = lr_channel_means.reshape(-1, 1, 1, 1)
    lr_channel_vars = masked_lr_vol.var(dim=1)
    lr_channel_vars = lr_channel_vars.reshape(-1, 1, 1, 1)

    # Add LR stats to the summary table.
    lr_vol_stats = batch_boxplot_stats(masked_lr_vol)
    subj_stats["Subj ID"].extend(
        list(itertools.repeat(subj_id, len(fr_vol_stats.median)))
    )
    subj_stats["Resolution"].extend(
        list(itertools.repeat(tuple(lr_vol.shape[1:]), len(lr_vol_stats.median)))
    )
    subj_stats["Channel Index"].extend(dti_channel_names)
    subj_stats["Mean"].extend(lr_channel_means.cpu().flatten().tolist())
    subj_stats["Var"].extend(lr_channel_vars.cpu().flatten().tolist())
    # Append LR boxplot stats to their corresponding fields. In other words, all columns
    # after the var column.
    for i, field in enumerate(
        summary_stats_header[(summary_stats_header.index("Var") + 1) :]
    ):
        if "Num" in field:
            subj_stats[field].extend(list(map(len, lr_vol_stats[i])))
        else:
            subj_stats[field].extend(lr_vol_stats[i])

    subj_idx = list(subj_dirs.keys()).index(subj_id)
    # Generate row of boxplots for the non-normalized volumes.
    fig_non_norm, share_axs_rows, share_axs_cols = plot_dti_box_row(
        fig_non_norm,
        grid_non_norm,
        row_idx=subj_idx,
        subj_id=subj_id,
        shared_axs_rows=share_axs_non_norm["row"],
        shared_axs_cols=share_axs_non_norm["col"],
        fr_vol=masked_fr_vol.detach().cpu().numpy(),
        lr_vol=masked_lr_vol.detach().cpu().numpy(),
    )
    share_axs_non_norm["row"] = share_axs_rows
    share_axs_non_norm["col"] = share_axs_cols

    # Finalize this subject.
    subj_data[subj_id] = preproc_subj

    print("=" * 20)


print("===Data Loaded & Transformed===")

subj_stats_str = tabulate(
    subj_stats, headers=summary_stats_header, tablefmt=summary_stats_fmt
)

In [None]:
# fig_non_norm

In [None]:
# ### !Temporary! Will be removed after fitting UVA DTI.
# # Import all image data into a sequence of `torchio.Subject` objects.
# subj_data: dict = dict()

# subj_stats = dict()
# # Grid plot for displaying enhanced boxplots of subject DTI image intensities.
# plt.clf()

# fig_non_norm = plt.figure(dpi=120, figsize=(6 * 2.5, 1.5 * n_subj))
# fig_non_norm.suptitle("Non-Normalized Subject DTI Distributions")
# grid_non_norm = mpl.gridspec.GridSpec(
#     nrows=len(subj_dirs.keys()), ncols=len(dti_channel_names), figure=fig_non_norm
# )

# share_axs_non_norm = collections.defaultdict(
#     lambda: None,
#     row=collections.defaultdict(lambda: None),
#     col=collections.defaultdict(lambda: None),
# )

# for k in summary_stats_header:
#     subj_stats[k] = list()

# subj_dirs = {"001": pathlib.Path(os.environ["DATA_DIR"]) / "uva/001"}

# for subj_id, subj_dir in subj_dirs.items():

#     # Sub-select volumes with only bvals in a certain range. E.x. bvals <= 1100 mm/s^2,
#     # a.k.a. only the b=0 and b=1000 shells.
#     bvals = torch.as_tensor(np.loadtxt(subj_dir / "sub-001_ses-01_run-2_dwi.bval").astype(int))
#     bvecs = torch.as_tensor(np.loadtxt(subj_dir / "sub-001_ses-01_run-2_dwi.bvec"))
#     # Reshape to be N x 3
#     if bvecs.shape[0] == 3:
#         bvecs = bvecs.T

#     brain_mask = torchio.LabelMap(
#         Path(
#             "/srv/tmp/pitn/uva/chronic_pain_head_and_neck/derivatives/diffusion/sub-001/ses-01/mask/sub-001_space-orig_ses-01_run-2_mask.nii.gz"
#         ),
#         #         subj_dir / "nodif_brain_mask.nii.gz",
#         type=torchio.LABEL,
#         channels_last=False,
#     )
#     brain_mask.set_data(brain_mask.data.bool())
#     mask_volume = brain_mask["data"].sum()

#     dwi = torchio.ScalarImage(
#         subj_dir / "sub-001_ses-01_run-2_dwi_epi.nii.gz",
#         type=torchio.INTENSITY,
#         bvals=bvals,
#         bvecs=bvecs,
#         reader=pitn.io.nifti_reader,
#         channels_last=True,
#     )

#     subject_dict = torchio.Subject(subj_id=subj_id, dwi=dwi, brain_mask=brain_mask)
#     preproc_subj = preproc_transforms(subject_dict)

#     # Print and log some statistics of the subject data.
#     # Process FR (ground truth) volumes.
#     # Grab FR voxels within the mask.
#     # The normalized target is named "fr_dti", whereas the source/ground truth DTI
#     # (without normalization or clamping) is named "gt_dti".
#     fr_vol = preproc_subj.gt_dti.tensor
#     fr_mask = preproc_subj.fr_brain_mask.tensor.bool()
#     masked_fr_vol = torch.masked_select(fr_vol, fr_mask).reshape(fr_vol.shape[0], -1)
#     # Estimate FR means and vars.
#     fr_channel_means = masked_fr_vol.mean(dim=1)
#     fr_channel_means = fr_channel_means.reshape(-1, 1, 1, 1)
#     fr_channel_vars = masked_fr_vol.var(dim=1)
#     fr_channel_vars = fr_channel_vars.reshape(-1, 1, 1, 1)

#     # Add FR stats to the summary table.
#     fr_vol_stats = batch_boxplot_stats(masked_fr_vol)
#     subj_stats["Subj ID"].extend(
#         list(itertools.repeat(subj_id, len(fr_vol_stats.median)))
#     )
#     subj_stats["Resolution"].extend(
#         list(itertools.repeat(tuple(fr_vol.shape[1:]), len(fr_vol_stats.median)))
#     )
#     subj_stats["Channel Index"].extend(dti_channel_names)
#     subj_stats["Mean"].extend(fr_channel_means.cpu().flatten().tolist())
#     subj_stats["Var"].extend(fr_channel_vars.cpu().flatten().tolist())
#     # Append FR boxplot stats to their corresponding fields. In other words, all columns
#     # after the var column.
#     for i, field in enumerate(
#         summary_stats_header[(summary_stats_header.index("Var") + 1) :]
#     ):
#         if "Num" in field:
#             subj_stats[field].extend(list(map(len, fr_vol_stats[i])))
#         else:
#             subj_stats[field].extend(fr_vol_stats[i])

# #     # Process LR volumes
# #     # Grab LR voxels within the mask.
# #     lr_vol = preproc_subj.lr_dti["data"]
# #     lr_mask = preproc_subj.lr_brain_mask["data"].bool()
# #     masked_lr_vol = torch.masked_select(lr_vol, lr_mask).reshape(lr_vol.shape[0], -1)

# #     lr_channel_means = masked_lr_vol.mean(dim=1)
# #     lr_channel_means = lr_channel_means.reshape(-1, 1, 1, 1)
# #     lr_channel_vars = masked_lr_vol.var(dim=1)
# #     lr_channel_vars = lr_channel_vars.reshape(-1, 1, 1, 1)

# #     # Add LR stats to the summary table.
# #     lr_vol_stats = batch_boxplot_stats(masked_lr_vol)
# #     subj_stats["Subj ID"].extend(
# #         list(itertools.repeat(subj_id, len(fr_vol_stats.median)))
# #     )
# #     subj_stats["Resolution"].extend(
# #         list(itertools.repeat(tuple(lr_vol.shape[1:]), len(lr_vol_stats.median)))
# #     )
# #     subj_stats["Channel Index"].extend(dti_channel_names)
# #     subj_stats["Mean"].extend(lr_channel_means.cpu().flatten().tolist())
# #     subj_stats["Var"].extend(lr_channel_vars.cpu().flatten().tolist())
# #     # Append LR boxplot stats to their corresponding fields. In other words, all columns
# #     # after the var column.
# #     for i, field in enumerate(
# #         summary_stats_header[(summary_stats_header.index("Var") + 1) :]
# #     ):
# #         if "Num" in field:
# #             subj_stats[field].extend(list(map(len, lr_vol_stats[i])))
# #         else:
# #             subj_stats[field].extend(lr_vol_stats[i])

# #     subj_idx = list(subj_dirs.keys()).index(subj_id)
# #     # Generate row of boxplots for the non-normalized volumes.
# #     fig_non_norm, share_axs_rows, share_axs_cols = plot_dti_box_row(
# #         fig_non_norm,
# #         grid_non_norm,
# #         row_idx=subj_idx,
# #         subj_id=subj_id,
# #         shared_axs_rows=share_axs_non_norm["row"],
# #         shared_axs_cols=share_axs_non_norm["col"],
# #         fr_vol=masked_fr_vol.detach().cpu().numpy(),
# #         lr_vol=masked_lr_vol.detach().cpu().numpy(),
# #     )
# #     share_axs_non_norm["row"] = share_axs_rows
# #     share_axs_non_norm["col"] = share_axs_cols

#     # Finalize this subject.
#     subj_data[subj_id] = preproc_subj

#     print("=" * 20)


# print("===Data Loaded & Transformed===")

# # subj_stats_str = tabulate(
# #     subj_stats, headers=summary_stats_header, tablefmt=summary_stats_fmt
# # )

In [None]:
# # Save out DTIs to NIFTI files
# # Save out full resolution DTI's and masks.

# # File path relative to the `write_data_dir` where the subjects are located.
# # Trying to follow BIDS format as much as possible...
# prefix_dir = 'uva/derivatives/diffusion/'
# save_dir = write_data_dir / prefix_dir
# # Names that further define sub-directories under each subject.
# dti_type_name = 'dti'
# mask_type_name = 'mask'
# for subj_id, subj_dti in subj_data.items():
#     print(subj_id)
#     sub_dir = save_dir / f'sub-{subj_id}'

#     # Save out DTI
#     nifti_img = nib.Nifti2Image(subj_dti['gt_dti'].data.detach().cpu().numpy(), affine=subj_dti['gt_dti'].affine)
#     fname = f'sub-{subj_id}_dti.nii.gz'
#     img_save_dir = sub_dir / dti_type_name
#     img_save_dir.mkdir(parents=True, exist_ok=True)
#     nib.save(nifti_img, img_save_dir / fname)

In [None]:
# display(Markdown("**Stats for Ground Truth and LR DTI's, before normalization**"))
# display(Markdown(subj_stats_str))

In [None]:
with open(log_txt_file, "a+") as f:
    f.write(f"Preprocessing transformation pipeline: {str(preproc_transforms)}\n")
    f.write(f"Data Summary Statistics, no normalization:\n {subj_stats_str}\n\n")

---

(Optional) Save out the fitted DTIs for both the full resolution and mean downsampled versions.

In [None]:
# # Save out DTIs to NIFTI files
# # Save out full resolution DTI's and masks.

# # File path relative to the `write_data_dir` where the subjects are located.
# # Trying to follow BIDS format as much as possible...
# prefix_dir = 'derivatives/diffusion/mean_downsample/scale-orig'
# save_dir = write_data_dir / prefix_dir
# # Names that further define sub-directories under each subject.
# dti_type_name = 'dti'
# mask_type_name = 'mask'
# for subj_id, subj_dti in subj_data.items():
#     print(subj_id)
#     sub_dir = save_dir / f'sub-{subj_id}'

#     # Save out DTI
#     nifti_img = nib.Nifti2Image(subj_dti['gt_dti'].data.detach().cpu().numpy(), affine=subj_dti['gt_dti'].affine)
#     fname = f'sub-{subj_id}_dti.nii.gz'
#     img_save_dir = sub_dir / dti_type_name
#     img_save_dir.mkdir(parents=True, exist_ok=True)
#     nib.save(nifti_img, img_save_dir / fname)

#     # Save out mask
#     nifti_img = nib.Nifti2Image(subj_dti['fr_brain_mask'].data.detach().cpu().numpy(), affine=subj_dti['fr_brain_mask'].affine)
#     fname = f'sub-{subj_id}_mask.nii.gz'
#     img_save_dir = sub_dir / mask_type_name
#     img_save_dir.mkdir(parents=True, exist_ok=True)
#     nib.save(nifti_img, img_save_dir / fname)

In [None]:
# # Save out downsampled DTIs
# prefix_dir = 'derivatives/diffusion/mean_downsample/scale-2.00mm'
# save_dir = write_data_dir / prefix_dir
# # Names that further define sub-directories under each subject.
# dti_type_name = 'dti'
# mask_type_name = 'mask'
# for subj_id, subj_dti in subj_data.items():
#     print(subj_id)
#     sub_dir = save_dir / f'sub-{subj_id}'

#     # Save out DTI
#     nifti_img = nib.Nifti2Image(subj_dti['lr_dti']['data'].detach().cpu().numpy(), affine=subj_dti['lr_dti']['affine'])
#     fname = f'sub-{subj_id}_meandownsample-2.00mm_dti.nii.gz'
#     img_save_dir = sub_dir / dti_type_name
#     img_save_dir.mkdir(parents=True, exist_ok=True)
#     nib.save(nifti_img, img_save_dir / fname)

#     # Save out mask
#     nifti_img = nib.Nifti2Image(subj_dti['lr_brain_mask']['data'].detach().cpu().numpy(), affine=subj_dti['lr_brain_mask']['affine'])
#     fname = f'sub-{subj_id}_meandownsample-2.00mm_mask.nii.gz'
#     img_save_dir = sub_dir / mask_type_name
#     img_save_dir.mkdir(parents=True, exist_ok=True)
#     nib.save(nifti_img, img_save_dir / fname)

---

#### Volume Normalization

In [None]:
if isinstance(data_norm_method, str) and "channel" in data_norm_method.casefold():
    normalize = True
else:
    normalize = False

In [None]:
# Collect global statistics
# Create digest for keeping a running estimate of global-level statistics.
digests = Addict()
quantiles = Addict()
for i in range(exp_params.patch.channels):
    digests.lr[i] = list()
    digests.fr[i] = list()
    quantiles.lr[i] = (-np.inf, np.inf)
    quantiles.fr[i] = (-np.inf, np.inf)

for i_channel in range(exp_params.patch.channels):
    # FR loop
    for subj_id, subj in subj_data.items():
        gt_dti = subj["gt_dti"].data[i_channel]
        fr_mask = subj["fr_brain_mask"].data[0].bool()
        masked_fr_vol = torch.masked_select(gt_dti, fr_mask).view(-1)

        digests.fr[i_channel].append(masked_fr_vol)

    # Calculate full quantiles
    quantile = torch.quantile(
        torch.cat(digests.fr[i_channel]),
        torch.as_tensor(exp_params.data.clamp_quantiles).to(masked_fr_vol),
    )
    # Round quantiles to resolve round-off errors.
    # Pytorch `round()` doesn't have a decimals argument? Seriously?
    quantile = torch.from_numpy(
        np.round(quantile.detach().cpu().numpy(), decimals=7)
    ).to(quantile)
    # Store quantiles.
    quantiles.fr[i_channel] = quantile
    # Delete concatenated volumes.
    digests.fr[i_channel].clear()

    # LR loop
    for subj_id, subj in subj_data.items():
        lr_dti = subj["lr_dti"]["data"][i_channel]
        lr_mask = subj["lr_brain_mask"]["data"][0].bool()
        masked_lr_vol = torch.masked_select(lr_dti, lr_mask).view(-1)

        digests.lr[i_channel].append(masked_lr_vol)
    # Calculate full quantiles
    quantile = torch.quantile(
        torch.cat(digests.lr[i_channel]),
        torch.as_tensor(exp_params.data.clamp_quantiles).to(masked_lr_vol),
    )
    # Round quantiles to resolve round-off errors.
    quantile = torch.from_numpy(
        np.round(quantile.detach().cpu().numpy(), decimals=7)
    ).to(quantile)
    # Store quantiles.
    quantiles.lr[i_channel] = quantile
    # Delete concatenated volumes.
    digests.lr[i_channel].clear()

In [None]:
ppr(quantiles)

In [None]:
# Optional data normalization and normalized stats reporting.

# Dictionary to hold the subject's summary statistics if image-level or global
# normalization is used.
norm_subj_stats = dict()
# Grid plot for displaying enhanced boxplots of subject DTI image intensities.
plt.clf()
colors = list(sns.color_palette("Set2", n_colors=2))
fig_norm = plt.figure(dpi=120, figsize=(6 * 2.5, 1.5 * n_subj))
fig_norm.suptitle("Normalized Subject DTI Distributions")
grid_norm = mpl.gridspec.GridSpec(
    nrows=len(subj_dirs.keys()), ncols=len(dti_channel_names), figure=fig_norm
)
share_axs_norm = collections.defaultdict(
    lambda: None,
    row=collections.defaultdict(lambda: None),
    col=collections.defaultdict(lambda: None),
)

for k in summary_stats_header:
    norm_subj_stats[k] = list()

# Optionally apply image-level or global normalization.
if normalize:
    for subj_id, subj in subj_data.items():

        # Subject-and-channel-wise standardization/normalization of both the LR and FR vols.
        # Note that LR and FR images should have the same means, but *not* the same variances.
        fr_vol = subj.gt_dti.tensor
        # Perform clamping of quantile ranges
        fr_clamp_quantiles = (
            torch.stack([q for q in quantiles.fr.values()], dim=1)
            .to(fr_vol)
            .view(2, 6, 1, 1, 1)
        )
        fr_mask = subj.fr_brain_mask.tensor.bool()
        # Report how many voxels will be clamped.
        num_clamps = (
            torch.masked_select(fr_vol, fr_mask).reshape(fr_vol.shape[0], -1)
            < fr_clamp_quantiles[0, :, 0, 0]
        ).sum(dim=1)

        num_clamps = (
            torch.masked_select(fr_vol, fr_mask).reshape(fr_vol.shape[0], -1)
            > fr_clamp_quantiles[1, :, 0, 0]
        ).sum(dim=1)

        print(
            f"     Subj {subj_id} FR, number of clamped voxels per-channel: ",
            num_clamps.tolist(),
        )
        print(
            f"    Subj {subj_id} LR, percentage of clamped voxels per-channel: ",
            (
                num_clamps / torch.masked_select(fr_vol[0], fr_mask).numel() * 100
            ).tolist(),
        )
        # Clamp upper and lower quantiles in the volume.
        fr_vol = torch.where(
            fr_vol < fr_clamp_quantiles[0], fr_clamp_quantiles[0], fr_vol
        )
        fr_vol = torch.where(
            fr_vol > fr_clamp_quantiles[1], fr_clamp_quantiles[1], fr_vol
        )

        # Estimate means and vars from the masked voxels.
        masked_fr_vol = torch.masked_select(fr_vol, fr_mask).reshape(
            fr_vol.shape[0], -1
        )
        fr_channel_means = masked_fr_vol.mean(dim=1)
        fr_channel_means = fr_channel_means.reshape(-1, 1, 1, 1)
        fr_channel_vars = masked_fr_vol.var(dim=1)
        fr_channel_vars = fr_channel_vars.reshape(-1, 1, 1, 1)

        # Process LR volumes.
        lr_vol = subj["lr_dti"]["data"]
        subj["source_lr_dti"] = dict()
        subj["source_lr_dti"]["data"] = lr_vol
        subj["source_lr_dti"]["affine"] = subj["lr_dti"]["affine"]
        # Perform clamping of quantile ranges
        lr_clamp_quantiles = (
            torch.stack([q for q in quantiles.lr.values()], dim=1)
            .to(lr_vol)
            .view(2, 6, 1, 1, 1)
        )
        lr_mask = subj["lr_brain_mask"]["data"].bool()

        # Report how many voxels will be clamped.
        num_clamps = (
            torch.masked_select(lr_vol, lr_mask).reshape(lr_vol.shape[0], -1)
            < lr_clamp_quantiles[0, :, 0, 0]
        ).sum(dim=1)

        num_clamps = (
            torch.masked_select(lr_vol, lr_mask).reshape(lr_vol.shape[0], -1)
            > lr_clamp_quantiles[1, :, 0, 0]
        ).sum(dim=1)

        print(
            f"    Subj {subj_id} LR, number of clamped voxels per-channel: ",
            num_clamps.tolist(),
        )
        print(
            f"    Subj {subj_id} LR, percentage of clamped voxels per-channel: ",
            (
                num_clamps / torch.masked_select(lr_vol[0], lr_mask).numel() * 100
            ).tolist(),
        )
        lr_vol = torch.where(
            lr_vol <= lr_clamp_quantiles[0], lr_clamp_quantiles[0], lr_vol
        )
        lr_vol = torch.where(
            lr_vol >= lr_clamp_quantiles[1], lr_clamp_quantiles[1], lr_vol
        )

        # Estimate means and vars from the masked voxels.
        masked_lr_vol = torch.masked_select(lr_vol, lr_mask).reshape(
            lr_vol.shape[0], -1
        )
        lr_channel_means = masked_lr_vol.mean(dim=1)
        lr_channel_means = lr_channel_means.reshape(-1, 1, 1, 1)
        lr_channel_vars = masked_lr_vol.var(dim=1)
        lr_channel_vars = lr_channel_vars.reshape(-1, 1, 1, 1)

        # Normalize the volumes.
        fr_vol = pitn.data.norm.normalize_dti(fr_vol, fr_channel_means, fr_channel_vars)
        # Zero out voxels outside the mask.
        fr_vol = fr_vol * fr_mask

        lr_vol = pitn.data.norm.normalize_dti(lr_vol, lr_channel_means, lr_channel_vars)
        # Zero out voxels outside the mask.
        lr_vol = lr_vol * lr_mask

        # Store new volumes back into the Subject object.
        fr_dti_img = torchio.ScalarImage(
            tensor=fr_vol,
            affine=subj["gt_dti"]["affine"],
        )
        # The normalized target is named "fr_dti", whereas the source/ground truth DTI
        # (without normalization or clamping) is named "gt_dti".
        subj.add_image(fr_dti_img, "fr_dti")
        # Store subject-and-channel-wise means and vars.
        subj["fr_means"] = fr_channel_means.detach().cpu().numpy()
        subj["fr_vars"] = fr_channel_vars.detach().cpu().numpy()

        subj["lr_dti"]["data"] = lr_vol
        # Store subject-and-channel-wise means and vars.
        subj["lr_means"] = lr_channel_means.detach().cpu().numpy()
        subj["lr_vars"] = lr_channel_vars.detach().cpu().numpy()

        # ============= Visualization and tracking procedures =============
        # Re-calculate the same statistics post-normalization.
        # Only consider the voxels within the mask.
        masked_fr_vol = torch.masked_select(fr_vol, fr_mask).reshape(6, -1)

        masked_fr_channel_means = masked_fr_vol.mean(dim=1)
        masked_fr_channel_means = masked_fr_channel_means.reshape(-1, 1, 1, 1)
        masked_fr_channel_vars = masked_fr_vol.var(dim=1)
        masked_fr_channel_vars = masked_fr_channel_vars.reshape(-1, 1, 1, 1)
        # Print and log some statistics of the subject data.
        # Add FR stats to the summary table.
        masked_fr_vol_stats = batch_boxplot_stats(masked_fr_vol.detach().cpu().numpy())
        norm_subj_stats["Subj ID"].extend(
            list(itertools.repeat(subj_id, len(masked_fr_vol_stats.median)))
        )
        norm_subj_stats["Resolution"].extend(
            list(
                itertools.repeat(
                    tuple(fr_vol.shape[1:]), len(masked_fr_vol_stats.median)
                )
            )
        )
        norm_subj_stats["Channel Index"].extend(dti_channel_names)
        norm_subj_stats["Mean"].extend(masked_fr_channel_means.cpu().flatten().tolist())
        norm_subj_stats["Var"].extend(masked_fr_channel_vars.cpu().flatten().tolist())
        # Append FR boxplot stats to their corresponding fields. In other words, all
        # columns after the var column.
        for i, field in enumerate(
            summary_stats_header[(summary_stats_header.index("Var") + 1) :]
        ):
            if "Num" in field:
                norm_subj_stats[field].extend(list(map(len, masked_fr_vol_stats[i])))
            else:
                norm_subj_stats[field].extend(masked_fr_vol_stats[i])

        # Add LR stats to the summary table.
        masked_lr_vol = torch.masked_select(lr_vol, lr_mask).reshape(6, -1)
        masked_lr_channel_means = masked_lr_vol.mean(dim=1)
        masked_lr_channel_means = masked_lr_channel_means.reshape(-1, 1, 1, 1)
        masked_lr_channel_vars = masked_lr_vol.var(dim=1)
        masked_lr_channel_vars = masked_lr_channel_vars.reshape(-1, 1, 1, 1)

        masked_lr_vol_stats = batch_boxplot_stats(masked_lr_vol.detach().cpu().numpy())
        norm_subj_stats["Subj ID"].extend(
            list(itertools.repeat(subj_id, len(masked_lr_vol_stats.median)))
        )
        norm_subj_stats["Resolution"].extend(
            list(
                itertools.repeat(
                    tuple(lr_vol.shape[1:]), len(masked_lr_vol_stats.median)
                )
            )
        )
        norm_subj_stats["Channel Index"].extend(dti_channel_names)
        norm_subj_stats["Mean"].extend(masked_lr_channel_means.cpu().flatten().tolist())
        norm_subj_stats["Var"].extend(masked_lr_channel_vars.cpu().flatten().tolist())
        # Append LR boxplot stats to their corresponding fields. In other words, all columns
        # after the var column.
        for i, field in enumerate(
            summary_stats_header[(summary_stats_header.index("Var") + 1) :]
        ):
            if "Num" in field:
                norm_subj_stats[field].extend(list(map(len, masked_lr_vol_stats[i])))
            else:
                norm_subj_stats[field].extend(masked_lr_vol_stats[i])

        subj_idx = list(subj_dirs.keys()).index(subj_id)

        # Generate row of boxplots for the normalized volumes.
        fig_norm, share_axs_rows, share_axs_cols = plot_dti_box_row(
            fig_norm,
            grid_norm,
            row_idx=subj_idx,
            subj_id=subj_id,
            shared_axs_rows=share_axs_norm["row"],
            shared_axs_cols=share_axs_norm["col"],
            fr_vol=masked_fr_vol.detach().cpu().numpy(),
            lr_vol=masked_lr_vol.detach().cpu().numpy(),
        )
        share_axs_norm["row"] = share_axs_rows
        share_axs_norm["col"] = share_axs_cols

        print("=" * 20)

    print("===Data Normalized===")

    if norm_subj_stats["Subj ID"]:
        norm_subj_stats_str = tabulate(
            norm_subj_stats, headers=summary_stats_header, tablefmt=summary_stats_fmt
        )
    else:
        norm_subj_stats_str = ""

In [None]:
# If the subject data was normalized and those stats were recorded, log those stats.
if normalize:
    with open(log_txt_file, "a+") as f:
        f.write(
            f"Data Summary Statistics, after normalization:\n {norm_subj_stats_str}\n\n"
        )

In [None]:
# if normalize:
#     display(fig_norm)

In [None]:
# if normalize:
#     display(Markdown(norm_subj_stats_str))

In [None]:
# Finalize the dataset.
subj_dataset = torchio.SubjectsDataset(list(subj_data.values()), load_getitem=False)

## Model Training

### Set Up Patch-Based Data Loaders

In [None]:
# Data train/validation/test split

num_subjs = len(subj_dataset)
num_test_subjs = int(np.ceil(num_subjs * test_percent))
num_val_subjs = int(np.ceil(num_subjs * val_percent))
num_train_subjs = num_subjs - (num_test_subjs + num_val_subjs)
subj_list = subj_dataset.dry_iter()
# Randomly shuffle the list of subjects, then choose the first `num_test_subjs` subjects
# for testing.
random.shuffle(subj_list)

# Choose the remaining for training/validation.
# If only 1 subject is available, assume this is a debugging run.
if num_subjs == 1:
    warnings.warn(
        "DEBUG: Only 1 subject selected, mixing training, validation, and testing sets"
    )
    num_train_subjs = num_val_subjs = num_test_subjs = 1

    test_subjs = subj_list
    val_subjs = subj_list
    train_subjs = subj_list
else:
    test_subjs = subj_list[:num_test_subjs]
    val_subjs = subj_list[num_test_subjs : (num_test_subjs + num_val_subjs)]
    train_subjs = subj_list[(num_test_subjs + num_val_subjs) :]

test_dataset = monai.data.Dataset([Addict(subj) for subj in test_subjs])
val_dataset = monai.data.Dataset([Addict(subj) for subj in val_subjs])
train_dataset = torchio.SubjectsDataset(train_subjs, load_getitem=False)

# Training patch sampler, random across all patches of all volumes.
# The training targets consist of normalized, transformed DTIs that are *not* exactly
# equivalent to the ground truth DTI's.
# Training targets use the "fr_dti"s, while testing and validation targets are "gt_dti"s.
train_sampler = pitn.samplers.MultiresSampler(
    source_img_key="fr_dti",
    low_res_key="lr_dti",
    downsample_factor=exp_params.downsample_factor,
    low_res_sample_extension=exp_params.patch.low_res_sample_extension,
    label_name="fr_brain_mask",
    source_spatial_patch_size=output_spatial_patch_shape,
    low_res_spatial_patch_size=input_spatial_patch_shape,
    label_probabilities={0: 0, 1: 1},
    source_mask_key="fr_brain_mask",
)

patches_per_subj = 8000
# Set length of the queue.
queue_max_length = patches_per_subj * num_train_subjs

# Set up a torchio.Queue to act as a sampler proxy for the torch DataLoader
train_queue = torchio.Queue(
    train_dataset,
    max_length=queue_max_length,
    samples_per_volume=patches_per_subj,
    sampler=train_sampler,
    shuffle_patches=True,
    shuffle_subjects=True,
    num_workers=7,
    #     verbose=True,
)

# Create partial function to collect list of samples and form a tuple of tensors.
collate_fn = functools.partial(
    pitn.samplers.collate_subj_mask,
    full_res_key="fr_dti",
    low_res_key="lr_dti",
    full_res_mask_key="fr_brain_mask",
)

train_loader = torch.utils.data.DataLoader(
    train_queue,
    batch_size=exp_params.patch.batch_size,
    collate_fn=collate_fn,
    pin_memory=True,
    num_workers=0,
)

In [None]:
# # Statically set training patches
# static_train_patches = [train_queue[i] for i in range(train_queue.iterations_per_epoch)]
# static_dataset = torchio.SubjectsDataset(static_train_patches)
# train_loader = torch.utils.data.DataLoader(
#     static_dataset,
#     batch_size=exp_params.patch.batch_size,
#     shuffle=True,
#     collate_fn=collate_fn,
#     pin_memory=True,
#     num_workers=0,
# )

In [None]:
# Set up validation and testing objects.

# Validation samplers/datasets
# Alter the collate function to include locations, masks, and subj ids in batches for
# visualization, reconstruction, analysis, etc.
# This collate function determines the contents of the `batch` parameters in the
# PytorchLightning system's `test_step` and `validation_step` methods.
collate_meta = functools.partial(
    pitn.viz.collate_locs_and_keys,
    full_res_key="gt_dti",
    low_res_key="lr_dti",
    full_res_mask_key="source_mask",
    subj_id="subj_id",
    fr_mean="fr_means",
    fr_var="fr_vars",
)

collate_keys = lambda samples: pitn.samplers.collate_dicts(
    samples,
    "subj_id",
    "fr_brain_mask",
    low_res="lr_dti",
    full_res="gt_dti",
    fr_mean="fr_means",
    fr_var="fr_vars",
)


# Construct val and test DataLoaders to operate on one full volume per batch.

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=1,
    collate_fn=collate_keys,
    pin_memory=True,
    num_workers=4,
)

# Test samplers/datasets
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    collate_fn=collate_keys,
    pin_memory=True,
    num_workers=4,
)

print("Training subject ID(s): ", [s.subj_id for s in train_dataset.dry_iter()])
print("Validation subject ID(s): ", [s.subj_id for s in val_dataset])
print("Test subject ID(s): ", [s.subj_id for s in test_dataset])

### Bounding Box Selection for Visualization During Training

In [None]:
# Choose only one subject for visualization during validation.
val_viz_subj_id = random.choice([s.subj_id for s in val_dataset])
# Store its index in the Dataset object for easier indexing.
val_viz_dataset_idx = [s.subj_id for s in val_dataset].index(val_viz_subj_id)
print(val_viz_subj_id)

In [None]:
# Create bbox coordinates for visualizing validation aggregate patches.
# NOTE: This presumes that there are no overlap patches in the validation step, and that
# each validation volume is divisible by the patch shape in each spatial dimension.
bbox_coords = list()
region_size = torch.as_tensor(output_spatial_patch_shape) * 3
vol_shape = torch.as_tensor(val_dataset[val_viz_dataset_idx].fr_dti.data.shape[1:])
possible_bbox_ini = [
    torch.arange(0, vol_shape[0], output_spatial_patch_shape[0]),
    torch.arange(0, vol_shape[1], output_spatial_patch_shape[1]),
    torch.arange(0, vol_shape[2], output_spatial_patch_shape[2]),
]

# Create bbox that spans roughly the center of the volume.
bbox_idx_ini = list()
for possible_bbox_part in possible_bbox_ini:
    num_parts = len(possible_bbox_part)
    bbox_idx_ini.append(possible_bbox_part[round(num_parts * 0.4)])
bbox_idx_ini = torch.as_tensor(bbox_idx_ini)

bbox_coord = torch.cat([bbox_idx_ini, bbox_idx_ini + region_size])
bbox_coords.append(bbox_coord)

In [None]:
# Create bbox that covers an edge.
vol_mask = val_dataset[val_viz_dataset_idx].fr_brain_mask.data[0]
bbox_mask_coverage = list()
possible_bbox_coords_ini = list(itertools.product(*possible_bbox_ini))
for bbox_w, bbox_h, bbox_d in possible_bbox_coords_ini:
    bbox_start = (bbox_w, bbox_h, bbox_d)
    bbox_end = (
        bbox_start[0] + output_spatial_patch_shape[0],
        bbox_start[1] + output_spatial_patch_shape[1],
        bbox_start[2] + output_spatial_patch_shape[2],
    )
    patch = vol_mask[
        bbox_start[0] : bbox_end[0],
        bbox_start[1] : bbox_end[1],
        bbox_start[2] : bbox_end[2],
    ]
    # Validation visuals are shown over the Superior-Anterior axis, so the edge should be
    # located there.
    bbox_mask_coverage.append(
        patch[:, output_spatial_patch_shape[1] // 2, :].sum().item()
    )

# Now search all coordinates' mask coverages for the one closest to a 50% of the total
# patch volume.
patch_vol = output_spatial_patch_shape[0] * output_spatial_patch_shape[2]
target_mask_vol = patch_vol // (2**2)

bbox_coord_idx = np.argmin(np.abs(np.asarray(bbox_mask_coverage) - patch_vol))
bbox_idx_ini = torch.tensor(possible_bbox_coords_ini[bbox_coord_idx])
bbox_coord = torch.cat([bbox_idx_ini, bbox_idx_ini + region_size])
bbox_coords.append(bbox_coord)

In [None]:
bbox_coords = torch.stack(bbox_coords)

bbox_coords

In [None]:
with open(log_txt_file, "a+") as f:
    f.write(f"Training Set Subjects: {[s.subj_id for s in train_dataset.dry_iter()]}\n")
    f.write(f"Validation Set Subjects: {[s.subj_id for s in val_dataset]}\n")
    f.write(f"Test Set Subjects: {[s.subj_id for s in test_dataset]}\n")

logger.add_text("train_subjs", str([s.subj_id for s in train_dataset.dry_iter()]))
logger.add_text("val_subjs", str([s.subj_id for s in val_dataset]))
logger.add_text("test_subjs", str([s.subj_id for s in test_dataset]))

### Model Definition

In [None]:
# Full pytorch-lightning module for contained training, validation, and testing.
debug_prob = -1 / (patches_per_subj * num_train_subjs / exp_params.patch.batch_size)


class DIQTSystem(pl.LightningModule):

    # Specify training loss methods with mappings to their names as strings.
    loss_methods = {
        "MSE".casefold(): torch.nn.MSELoss(reduction="mean"),
        "SSE".casefold(): torch.nn.MSELoss(reduction="sum"),
        "RMSE".casefold(): lambda y_hat, y: torch.sqrt(
            F.mse_loss(y_hat, y, reduction="mean")
        ),
        "L1".casefold(): torch.nn.L1Loss(reduction="mean"),
    }

    def __init__(
        self,
        channels,
        downsample_factor,
        source_vox_size: float,
        target_vox_size: float,
        train_loss_method: str,
        opt_params: dict,
        norm_method=None,
        val_viz_bboxes=None,
        val_viz_subj_id=None,
        #         val_patch_overlap=(0, 0, 0),
        val_viz_every_n_epochs=1,
    ):
        super().__init__()

        self._channels = channels
        self._downsample_factor = downsample_factor

        # Parameters
        # Network parameters
        self.net = pitn.nn.models.FractDownReduceBy5Conv(
            self._channels, self._downsample_factor
        )
        #         self.net = pitn.nn.models.DebugFC(
        #             input_patch_shape, output_patch_shape
        #         )
        self._norm_eps = 1e-10
        ## Training parameters
        self.opt_params = opt_params

        # Select loss method as either one of the pre-selected methods, or a custom
        # callable.
        try:
            self._loss_fn = self.loss_methods[train_loss_method.casefold()]
        except (AttributeError, KeyError):
            if callable(train_loss_method):
                self._loss_fn = train_loss_method
            else:
                raise ValueError(
                    f"ERROR: Invalid loss function specification {train_loss_method}, "
                    + f"expected one of {self.loss_methods.keys()} or a callable."
                )

        # Sub-regions of the volume that should be logged in validation.
        if val_viz_bboxes is None:
            self.val_bboxes = torch.zeros(0, 6)
            self._val_subj_id = None
        else:
            self.val_bboxes = val_viz_bboxes
            self._val_subj_id = val_viz_subj_id
        #         self.val_patch_overlap = val_patch_overlap

        # Store the validation set min and max for each bounding box, to keep a consistent
        # color scale on the color bar.
        self.val_vmin = [
            None,
        ] * self.val_bboxes.shape[0]
        self.val_vmax = [
            None,
        ] * self.val_bboxes.shape[0]
        self.val_viz_every_n_epochs = val_viz_every_n_epochs
        self._last_val_viz_epoch = -1

        # My own dinky logging object.
        self.plain_log = Addict(
            {
                "train_loss": list(),
                "val_loss": list(),
                "test_loss": dict(),
                "spline_loss": list(),
                "viz": Addict(
                    {
                        "test_preds": Addict(),
                        "test_squared_error": Addict(),
                    }
                ),
            }
        )

    def forward(self, x):
        y = self.net(x)
        return y

    def training_step(self, batch, batch_idx):
        x, y, masks = batch
        masks = masks.bool()

        y_pred = self.net(x, norm_output=False)

        loss = self._loss_fn(
            torch.masked_select(y_pred, masks), torch.masked_select(y, masks)
        )

        # Random debug scheduling.
        if random.random() <= debug_prob:
            #         try:
            #             n_m_3_loss = np.mean(self.plain_log.train_loss[-3:])
            #         except IndexError:
            #             n_m_3_loss = loss.detach().item()
            #         if (self.current_epoch >= 3) and np.abs(loss.detach().item() - n_m_3_loss) >= (
            #             100 * n_m_3_loss
            #         ):
            batch_size = x.shape[0]
            images = list()
            tab_labels = list()
            image_labels = list()
            channel_labels = [
                "D x,x",
                "D x,y",
                "D y,y",
                "D x,z",
                "D y,z",
                "D z,z",
            ]
            # Construct the images to display, their names, and their batch index
            # as the "class" tab label.
            # ipyplot only takes flat lists as parameters, so no fancy multi-dimensional
            # lists or anything like that.
            # Each batch is its own "class" (UI tab).
            for batch_i in range(batch_size):
                # Create a new set of images for every channel.
                for channel_i, channel_name in enumerate(channel_labels):

                    # Collect the input image.
                    input_img = x[batch_i, channel_i].detach().cpu().numpy()
                    # ipyplot will use PIL.Image when given a numpy array to display, so
                    # float ndarrays must be scaled between 0 and 1.
                    input_img = skimage.exposure.rescale_intensity(
                        input_img, out_range=(0.0, 1.0)
                    )
                    # Remove the oversampled pixels.
                    #                     input_img = input_img[2:-2, 2:-2, 2:-2]
                    input_slice_idx = input_img.shape[1] // 2
                    input_img = input_img[:, input_slice_idx, :]
                    tab_labels.append(str(batch_i))
                    images.append(np.rot90(input_img))
                    image_labels.append(f"Input {channel_name}")

                    # Collect the target image.
                    target_img = y[batch_i, channel_i].detach().cpu().numpy()
                    target_img = skimage.exposure.rescale_intensity(
                        target_img, out_range=(0.0, 1.0)
                    )
                    tab_labels.append(str(batch_i))
                    hr_slice_idx = target_img.shape[1] // 2
                    target_img = target_img[:, hr_slice_idx, :]
                    images.append(np.rot90(target_img))
                    image_labels.append(f"Target {channel_name}")

                    # Collect the predicted image.
                    pred_img = y_pred[batch_i, channel_i].detach().cpu().numpy()
                    pred_img = skimage.exposure.rescale_intensity(
                        pred_img, out_range=(0.0, 1.0)
                    )
                    tab_labels.append(str(batch_i))
                    pred_img = pred_img[:, hr_slice_idx, :]
                    images.append(np.rot90(pred_img))
                    image_labels.append(f"Prediction {channel_name}")

                # Collect the mask image (only 1 per batch).
                mask_img = masks[batch_i][0].detach().cpu().numpy()
                mask_img = mask_img.astype(bool)
                mask_img = mask_img[:, hr_slice_idx, :]
                tab_labels.append(str(batch_i))
                images.append(np.rot90(mask_img))
                image_labels.append("Mask")

            ipyplot.plot_class_tabs(
                images, tab_labels, custom_texts=image_labels, show_url=False
            )

        self.log("train_loss", loss)
        self.plain_log["train_loss"].append(float(loss.cpu()))
        return loss

    def log_bbox_figure(
        self,
        fr,
        pred,
        abs_error,
        vmin: float,
        vmax: float,
        fig_name: str,
        slice_idx=(0, slice(None), slice(None), 86, slice(None)),
    ):
        # Slice into full volume and just grab a B x C x H x W slice.
        row_fr = pitn.viz.make_grid(
            [torch.rot90(t) for t in fr[slice_idx]], padding=2, pad_value=-2, nrow=1
        )
        row_pred = pitn.viz.make_grid(
            [torch.rot90(t) for t in pred[slice_idx]], padding=2, pad_value=-2, nrow=1
        )
        row_abs_error = pitn.viz.make_grid(
            [torch.rot90(t) for t in abs_error[slice_idx]],
            padding=2,
            pad_value=-2,
            nrow=1,
        )
        reg_grid = pitn.viz.make_grid(
            [row_fr, row_pred, row_abs_error], nrow=3, pad_value=-2
        )

        fig = plt.figure(figsize=(8, 4.5), dpi=110, clear=True)

        plt.imshow(
            #             np.rot90(reg_grid.cpu().numpy(), axes=(-2, -1)),
            reg_grid.cpu().numpy(),
            vmin=vmin,
            vmax=vmax,
            cmap="jet",
        )
        plt.xticks([], [])
        plt.yticks([], [])
        plt.colorbar(location="bottom")
        self.logger.experiment.add_figure(fig_name, fig, self.global_step)

    def validation_step(self, batch: dict, batch_idx):
        # The validation samples include extra metadata regarding the patches.
        # NOTE: This assumes a batch size of 1!
        batch = Addict(batch)
        x = batch.low_res.data
        y = batch.full_res.data
        y_mask = batch.fr_brain_mask.data
        y_mask = y_mask.bool()
        fr_mean = batch.fr_mean
        fr_var = batch.fr_var
        fr_mean = torch.as_tensor(fr_mean).to(x)
        fr_var = torch.as_tensor(fr_var).to(fr_mean)
        subj_id = batch.subj_id[0]

        y_pred = self.net(
            x,
            norm_output=False,
            pad_reduced_shape=True,
            interp_to_spatial_shape=y.shape[-3:],
        )

        # Denormalize the predictions and calculate RMSE on intensity values
        # in the original data's scale.
        if (not (fr_mean == 0).all()) and (not (fr_var == 1).all()):
            y_pred = pitn.data.norm.denormalize_batch(
                y_pred, mean=fr_mean, var=fr_var, eps=self._norm_eps
            )
        loss = torch.sqrt(
            F.mse_loss(
                torch.masked_select(y_pred, y_mask),
                torch.masked_select(y, y_mask),
                reduction="mean",
            )
        )
        #         loss = torch.sqrt(F.mse_loss(y_pred * y_mask, y * y_mask, reduction="mean"))

        self.log("val_loss", loss)
        self.plain_log["val_loss"].append(float(loss.cpu()))

        # Create validation set visualizations, if necessary.
        # Only create viz for one subject in the validation set.
        if subj_id == self._val_subj_id:
            # Only create a validation visualization if the current global epoch has reached a
            # threshold.
            if (
                self.current_epoch % self.val_viz_every_n_epochs == 0
                or (self.current_epoch - self._last_val_viz_epoch)
                > self.val_viz_every_n_epochs
            ):
                self._last_val_viz_epoch = self.current_epoch

                # Create a plot for each of the given bounding boxes/regions of interest.
                for i_bbox, bbox_range in enumerate(self.val_bboxes):

                    # Grab the range of the bbox region and create the index object.
                    bbox_range = bbox_range.detach().cpu().int()
                    b_start = bbox_range[:3].tolist()
                    b_end = bbox_range[3:].tolist()
                    #                     breakpoint()
                    # Select the target FR region, the predicted region, and the absolute
                    # error between the two.
                    fr_patch = (
                        y[
                            ...,
                            b_start[0] : b_end[0],
                            b_start[1] : b_end[1],
                            b_start[2] : b_end[2],
                        ]
                        .detach()
                        .cpu()
                    )

                    pred_patch = (
                        y_pred[
                            ...,
                            b_start[0] : b_end[0],
                            b_start[1] : b_end[1],
                            b_start[2] : b_end[2],
                        ]
                        .detach()
                        .cpu()
                    )

                    abs_error = torch.sqrt(
                        F.mse_loss(y_pred * y_mask, y * y_mask, reduction="none")
                    ).detach()
                    abs_error_patch = abs_error[
                        ...,
                        b_start[0] : b_end[0],
                        b_start[1] : b_end[1],
                        b_start[2] : b_end[2],
                    ].cpu()

                    # Choose which 2D plane to plot.
                    slice_idx = (
                        0,
                        slice(None),
                        fr_patch.shape[-3] // 2,
                        slice(None),
                        slice(None),
                    )

                    # Select vmin and vmax values to maintain throughout training.
                    if self.val_vmin[i_bbox] is None:
                        vmin = np.quantile(fr_patch[slice_idx].numpy().flatten(), 0.05)
                        self.val_vmin[i_bbox] = vmin
                    if self.val_vmax[i_bbox] is None:
                        vmax = np.quantile(fr_patch[slice_idx].numpy().flatten(), 0.95)
                        self.val_vmax[i_bbox] = vmax

                    self.log_bbox_figure(
                        fr_patch,
                        pred_patch,
                        abs_error_patch,
                        vmin=self.val_vmin[i_bbox],
                        vmax=self.val_vmax[i_bbox],
                        fig_name=f"val_region_{i_bbox}",
                        slice_idx=slice_idx,
                    )

        return loss

    def test_step(self, batch: dict, batch_idx):

        # The testing samples include extra metadata regarding the patches.
        # NOTE: This assumes a batch size of 1!
        batch = Addict(batch)
        x = batch.low_res.data
        y = batch.full_res.data
        y_mask = batch.fr_brain_mask.data
        y_mask = y_mask.bool()
        fr_mean = batch.fr_mean
        fr_mean = torch.as_tensor(fr_mean).to(x)
        fr_var = batch.fr_var
        fr_var = torch.as_tensor(fr_var).to(fr_mean)
        subj_id = batch.subj_id[0]

        y_pred = self.net(
            x,
            norm_output=False,
            pad_reduced_shape=True,
            interp_to_spatial_shape=y.shape[-3:],
        )

        # Denormalize the predictions and calculate RMSE on intensity values
        # in the original data's scale.
        if (not (fr_mean == 0).all()) and (not (fr_var == 1).all()):
            y_pred = pitn.data.norm.denormalize_batch(
                y_pred, mean=fr_mean, var=fr_var, eps=self._norm_eps
            )
        loss = torch.sqrt(
            F.mse_loss(
                torch.masked_select(y_pred, y_mask),
                torch.masked_select(y, y_mask),
                reduction="mean",
            )
        )
        #         loss = torch.sqrt(F.mse_loss(y_pred * y_mask, y * y_mask, reduction="mean"))

        self.log("test_loss", loss)
        self.plain_log["test_loss"][subj_id] = loss.detach().cpu().item()

        self.plain_log.viz.test_preds[subj_id] = y_pred[0].detach().cpu()

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.net.parameters(), **self.opt_params)
        return optimizer

### Training Loop

In [None]:
train_start_timestamp = datetime.datetime.now().replace(microsecond=0)

model = DIQTSystem(
    channels=channels,
    downsample_factor=exp_params.downsample_factor,
    source_vox_size=exp_params.source_vox_size,
    target_vox_size=exp_params.target_vox_size,
    norm_method=network_norm_method,
    train_loss_method=train_loss_name,
    opt_params=opt_params,
    val_viz_subj_id=val_viz_subj_id,
    val_viz_bboxes=bbox_coords,
    val_viz_every_n_epochs=5,
)

with open(log_txt_file, "a+") as f:
    f.write(f"Model overview: {model}\n")

# Create trainer object.
trainer = pl.Trainer(
    #     fast_dev_run=10,
    gpus=1,
    max_epochs=max_epochs,
    logger=pl_logger,
    log_every_n_steps=50,
    check_val_every_n_epoch=1,
    progress_bar_refresh_rate=10,
    terminate_on_nan=True,
)

# Many warnings are produced here, so it's better for my sanity (i.e., worse in every other
# way) to just filter and ignore them...
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

train_duration = datetime.datetime.now().replace(microsecond=0) - train_start_timestamp
print(f"Train duration: {train_duration}")

In [None]:
# debug
# %autoreload

In [None]:
# Save out trained model
trainer.save_checkpoint(str(experiment_results_dir / "model.ckpt"))

In [None]:
with open(log_txt_file, "a+") as f:
    f.write("\n")
    f.write(f"Training time: {train_duration}\n")
    f.write(
        f"\t{train_duration.days} Days, "
        + f"{train_duration.seconds // 3600} Hours,"
        + f"{(train_duration.seconds // 60) % 60} Minutes,"
        + f'{train_duration.seconds % 60} Seconds"\n'
    )

In [None]:
# Plot rolling average window of training loss values.
plt.figure(dpi=110)
window = 1000
rolling_mean = (
    np.convolve(model.plain_log["train_loss"], np.ones(window), "valid") / window
)
rolling_start = 100
plt.plot(
    np.arange(
        window + rolling_start,
        window + rolling_start + len(rolling_mean[rolling_start:]),
    ),
    rolling_mean[rolling_start:],
)
plt.title("Training Loss " + train_loss_name + f"\nRolling Mean {window}")
plt.xlabel("Epoch")
plt.ylabel("Loss")
# plt.ylim(0, 1)
print(np.median(rolling_mean))
print(
    np.mean(model.plain_log["train_loss"][: window + rolling_start]),
    np.var(model.plain_log["train_loss"][: window + rolling_start]),
    np.max(model.plain_log["train_loss"][: window + rolling_start]),
)

plt.savefig(experiment_results_dir / "train_loss.png")

## Model Testing/Evaluation

### Testing Loop

In [None]:
# Store test reconstructions along the way for later visualization.
test_vol_viz = Addict()

# Structure is as follows:
# {subj_id_1:
#    fr_mask: np.ndarray,
#    dti: {
#          diqt: np.ndarray,
#          spline: np.ndarray,
#          fr: np.ndarray,
#          lr: np.ndarray
#          ...
#         }
#    fa: {
#         diqt: np.ndarray,
#         ...
#        },
# pitn.data.norm.denormalize_batch(y, mean=fr_means, var=fr_vars, eps=self._norm_eps)
#  subj_id_2:
#     ....
# }

In [None]:
# # Debug code for model editing after training has been completed.

# # If the model needs to be created, a.k.a. the training loop cell wasn't executed.
# model = DIQTSystem(
#     channels=channels,
#     downsample_factor=exp_params.downsample_factor,
#     source_vox_size=exp_params.source_vox_size,
#     target_vox_size=exp_params.target_vox_size,
#     norm_method=network_norm_method,
#     train_loss_method=train_loss_name,
#     opt_params=opt_params,
#     val_viz_subj_id=val_viz_subj_id,
#     val_viz_bboxes=bbox_coords,
#     val_viz_every_n_epochs=5,
# )

# model.load_from_checkpoint(
#     experiment_results_dir / "model.ckpt",,
#     channels=channels,
#     downsample_factor=exp_params.downsample_factor,
#     source_vox_size=exp_params.source_vox_size,
#     target_vox_size=exp_params.target_vox_size,
#     norm_method=network_norm_method,
#     train_loss_method=train_loss_name,
#     opt_params=opt_params,
#     val_viz_subj_id=val_viz_subj_id,
#     val_viz_bboxes=bbox_coords,
#     val_viz_every_n_epochs=5,
# )

# trainer.test(model, dataloaders=test_loader, ckpt_path=None, verbose=True)

In [None]:
trainer.test(model, test_dataloaders=test_loader, ckpt_path=None, verbose=True)

In [None]:
# Store DIQT test reconstructions for visualization.
for subj in test_dataset:
    subj_id = subj.subj_id
    test_vol_viz[subj_id].dti.diqt = (
        model.plain_log.viz.test_preds[subj_id].cpu().numpy()
    )

In [None]:
test_vol_viz.keys()

In [None]:
test_loss_name = "RMSE"
test_loss_log_file = experiment_results_dir / "test_loss.csv"

with open(log_txt_file, "a+") as f:
    f.write(f"Test loss function: {test_loss_name}\n")

row_iter = enumerate(list(model.plain_log["test_loss"].values()))
test_loss_tabular = "".join([f"{batch_idx}, {loss}\n" for batch_idx, loss in row_iter])

with open(test_loss_log_file, "a+") as f:
    f.write("batch_idx, loss\n")
    f.write(test_loss_tabular)

### Spline Baseline

In [None]:
spline_test_log = dict()

for subj in test_dataset:
    print("---")
    target_shape = np.asarray(subj["gt_dti"]["data"].cpu().numpy().shape[1:])
    # Fractional downsampling can't have a perfect correspondance between source shape and
    # target shape, so calculate the "actual" downsample done by the NN and have the
    # spline match that downsample factor.
    downsample_adjusted = target_shape / np.asarray(
        subj["lr_dti"]["data"].numpy().shape[1:]
    )

    interp_spline = scipy.ndimage.zoom(
        subj["source_lr_dti"]["data"].cpu().numpy(),
        zoom=(1,) + tuple(downsample_adjusted),
        order=exp_params.spline.order,
    )
    # 0-out everything not in the mask, for both visualization and quantification.
    interp_spline = interp_spline * subj["fr_brain_mask"]["data"].bool().cpu().numpy()

    subj_id = subj["subj_id"]
    print(f"Subj {subj_id} done")

    spline_test_log[subj_id] = interp_spline

In [None]:
spline_loss = list()

# Calculate spline loss for test images.
for subj in test_dataset:
    subj_id = subj["subj_id"]
    spline_pred = spline_test_log[subj_id]

    gt = subj["gt_dti"]["data"]
    gt_means = torch.from_numpy(subj["fr_means"])
    gt_vars = torch.from_numpy(subj["fr_vars"])
    lr_means = torch.from_numpy(subj["lr_means"])
    lr_vars = torch.from_numpy(subj["lr_vars"])

    brain_mask = subj["fr_brain_mask"]["data"].bool().cpu().numpy()[0]
    #     if (
    #         exp_params.data.data_norm_method is not None
    #         and "channel" in exp_params.data.data_norm_method.casefold()
    #     ):
    #         print("De-normalizing")
    #         spline_pred = torch.from_numpy(spline_pred)
    #         spline_pred = pitn.data.norm.denormalize_batch(
    #             spline_pred, mean=lr_means, var=lr_vars, eps=1e-10
    #         )

    #         spline_pred = spline_pred.detach().cpu().numpy()

    gt = gt.detach().cpu().numpy()

    # Calculate the RMSE of just the values found in the mask.
    se = (gt - spline_pred) ** 2
    se = se[:, brain_mask]
    loss = np.sqrt(se.mean())
    spline_loss.append(loss)

    # Store spline test reconstructions for visualization.
    test_vol_viz[subj_id].dti.spline = spline_pred.copy()

# Find the grand mean of the spline RMSE's
spline_loss_mean = np.mean(spline_loss)
print(spline_loss_mean)

### Evaluation Visualization

In [None]:
# Plot testing loss values over all patches.
fig, ax_prob = plt.subplots(figsize=(8, 4), dpi=120)
log_scale = False

hist = sns.histplot(
    list(model.plain_log["test_loss"].values()),
    alpha=0.5,
    stat="count",
    log_scale=log_scale,
    ax=ax_prob,
    legend=False,
    hatch="\\\\",
    ec="blue",
)
hist.yaxis.set_major_locator(mpl.ticker.MaxNLocator(integer=True))
plt.xlabel("Loss in $mm^2/second$")

# Draw means of different comparison models.
comparison_kwargs = {"ls": "-", "lw": 2.5}
# Plot the current DNN model performance.
plt.axvline(
    np.asarray(list(model.plain_log["test_loss"].values())).mean(),
    label="Current Model Mean",
    color="blue",
    **comparison_kwargs,
)
# Our spline mean performance.
plt.axvline(
    spline_loss_mean,
    label=f"(Ours) Spline Mean Order {exp_params.spline.order}",
    color="black",
    **comparison_kwargs,
)
sns.histplot(
    spline_loss,
    alpha=0.5,
    stat="count",
    log_scale=log_scale,
    ax=ax_prob,
    legend=False,
    color="black",
    hatch="//",
)

# Tanno, et. al., 2021 model comparisons.
# Taken from Table 2, HCP exterior.
plt.axvline(
    31.738e-4,
    label="(Tanno etal, 2021) C-spline Mean",
    color="red",
    **comparison_kwargs,
)
plt.axvline(
    23.139e-4,
    label="(Tanno etal, 2021) RF",
    color="orange",
    **comparison_kwargs,
)
plt.axvline(
    13.609e-4,
    label="(Tanno etal, 2021) ESPCN Baseline",
    color="green",
    **comparison_kwargs,
)
plt.axvline(
    13.412e-4,
    label="(Tanno etal, 2021) Best",
    color="purple",
    **comparison_kwargs,
)
# Best performing Blumberg, et. al., 2018 paper model.
plt.axvline(
    12.78e-4,
    label="(Blumberg etal, 2018) Best",
    color="pink",
    **comparison_kwargs,
)

plt.legend(fontsize="small")
plt.title(f"Test Loss Histogram Over All Subjects with Test Metric {test_loss_name}")
plt.savefig(experiment_results_dir / "test_loss_hist.png")

In [None]:
# Plot testing loss values over all subjects.
fig, ax = plt.subplots(figsize=(8, 4), dpi=120)

models = (
    "(Ours)\nCurrent Model",
    f"(Ours)\nSpline Order {exp_params.spline.order}",
    "(Tanno etal, 2021)\nC-spline Mean",
    "(Tanno etal, 2021)\nRF",
    "(Tanno etal, 2021)\nESPCN Baseline",
    "(Tanno etal, 2021)\nBest",
    "(Blumberg etal, 2018)\nBest",
)

rmse_scores = (
    np.asarray(list(model.plain_log["test_loss"].values())).mean(),
    spline_loss_mean,
    31.738e-4,
    23.139e-4,
    13.609e-4,
    13.412e-4,
    12.13e-4,
)
rmse_std_error = np.asarray([0, 0, 0, 0.351e-4, 0.084e-4, 0.041e-4, 1.24e-4])

ax.grid(True, axis="y", zorder=1000)
ax.set_axisbelow(True)
ax.bar(
    models,
    rmse_scores,
    yerr=rmse_std_error,
    color=sns.color_palette("deep", n_colors=len(rmse_scores)),
    edgecolor="black",
    lw=0.75,
)

for container in ax.containers:
    if isinstance(container, mpl.container.BarContainer):
        ax.bar_label(container, fmt="%.3e")

ax.set_ylim(bottom=0, top=ax.get_ylim()[1] * 1.1)
ax.set_xlabel("Model")
ax.set_ylabel("Loss in $mm^2/second$")

ax.set_title(f"Mean Over Subjects Test Loss {test_loss_name}")
ax.set_xticks(models)
ax.set_xticklabels(
    models, fontsize="x-small", rotation=25, ha="right", rotation_mode="anchor"
)
plt.savefig(experiment_results_dir / "test_loss_bar.png")

In [None]:
print(model.plain_log["test_loss"])
print(np.mean(list(model.plain_log["test_loss"].values())))

In [None]:
sorted_test_idx = np.argsort(np.asarray(list(model.plain_log["test_loss"].values())))
sorted_test_results = dict(
    list(model.plain_log["test_loss"].items())[i] for i in sorted_test_idx
)
ppr(sorted_test_results, sort_dicts=False)

In [None]:
print(spline_loss)
print(spline_loss_mean)

In [None]:
logger.add_histogram(
    "test/rmse_dist", np.asarray(list(model.plain_log["test_loss"].values()))
)

### Save out metrics and Hyperparameters

In [None]:
compare_hparams.metric["hparam/rmse"] = np.mean(
    list(model.plain_log["test_loss"].values())
)
with open(log_txt_file, "a+") as f:
    f.write(
        f"Mean RMSE Testing Value: {np.mean(list(model.plain_log['test_loss'].values()))}\n"
    )
    f.write(f"Mean RMSE Spline Value: {spline_loss_mean}\n")

logger.add_scalar("metric/rmse", np.mean(list(model.plain_log["test_loss"].values())))
logger.add_scalar("metric/spline", spline_loss_mean)

## Whole-Volume Visualization

### Setup

In [None]:
# Debug flag(s)
disable_fig_save = False

In [None]:
# Create full 3D volume of full-res ground truth, low-res downsample, full-res mask, and
# full-res predictions.

with torch.no_grad():

    for subj_i in test_dataset:
        subj = copy.copy(subj_i)
        subj_id = subj["subj_id"]
        print(f"Starting subject {subj_id}")

        # Collect all variants of the volume and aggregate into one container object.
        fr_vol = subj["gt_dti"]["data"].clone()
        lr_vol = subj["lr_dti"]["data"].clone()

        full_res_predicted = torch.from_numpy(test_vol_viz[subj_id].dti.diqt.copy())
        full_res_spline = torch.from_numpy(test_vol_viz[subj_id].dti.spline.copy())

        #         warnings.warn("======== Skipping all de-normalization for visualization.")
        #         if False:
        if data_norm_method is not None and "channel" in data_norm_method.casefold():
            print("Normalizing")

            lr_means = torch.as_tensor(subj["lr_means"]).to(lr_vol).clone()
            lr_vars = torch.as_tensor(subj["lr_vars"]).to(lr_vol).clone()

            lr_vol = pitn.data.norm.denormalize_dti(
                lr_vol, mean=lr_means, var=lr_vars
            ).clone()

        # Zero-out all voxels outside the mask.

        fr_mask = subj["fr_brain_mask"]["data"].bool().clone()
        full_res_spline = full_res_spline * fr_mask.to(full_res_spline).bool()
        fr_vol = fr_vol * fr_mask.to(fr_vol).bool()
        full_res_predicted = full_res_predicted * fr_mask.to(full_res_predicted).bool()
        lr_vol = lr_vol * subj["lr_brain_mask"]["data"].to(lr_vol).bool()
        abs_error = torch.abs(full_res_predicted - fr_vol)

        test_vol_viz[subj.subj_id].dti.update(
            fr=fr_vol.cpu().numpy(),
            lr=lr_vol.cpu().numpy(),
            diqt=full_res_predicted.cpu().numpy(),
            spline=full_res_spline.cpu().numpy(),
        )

        test_vol_viz[subj.subj_id].fa.update(
            itertools.starmap(
                lambda k, v: (k, pitn.viz.fa_map(test_vol_viz[subj_id].dti[k])),
                test_vol_viz[subj_id].dti.items(),
            )
        )
        test_vol_viz[subj_id].fr_mask = fr_mask.cpu().numpy()
        test_vol_viz[subj_id].dti.abs_error = abs_error.cpu().numpy()

        print(f"Finished subject {subj_id}")

In [None]:
# Save out all network predictions to Nifti2 files and compress them into a zip archive.
if not disable_fig_save:
    img_names = list()
    for subj_id, viz in test_vol_viz.items():
        pred_vol = viz.dti.diqt
        affine = subj_data[subj_id]["gt_dti"].affine
        nib_img = nib.Nifti2Image(pred_vol, affine)

        filename = experiment_results_dir / f"{subj_id}_predicted_dti.nii.gz"
        nib.save(nib_img, str(filename))
        img_names.append(filename)

    with zipfile.ZipFile(experiment_results_dir / "predicted_dti.zip", "w") as fzip:
        for filename in img_names:
            fzip.write(
                filename,
                arcname=filename.name,
                compress_type=zipfile.ZIP_DEFLATED,
                compresslevel=6,
            )
            os.remove(filename)
    # Make sure we exit the 'with' statement above.
    print("Done with files")

In [None]:
viz_subj_idx = random.choice(list(range(len(test_vol_viz.keys()))))
# viz_subj_id = list(test_vol_viz.keys())[viz_subj_idx]
# Pick the worst performing subject from the test set.
viz_subj_id = list(sorted_test_results.keys())[-1]
print(list(test_vol_viz.keys()))
print(viz_subj_id)

### FA-Weighted Direction Maps

In [None]:
# Generate FA-weighted diffusion direction map for prediction.
pred_dir_map = pitn.viz.direction_map(test_vol_viz[viz_subj_id].dti.diqt)
# Set channels last for matplotlib
pred_dir_map = pred_dir_map.transpose(1, 2, 3, 0)
test_vol_viz[viz_subj_id].color_fa.diqt = pred_dir_map

spline_dir_map = pitn.viz.direction_map(test_vol_viz[viz_subj_id].dti.spline)
# Set channels last for matplotlib
spline_dir_map = spline_dir_map.transpose(1, 2, 3, 0)
test_vol_viz[viz_subj_id].color_fa.spline = spline_dir_map

fr_dir_map = pitn.viz.direction_map(test_vol_viz[viz_subj_id].dti.fr)
# Set channels last for matplotlib
fr_dir_map = fr_dir_map.transpose(1, 2, 3, 0)
test_vol_viz[viz_subj_id].color_fa.fr = fr_dir_map

lr_dir_map = pitn.viz.direction_map(test_vol_viz[viz_subj_id].dti.lr)
# Set channels last for matplotlib
lr_dir_map = lr_dir_map.transpose(1, 2, 3, 0)
test_vol_viz[viz_subj_id].color_fa.lr = lr_dir_map

In [None]:
slice_idx = (slice(None, None, None), slice(None, None, None), 86)
low_res_slice_idx = tuple(
    int(np.round(s / downsample_factor)) if isinstance(s, int) else s for s in slice_idx
)
print(slice_idx)
print(low_res_slice_idx)

In [None]:
plt.figure(dpi=150)
plt.imshow(np.rot90(pred_dir_map[slice_idx]), interpolation="none")
plt.axis("off")
plt.title("Predicted with DIQT Net")
if not disable_fig_save:
    plt.savefig(experiment_results_dir / "pred_dir_map_sample.png");

In [None]:
plt.figure(dpi=150)
plt.imshow(np.rot90(fr_dir_map[slice_idx]), interpolation="none")
# plt.colorbar()
plt.axis("off")
plt.title("Ground Truth")
if not disable_fig_save:
    plt.savefig(experiment_results_dir / "ground_truth_dir_map_sample.png")

In [None]:
plt.figure(dpi=150)
plt.imshow(np.rot90(spline_dir_map[slice_idx]), interpolation="none")
plt.axis("off")
plt.title(f"Spline Interpolation Order {exp_params.spline.order}")
if not disable_fig_save:
    plt.savefig(experiment_results_dir / "cubic_spline_dir_map_sample.png");

In [None]:
plt.figure(dpi=150)
plt.imshow(np.rot90(lr_dir_map[low_res_slice_idx]), interpolation="none")
plt.axis("off")
plt.title("LR Input")
if not disable_fig_save:
    plt.savefig(experiment_results_dir / "low_res_map_sample.png");

### DTI Channel-Wise Visualization

In [None]:
channel_names = [
    "$D_{x,x}$",
    "$D_{x,y}$",
    "$D_{y,y}$",
    "$D_{x,z}$",
    "$D_{y,z}$",
    "$D_{z,z}$",
]
dti_names = [
    "Full-Res",
    "Low-Res Input",
    f"Spline Order {exp_params.spline.order}",
    "Predicted",
    "Absolute Error\nFR vs. Predicted",
]

#### Global Normalization

In [None]:
# Display all 6 DTIs for ground truth, predicted, and root squared error

cmap = "jet"

dtis = [
    test_vol_viz[viz_subj_id].dti.fr[(slice(None), *slice_idx)],
    test_vol_viz[viz_subj_id].dti.lr[(slice(None), *low_res_slice_idx)],
    test_vol_viz[viz_subj_id].dti.spline[(slice(None), *slice_idx)],
    test_vol_viz[viz_subj_id].dti.diqt[(slice(None), *slice_idx)],
    test_vol_viz[viz_subj_id].dti.abs_error[(slice(None), *slice_idx)],
]

nrows = len(dtis)
ncols = len(channel_names)

# Don't take the absolute max and min values, as there exist some extreme (e.g., > 3
# orders of magnitude) outliers. Instead, take some percente quantile.
# Reshape and concatenate the dtis in order to compute the quantiles of images with
# different shapes (e.g., the low-res input patch).
max_dti = np.quantile(np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.95)
min_dti = np.quantile(np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.05)

nrows = len(dtis)
ncols = len(channel_names)

fig = plt.figure(figsize=(12, 6), dpi=160)

grid = mpl.gridspec.GridSpec(
    nrows,
    ncols,
    figure=fig,
    hspace=0.05,
    wspace=0.05,
)
axs = list()
max_subplot_height = 0
for i_row in range(nrows):
    dti = dtis[i_row]

    for j_col in range(ncols):
        ax = fig.add_subplot(grid[i_row, j_col])
        ax.imshow(
            np.rot90(dti[j_col]),
            cmap=cmap,
            interpolation=None,
            vmin=min_dti,
            vmax=max_dti,
        )
        if ax.get_subplotspec().is_first_col():
            ax.set_ylabel(dti_names[i_row], size="small")
        if ax.get_subplotspec().is_last_row():
            ax.set_xlabel(channel_names[j_col])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        # Update highest subplot to put the `suptitle` later on.
        max_subplot_height = max(
            max_subplot_height, ax.get_position(original=False).get_points()[1, 1]
        )
        axs.append(ax)

color_norm = mpl.colors.Normalize(vmin=min_dti, vmax=max_dti)
fig.colorbar(
    mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap),
    ax=axs,
    location="right",
    fraction=0.1,
    pad=0.03,
)
plt.suptitle(
    "DTI Channel Breakdown, Normalized over All Images",
    y=max_subplot_height + 0.015,
    verticalalignment="bottom",
)
if not disable_fig_save:
    plt.savefig(experiment_results_dir / "DTI_channel_sample_global_norm.png");

#### Channel-Wise Normalization

In [None]:
# Display all 6 DTIs for ground truth, predicted, low-res input, and root squared error
# Normalize by index in the DTI coefficients.
# Reshape and concatenate the dtis in order to compute the quantiles of images with
# different shapes (e.g., the low-res input patch).

cmap = "coolwarm"
dtis = [
    test_vol_viz[viz_subj_id].dti.fr[(slice(None), *slice_idx)],
    test_vol_viz[viz_subj_id].dti.lr[(slice(None), *low_res_slice_idx)],
    test_vol_viz[viz_subj_id].dti.spline[(slice(None), *slice_idx)],
    test_vol_viz[viz_subj_id].dti.diqt[(slice(None), *slice_idx)],
    test_vol_viz[viz_subj_id].dti.abs_error[(slice(None), *slice_idx)],
]

nrows = len(dtis)
ncols = len(channel_names)

# Don't take the absolute max and min values, as there exist some extreme (e.g., > 3
# orders of magnitude) outliers. Instead, take some percente quantile.
max_dtis = np.quantile(
    np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.95, axis=1
)
min_dtis = np.quantile(
    np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.05, axis=1
)

max_dtis = np.max(np.abs([max_dtis, min_dtis]), axis=0)
min_dtis = -1 * max_dtis

nrows = len(dtis)
ncols = len(channel_names)

fig = plt.figure(figsize=(12, 7), dpi=160)

grid = mpl.gridspec.GridSpec(
    nrows,
    ncols,
    figure=fig,
    hspace=0.05,
    wspace=0.05,
)

axs = list()
max_subplot_height = 0
for i_row in range(nrows):
    dti = dtis[i_row]
    axs_cols = list()

    for j_col in range(ncols):
        ax = fig.add_subplot(grid[i_row, j_col])
        ax.imshow(
            np.rot90(dti[j_col]),
            cmap=cmap,
            interpolation=None,
            vmin=min_dtis[j_col],
            vmax=max_dtis[j_col],
        )
        if ax.get_subplotspec().is_first_col():
            ax.set_ylabel(dti_names[i_row], size="small")
        if ax.get_subplotspec().is_last_row():
            ax.set_xlabel(channel_names[j_col])

        # Update highest subplot to put the `suptitle` later on.
        max_subplot_height = max(
            max_subplot_height, ax.get_position(original=False).get_points()[1, 1]
        )
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])

        axs_cols.append(ax)

    axs.append(axs_cols)

# Place colorbars on each column.
for j_col in range(ncols):

    full_col_ax = [axs[i][j_col] for i in range(nrows)]

    color_norm = mpl.colors.Normalize(vmin=min_dtis[j_col], vmax=max_dtis[j_col])

    color_mappable = mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap)
    cbar = fig.colorbar(
        color_mappable,
        ax=full_col_ax,
        location="top",
        orientation="horizontal",
        pad=0.01,
        shrink=0.85,
    )
    cbar.ax.tick_params(labelsize=8, rotation=35)
    cbar.ax.xaxis.set_major_formatter(mpl.ticker.StrMethodFormatter("{x:g}"))
#     cbar.ax.ticklabel_format(scilimits=(3, -3), useOffset=False)

plt.suptitle(
    "DTI Channel Breakdown, Channel-Wise Normalization",
    y=max_subplot_height + 0.01,
    verticalalignment="bottom",
)
if not disable_fig_save:
    plt.savefig(experiment_results_dir / "DTI_channel_sample_channel_wise_norm.png");

### FA Map

In [None]:
# Slice locations for 2D visualization

half_fr_space_shape = np.floor(
    np.asarray(test_vol_viz[viz_subj_id].fa.fr.shape) // 2
).astype(int)
half_fr_space_shape = half_fr_space_shape.tolist()

slice_indices = [
    (half_fr_space_shape[0], slice(None, None, None), slice(None, None, None)),
    (slice(None, None, None), half_fr_space_shape[1], slice(None, None, None)),
    (slice(None, None, None), slice(None, None, None), half_fr_space_shape[2]),
]

low_res_slice_indices = list()
for slice_idx_i in slice_indices:
    slice_coords = tuple()
    for s in slice_idx_i:
        slice_coords = slice_coords + (
            int(np.floor(s / downsample_factor)) if isinstance(s, int) else s,
        )
    low_res_slice_indices.append(slice_coords)

print(slice_indices)
print()
print(low_res_slice_indices)

row_names = [
    "Saggital",
    "Coronal",
    "Horizontal",
]

model_names = [
    "Full-Res",
    "Low-Res Input",
    f"Spline Order {exp_params.spline.order}",
    "Predicted",
    #     "Absolute Error\nFR vs. Predicted",
]

nrows = len(row_names)
ncols = len(model_names)

In [None]:
imgs = dict()
for i_row, slice_i in enumerate(slice_indices):
    low_res_slice_i = low_res_slice_indices[i_row]
    imgs[i_row] = dict()
    col_imgs = [
        test_vol_viz[viz_subj_id].fa.fr[(*slice_i,)],
        test_vol_viz[viz_subj_id].fa.lr[(*low_res_slice_i,)],
        test_vol_viz[viz_subj_id].fa.spline[(*slice_i,)],
        test_vol_viz[viz_subj_id].fa.diqt[(*slice_i,)],
        #         test_vol_viz[viz_subj_id].fa.abs_error[(*slice_i,)],
    ]

    imgs[i_row].update(tuple(enumerate(col_imgs)))

#### Global Normalization

In [None]:
# Display all 6 DTIs for ground truth, predicted, and root squared error

cmap = "gist_gray"

# Don't take the absolute max and min values, as there exist some extreme (e.g., > 3
# orders of magnitude) outliers. Instead, take some percente quantile.
# Concatenate the images in order to compute the quantiles of images with
# different shapes (e.g., the low-res input patch).
max_intensity = np.quantile(
    np.concatenate(
        [imgs[i][j] for (i, j) in itertools.product(range(nrows), range(ncols))],
        axis=None,
    ),
    0.95,
)
min_intensity = np.quantile(
    np.concatenate(
        [imgs[i][j] for (i, j) in itertools.product(range(nrows), range(ncols))],
        axis=None,
    ),
    0.05,
)

fig = plt.figure(figsize=(5 * 1.5, 3 * 1.5), dpi=160)

grid = mpl.gridspec.GridSpec(
    nrows,
    ncols,
    figure=fig,
    hspace=0.05,
    wspace=0.05,
)
axs = list()
max_subplot_height = 0
for i_row in range(nrows):
    img_row = imgs[i_row]

    for j_col in range(ncols):
        ax = fig.add_subplot(grid[i_row, j_col])
        ax.imshow(
            np.rot90(img_row[j_col]),
            cmap=cmap,
            interpolation=None,
            vmin=min_intensity,
            vmax=max_intensity,
        )
        if ax.get_subplotspec().is_first_col():
            ax.set_ylabel(row_names[i_row], size="small")
        if ax.get_subplotspec().is_last_row():
            ax.set_xlabel(model_names[j_col], size="small")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        # Update highest subplot to put the `suptitle` later on.
        max_subplot_height = max(
            max_subplot_height, ax.get_position(original=False).get_points()[1, 1]
        )
        axs.append(ax)

color_norm = mpl.colors.Normalize(vmin=min_intensity, vmax=max_intensity)
fig.colorbar(
    mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap),
    ax=axs,
    location="right",
    fraction=0.1,
    pad=0.03,
)
plt.suptitle(
    "FA Maps, Normalized over All Images",
    y=max_subplot_height + 0.015,
    verticalalignment="bottom",
)
if not disable_fig_save:
    plt.savefig(experiment_results_dir / "fa_sample_global_norm.png");

---

## End Experiment

In [None]:
pl_logger.experiment.flush()
# Close tensorboard logger.
# Don't finalize if the experiment was for debugging.
if "debug" not in EXPERIMENT_NAME.casefold():
    pl_logger.finalize("success")
    # Experiment is complete, move the results directory to its final location.
    if experiment_results_dir != final_experiment_results_dir:
        print("Moving out of tmp location")
        experiment_results_dir = experiment_results_dir.rename(
            final_experiment_results_dir
        )
        log_txt_file = experiment_results_dir / log_txt_file.name