# Pain in the Net
Replication of *Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images*


Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

---

Source work:
`S. B. Blumberg, R. Tanno, I. Kokkinos, and D. C. Alexander, “Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images,” in Medical Image Computing and Computer Assisted Intervention – MICCAI 2018, Cham, 2018, pp. 118–125, doi: 10.1007/978-3-030-00928-1_14.`


# Imports & Environment Setup

## Imports

In [1]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 1

# imports
import collections
import dataclasses
from dataclasses import dataclass
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
from typing import Generator

import ants
import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

import IPython

# Try importing GPUtil for printing GPU specs.
# May not be installed if using CPU only.
try:
    import GPUtil
except ImportError:
    warnings.warn("WARNING: Package GPUtil not found, cannot print GPU specs")
from tabulate import tabulate
from IPython.display import display, Markdown

# Data management libraries.
import nibabel as nib
import nilearn
import nilearn.plotting
import natsort
from natsort import natsorted
import addict
from addict import Addict
import pprint
from pprint import pprint as ppr

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchvision
import torchio
import pytorch_lightning as pl

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy
import einops
import einops.layers
import einops.layers.torch

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)


Fetchers from the nilearn.datasets module will be updated in version 0.9 to return python strings instead of bytes and Pandas dataframes instead of Numpy arrays.



In [2]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [3]:
# Project-specific scripts
# It's easier to import it this way rather than make an entirely new package, due to
# conflicts with local packages and anaconda installations.
# You made me do this, poor python package management!!
if "PROJECT_ROOT" in os.environ:
    lib_location = str(Path(os.environ["PROJECT_ROOT"]).resolve())
else:
    lib_location = str(Path("../../").resolve())
if lib_location not in sys.path:
    sys.path.insert(0, lib_location)
import lib as pitn

# Include the top-level lib module along with its submodules.
%aimport lib
# Grab all submodules of lib, not including modules outside of the package.
includes = list(
    filter(
        lambda m: m.startswith("lib."),
        map(lambda x: x[1].__name__, inspect.getmembers(pitn, inspect.ismodule)),
    )
)
# Run aimport magic with constructed includes.
ipy = IPython.get_ipython()
ipy.run_line_magic("aimport", ", ".join(includes))

In [4]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

cuda


## Specs Recording

In [5]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():

    # GPU information
    # Taken from
    # <https://www.thepythoncode.com/article/get-hardware-system-information-python>.
    # If GPUtil is not installed, skip this step.
    try:
        gpus = GPUtil.getGPUs()
        print("=" * 50, "GPU Specs", "=" * 50)
        list_gpus = []
        for gpu in gpus:
            # get the GPU id
            gpu_id = gpu.id
            # name of GPU
            gpu_name = gpu.name
            driver_version = gpu.driver
            cuda_version = torch.version.cuda
            # get total memory
            gpu_total_memory = f"{gpu.memoryTotal}MB"
            gpu_uuid = gpu.uuid
            list_gpus.append(
                (
                    gpu_id,
                    gpu_name,
                    driver_version,
                    cuda_version,
                    gpu_total_memory,
                    gpu_uuid,
                )
            )

        print(
            tabulate(
                list_gpus,
                headers=(
                    "id",
                    "Name",
                    "Driver Version",
                    "CUDA Version",
                    "Total Memory",
                    "uuid",
                ),
            )
        )
    except NameError:
        print("CUDA Version: ", torch.version.cuda)

else:
    print("CUDA not in use, falling back to CPU")

In [6]:
# cap is defined in an ipython magic command
print(cap)

Author: Tyler Spears

Last updated: 2021-05-24T21:19:42.635833+00:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.22.0

Compiler    : GCC 7.3.0
OS          : Linux
Release     : 5.4.0-72-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

Git hash: a2b50181b9ffcf98c7a835a7ae2cf1293e4ea738

numpy            : 1.20.2
einops           : 0.3.0
ipywidgets       : 7.6.3
addict           : 2.4.0
ants             : 0.2.7
torchvision      : 0.9.1
sys              : 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0]
nibabel          : 3.2.1
nilearn          : 0.7.1
pytorch_lightning: 1.3.2
IPython          : 7.22.0
skimage          : 0.18.1
matplotlib       : 3.4.1
json             : 2.0.9
pandas           : 1.2.3
torch            : 1.8.1
seaborn          : 0.11.1
natsort          : 7.1.1
GPUtil           : 1.4.0
scipy            : 1.5.3
dipy             : 1.4.0
torchio          : 0.18.37

  id  Name       Driver Versio

## Data Variables & Definitions Setup

In [7]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"]) / "hcp"
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"]) / "hcp"
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()

## Experiment Logging Setup

In [8]:
# tensorboard experiment logging setup.
EXPERIMENT_NAME = "dev_tb"

ts = datetime.datetime.now().replace(microsecond=0).isoformat()
# Break ISO format because many programs don't like having colons ':' in a filename.
ts = ts.replace(":", "_")
experiment_name = ts + "__" + EXPERIMENT_NAME
run_name = experiment_name
# experiment_results_dir = results_dir / experiment_name

# Create temporary directory for results directory, in case experiment does not finish.
tmp_results_dir = results_dir / "tmp"
tmp_dirs = list(tmp_results_dir.glob("*"))

# Only keep up to N tmp results.
n_tmp_to_keep = 3
if len(tmp_dirs) > (n_tmp_to_keep - 1):
    print(f"More than {n_tmp_to_keep} temporary results, culling to the most recent")
    tmps_to_delete = natsorted([str(tmp_dir) for tmp_dir in tmp_dirs])[
        : -(n_tmp_to_keep - 1)
    ]
    for tmp_dir in tmps_to_delete:
        shutil.rmtree(tmp_dir)
        print("Deleted temporary results directory ", tmp_dir)

experiment_results_dir = tmp_results_dir / experiment_name
# Final target directory, to be made when experiment is complete.
final_experiment_results_dir = results_dir / experiment_name

More than 3 temporary results, culling to the most recent
Deleted temporary results directory  /home/jovyan/work/pitn/results/tmp/2021-05-20T16_21_43__debug_norm


In [9]:
# Pass this object into the pytorchlightning Trainer object, for easier logging within
# the training/testing loops.
pl_logger = pl.loggers.TensorBoardLogger(
    tmp_results_dir, name=experiment_name, version="", log_graph=False
)
# Use the lower-level logger for logging histograms, images, etc.
logger = pl_logger.experiment

# Create a separate txt file to log streams of events & info besides parameters & results.
log_txt_file = Path(logger.log_dir) / "log.txt"
with open(log_txt_file, "a+") as f:
    f.write(f"Experiment Name: {experiment_name}\n")
    f.write(f"Timestamp: {ts}\n")
    # cap is defined in an ipython magic command
    f.write(f"Environment and Hardware Info:\n {cap}\n\n")

In [10]:
# # Experiment logging setup
# EXPERIMENT_NAME = "debug_norm"

# ts = datetime.datetime.now().replace(microsecond=0).isoformat()
# # Break ISO format because many programs don't like having colons ':' in a filename.
# ts = ts.replace(":", "_")
# experiment_name = ts + "__" + EXPERIMENT_NAME

# # Create temporary directory for results directory, in case experiment does not finish.
# tmp_results_dir = results_dir / "tmp"
# tmp_dirs = list(tmp_results_dir.glob("*"))

# # Only keep up to N tmp results.
# n_tmp_to_keep = 3
# if len(tmp_dirs) > (n_tmp_to_keep - 1):
#     print(f"More than {n_tmp_to_keep} temporary results, culling to the most recent")
#     tmps_to_delete = natsorted([str(tmp_dir) for tmp_dir in tmp_dirs])[
#         : -(n_tmp_to_keep - 1)
#     ]
#     for tmp_dir in tmps_to_delete:
#         shutil.rmtree(tmp_dir)
#         print("Deleted temporary results directory ", tmp_dir)

# experiment_results_dir = tmp_results_dir / experiment_name
# # Final target directory, to be made when experiment is complete.
# final_experiment_results_dir = results_dir / experiment_name

# (experiment_results_dir).mkdir(parents=True, exist_ok=True)
# print(
#     "Experiment results directory: ",
#     experiment_results_dir,
# )
# assert experiment_results_dir.exists()

# experiment_results_log = experiment_results_dir / "log.txt"
# with open(experiment_results_log, "a+") as f:
#     f.write(f"Experiment Name: {experiment_name}\n")
#     f.write(f"Timestamp: {ts}\n")
#     # cap is defined in an ipython magic command
#     f.write(f"Environment and Hardware Info:\n {cap}\n\n")

# Parameters and Function Definitions

## Parameters

In [11]:
# Dict to keep track of experiment configuration parameters. Will not be logged to
# tensorboard.
exp_params = Addict()
# Dict to keep track of tensorboard hparams that we *specifically* want to compare
# between runs.
compare_hparams = Addict(hparam=Addict(), metric=Addict())

In [12]:
downsample_factor = 2
# Include b=0 shells and b=1000 shells for DTI fitting.
bval_range = (0, 1500)
dti_fit_method = "WLS"
exp_params.update(
    {
        "downsample_factor": downsample_factor,
        "bval_range": bval_range,
        "dti_fit_method": dti_fit_method,
    }
)

### Patch Parameters

In [13]:
# Patch parameters
batch_size = 32
# 6 channels for the 6 DTI components
channels = 6

# Output patch shapes
h_out = 14
w_out = 14
d_out = 14

# This is the factor that determines how over-extended the input patch should be
# relative to the size of the full-res patch.
# $low_res_patch_dim = \frac{full_res_patch_dim}{downsample_factor} \times low_res_sample_extension$
# A value of 1 indicates that the input patch dims will be exactly divided by the
# downsample factor. A dilation > 1 increases the "spatial extent" of the input
# patch, providing information outside of the target HR patch.
low_res_sample_extension = 1.57

# Output shape after shuffling.
output_patch_shape = (channels, h_out, w_out, d_out)
output_spatial_patch_shape = output_patch_shape[1:]

# Input patch parameters
h_in = round(h_out / (downsample_factor) * low_res_sample_extension)
w_in = round(w_out / (downsample_factor) * low_res_sample_extension)
d_in = round(d_out / (downsample_factor) * low_res_sample_extension)
input_patch_shape = (channels, h_in, w_in, d_in)
input_spatial_patch_shape = input_patch_shape[1:]

# Pre-shuffle output patch sizes.
unshuffled_channels_out = channels * downsample_factor ** 3
# Output before shuffling
unshuffled_output_patch_shape = (unshuffled_channels_out, h_in, w_in, d_in)

# Patch size in FR-space when accounting for the low-res over-extension/over-sampling
# factor.
fr_extension_patch_size = tuple(
    np.asarray(input_spatial_patch_shape) * downsample_factor
)
fr_extension_amount = tuple(
    np.asarray(fr_extension_patch_size) - np.asarray(output_spatial_patch_shape)
)

In [14]:
exp_params.patch.update(
    batch_size=batch_size,
    channels=channels,
    low_res_sample_extension=low_res_sample_extension,
    input_shape=input_patch_shape,
    output_shape=output_patch_shape,
)

In [15]:
# Data parameters.
num_subject_samples = 4
# Should the data be normalized as a pre-processing step?
data_norm_method = "channels"

In [16]:
exp_params.data.update(
    num_subject=num_subject_samples, data_norm_method=data_norm_method
)

compare_hparams.hparams.update(data_norm=data_norm_method)

### Training and Testing Setup

In [17]:
# Training/testing parameters.
# Percentages will be rounded off to the nearest subject, with the test and validation
# sizes rounded *up*, ensuring at least 1 subject in each.
test_percent = 0.2
val_percent = 0.1
train_percent = 1 - (test_percent + val_percent)

# NN parameters.
max_epochs = 50
network_norm_method = None
train_loss_name = "MSE"

# Optimization parameters.
opt_name = "Adam"
opt_params = {"lr": 10e-3, "betas": (0.9, 0.999)}

# Spline interpolation baseline parameters.
spline_interp_order = 3

In [18]:
# Number of voxels to dilate the mask in FR space.
# Dilate to allow the outer-most perimeter to be completely outside the brain, and no
# more.
dilation_size = math.ceil((max(fr_extension_patch_size) + 1) / 2)
# 12 is too much, nearly doubles the volume of the brain mask. Lower it some more...
dilation_size = dilation_size // 3
# Even 4 is a chunky bit! Just make it 0.
dilation_size = 0

In [19]:
exp_params.update(test_percent=test_percent)
exp_params.train.update(
    train_percent=train_percent, max_epochs=max_epochs, train_loss_name=train_loss_name
)
exp_params.nn.update(network_norm_method=network_norm_method)
exp_params.opt.update(opt_params)
exp_params.opt.name = opt_name
exp_params.spline.update(order=spline_interp_order)
exp_params.preproc.update(dilation_size=dilation_size)

# Save out hyper params that we want to compare between runs.
compare_hparams.hparam.update(
    network_norm_method=network_norm_method, train_loss=train_loss_name
)

In [20]:
with open(log_txt_file, "a+") as f:
    f.write(pprint.pformat(exp_params) + "\n")

## Function Definitions

In [21]:
BoxplotStats = collections.namedtuple(
    "BoxplotStats",
    ["low_outliers", "low", "q1", "median", "q3", "high", "high_outliers"],
)


def batch_boxplot_stats(batch):
    """Quick calculation of a batch of 1D values for showing boxplot stats."""
    q1, median, q3 = np.quantile(batch, q=[0.25, 0.5, 0.75], axis=1)
    iqr = q3 - q1
    low = q1 - (1.5 * iqr)
    high = q3 + (1.5 * iqr)
    low_outliers = list()
    high_outliers = list()
    # Number of outliers may be different for each batch, so it needs to be a list of
    # arrays.
    for i_batch in range(len(batch)):
        batch_i = batch[i_batch]
        low_i = low[i_batch]
        low_outliers.append(batch_i[np.where(batch_i < low_i)])
        high_i = high[i_batch]
        high_outliers.append(batch_i[np.where(batch_i > high_i)])

    return BoxplotStats(low_outliers, low, q1, median, q3, high, high_outliers)

In [22]:
# Quick check on full volume/batch distributions.


def desc_channel_dists(vols, mask=None):
    t_vols = torch.as_tensor(vols)

    if t_vols.ndim == 4:
        t_vols = t_vols[None, ...]

    if mask is not None:
        t_mask = torch.as_tensor(mask)
        if mask.ndim == 4:
            mask = mask[0]
    else:
        t_mask = torch.ones_like(t_vols[0, 0]).bool()

    results = "means | vars\n"
    for t_vol_i in t_vols:
        masked_vol = torch.masked_select(t_vol_i, t_mask).reshape(t_vol_i.shape[0], -1)
        mean_i = torch.mean(masked_vol, dim=1)
        var_i = torch.var(masked_vol, dim=1)
        mvs = [
            (f"{mv[0]} | {mv[1]}\n")
            for mv in torch.stack([mean_i, var_i], dim=-1).tolist()
        ]
        results = results + "".join(mvs)
        results = results + ("=" * (len(mvs[-1]) - 1)) + "\n"

    return results

# Data Loading

## Subject ID Selection

In [23]:
# Find data directories for each subject.
subj_dirs: dict = dict()

selected_ids = [
    "397154",
    "224022",
    "140117",
    "751348",
    "894774",
    "156637",
    "227432",
    "303624",
    "185947",
    "810439",
    "753251",
    "644246",
    "141422",
    "135528",
    "103010",
    "700634",
]

## Sub-set the chosen participants for dev and debugging!
selected_ids = random.sample(selected_ids, num_subject_samples)
warnings.warn(
    "WARNING: Sub-selecting participants for dev and debugging. "
    + f"Subj IDs selected: {selected_ids}"
)
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_ids = natsorted(list(map(lambda s: int(s), selected_ids)))

for subj_id in selected_ids:
    subj_dirs[subj_id] = data_dir / f"{subj_id}/T1w/Diffusion"
    assert subj_dirs[subj_id].exists()
subj_dirs





{224022: PosixPath('/mnt/storage/data/pitn/hcp/224022/T1w/Diffusion'),
 227432: PosixPath('/mnt/storage/data/pitn/hcp/227432/T1w/Diffusion'),
 751348: PosixPath('/mnt/storage/data/pitn/hcp/751348/T1w/Diffusion'),
 810439: PosixPath('/mnt/storage/data/pitn/hcp/810439/T1w/Diffusion')}

The 90 scans are taken from the $b=1000 \ s/mm^2$. However, the $b=0$ shells are still required for fitting the diffusion tensors (DTI's), so those will need to be kept, too.

To find those, sub-select with the $0 < bvals < 1500$, or roughly thereabout. A b-val of $995$ or $1005$ still counts as a b=1000.

In [24]:
with open(log_txt_file, "a+") as f:
    f.write(f"Selected Subjects: {selected_ids}\n")

logger.add_text("subjs", pprint.pformat(selected_ids))

## Loading and Preprocessing

In [25]:
# Set up the transformation pipeline.

preproc_transforms = torchio.Compose(
    [
        torchio.transforms.ToCanonical(include=("dwi", "brain_mask"), copy=False),
        pitn.transforms.BValSelectionTransform(
            bval_range=bval_range,
            bval_key="bvals",
            bvec_key="bvecs",
            include="dwi",
            copy=False,
        ),
        # Pad by the dilation factor, then dilate the mask.
        torchio.transforms.Pad(
            dilation_size,
            padding_mode=0,
            include=("dwi", "brain_mask"),
            copy=False,
        ),
        pitn.transforms.DilateMaskTransform(
            dilation_size=dilation_size, include=("brain_mask",), copy=False
        ),
        # Pad by the amount of extension voxels in FR space, so LR indices cannot
        # go out of bounds.
        torchio.transforms.Pad(
            fr_extension_amount,
            padding_mode=0,
            include=("dwi", "brain_mask"),
            copy=False,
        ),
        # Ensure FR dims are divisible by the downsample factor, to more reliably
        # convert between FR indices and LR indices.
        torchio.transforms.EnsureShapeMultiple(
            downsample_factor, method="pad", include=("dwi", "brain_mask"), copy=False
        ),
        pitn.transforms.MeanDownsampleTransform(
            downsample_factor,
            include=("dwi", "brain_mask"),
            keep={"dwi": "fr_dwi", "brain_mask": "fr_brain_mask"},
            copy=False,
        ),
        pitn.transforms.RenameImageTransform(
            {"dwi": "lr_dwi", "brain_mask": "lr_brain_mask"}, copy=False
        ),
        pitn.transforms.FitDTITransform(
            "bvals",
            "bvecs",
            "fr_brain_mask",
            fit_method=dti_fit_method,
            include=("fr_dwi"),
            copy=False,
        ),
        pitn.transforms.FitDTITransform(
            "bvals",
            "bvecs",
            "lr_brain_mask",
            fit_method=dti_fit_method,
            include=("lr_dwi"),
            copy=False,
        ),
        pitn.transforms.RenameImageTransform(
            {"fr_dwi": "fr_dti", "lr_dwi": "lr_dti"}, copy=False
        ),
        pitn.transforms.ImageToDictTransform(
            include=("lr_dti", "lr_brain_mask"), copy=False
        ),
    ]
)

In [26]:
# Import all image data into a sequence of `torchio.Subject` objects.
subj_data: dict = dict()
summary_stats_fmt = "github"
summary_stats_header = [
    "Subj ID",
    "Resolution",
    "Channel Index",
    "Mean",
    "Var",
    "Num Outliers (Lower)",
    "Low",
    "25th Percentile",
    "Median",
    "75th Percentile",
    "High",
    "Num Outliers (Upper)",
]
dti_channel_names = ["Dxx", "Dxy", "Dyy", "Dxz", "Dyz", "Dzz"]
subj_stats = dict()
# Dictionary to hold the subject's summary statistics if image-level or global
# normalization is used.
norm_subj_stats = dict()
for k in summary_stats_header:
    subj_stats[k] = list()
    norm_subj_stats[k] = list()

for subj_id, subj_dir in subj_dirs.items():
    # Sub-select volumes with only bvals in a certain range. E.x. bvals <= 1100 mm/s^2,
    # a.k.a. only the b=0 and b=1000 shells.
    bvals = torch.as_tensor(np.loadtxt(subj_dir / "bvals").astype(int))
    bvecs = torch.as_tensor(np.loadtxt(subj_dir / "bvecs"))
    # Reshape to be N x 3
    if bvecs.shape[0] == 3:
        bvecs = bvecs.T

    brain_mask = torchio.LabelMap(
        subj_dir / "nodif_brain_mask.nii.gz",
        type=torchio.LABEL,
        channels_last=False,
    )
    brain_mask.set_data(brain_mask.data.bool())
    mask_volume = brain_mask["data"].sum()
    print(f"Brain mask volume before dilation: {mask_volume}")
    dwi = torchio.ScalarImage(
        subj_dir / "data.nii.gz",
        type=torchio.INTENSITY,
        bvals=bvals,
        bvecs=bvecs,
        reader=pitn.io.nifti_reader,
        channels_last=True,
    )

    subject_dict = torchio.Subject(subj_id=subj_id, dwi=dwi, brain_mask=brain_mask)
    preproced_subj = preproc_transforms(subject_dict)
    dilated_mask_volume = preproced_subj["fr_brain_mask"]["data"].sum()
    print(f"Dilated mask volume: {dilated_mask_volume}")
    print(f"Mask volume difference: {dilated_mask_volume - mask_volume}")
    # Subject-and-channel-wise standardization/normalization of both the LR and FR vols.
    # Note that LR and FR images should have the same means, but *not* the same variances.
    fr_vol = preproced_subj.fr_dti.tensor
    fr_mask = preproced_subj.fr_brain_mask.tensor.bool()
    masked_fr_vol = torch.masked_select(fr_vol, fr_mask).reshape(fr_vol.shape[0], -1)
    fr_channel_means = masked_fr_vol.mean(dim=1)
    fr_channel_means = fr_channel_means.reshape(-1, 1, 1, 1)
    fr_channel_vars = masked_fr_vol.var(dim=1)
    fr_channel_vars = fr_channel_vars.reshape(-1, 1, 1, 1)

    # Store subject-and-channel-wise means and vars in order to reverse the normalization
    # for the final visualization/output.
    preproced_subj["fr_means"] = fr_channel_means.detach().cpu().numpy()
    preproced_subj["fr_vars"] = fr_channel_vars.detach().cpu().numpy()

    lr_vol = preproced_subj.lr_dti["data"]
    lr_mask = preproced_subj.lr_brain_mask["data"].bool()
    masked_lr_vol = torch.masked_select(lr_vol, lr_mask).reshape(lr_vol.shape[0], -1)
    lr_channel_means = masked_lr_vol.mean(dim=1)
    lr_channel_means = lr_channel_means.reshape(-1, 1, 1, 1)
    lr_channel_vars = masked_lr_vol.var(dim=1)
    lr_channel_vars = lr_channel_vars.reshape(-1, 1, 1, 1)

    # Store subject-and-channel-wise means and vars in order to reverse the normalization
    # for the final visualization/output.
    preproced_subj["lr_means"] = lr_channel_means.detach().cpu().numpy()
    preproced_subj["lr_vars"] = lr_channel_vars.detach().cpu().numpy()

    # Print and log some statistics of the subject data.
    # Add FR stats to the summary table.
    fr_vol_stats = batch_boxplot_stats(masked_fr_vol)
    subj_stats["Subj ID"].extend(
        list(itertools.repeat(subj_id, len(fr_vol_stats.median)))
    )
    subj_stats["Resolution"].extend(
        list(itertools.repeat(tuple(fr_vol.shape[1:]), len(fr_vol_stats.median)))
    )
    subj_stats["Channel Index"].extend(dti_channel_names)
    subj_stats["Mean"].extend(fr_channel_means.cpu().flatten().tolist())
    subj_stats["Var"].extend(fr_channel_vars.cpu().flatten().tolist())
    # Append FR boxplot stats to their corresponding fields. In other words, all columns
    # after the var column.
    for i, field in enumerate(
        summary_stats_header[(summary_stats_header.index("Var") + 1) :]
    ):
        if "Num" in field:
            subj_stats[field].extend(list(map(len, fr_vol_stats[i])))
        else:
            subj_stats[field].extend(fr_vol_stats[i])

    # Add LR stats to the summary table.
    lr_vol_stats = batch_boxplot_stats(masked_lr_vol)
    subj_stats["Subj ID"].extend(
        list(itertools.repeat(subj_id, len(fr_vol_stats.median)))
    )
    subj_stats["Resolution"].extend(
        list(itertools.repeat(tuple(lr_vol.shape[1:]), len(lr_vol_stats.median)))
    )
    subj_stats["Channel Index"].extend(dti_channel_names)
    subj_stats["Mean"].extend(lr_channel_means.cpu().flatten().tolist())
    subj_stats["Var"].extend(lr_channel_vars.cpu().flatten().tolist())
    # Append LR boxplot stats to their corresponding fields. In other words, all columns
    # after the var column.
    for i, field in enumerate(
        summary_stats_header[(summary_stats_header.index("Var") + 1) :]
    ):
        if "Num" in field:
            subj_stats[field].extend(list(map(len, lr_vol_stats[i])))
        else:
            subj_stats[field].extend(lr_vol_stats[i])

    # Optionally apply image-level or global normalization.
    if isinstance(data_norm_method, str) and "channel" in data_norm_method.casefold():
        # Standardize the volumes.
        preproced_subj.fr_dti.set_data(
            pitn.data.norm.normalize_dti(fr_vol, fr_channel_means, fr_channel_vars)
        )
        preproced_subj.lr_dti["data"] = pitn.data.norm.normalize_dti(
            lr_vol, lr_channel_means, lr_channel_vars
        )

        # Re-calculate the same statistics post-normalization.
        fr_vol = preproced_subj.fr_dti["data"]
        fr_mask = preproced_subj.fr_brain_mask["data"].bool()
        masked_fr_vol = torch.masked_select(fr_vol, fr_mask).reshape(
            fr_vol.shape[0], -1
        )
        fr_channel_means = masked_fr_vol.mean(dim=1)
        fr_channel_means = fr_channel_means.reshape(-1, 1, 1, 1)
        fr_channel_vars = masked_fr_vol.var(dim=1)
        fr_channel_vars = fr_channel_vars.reshape(-1, 1, 1, 1)
        # Print and log some statistics of the subject data.
        # Add FR stats to the summary table.
        fr_vol_stats = batch_boxplot_stats(masked_fr_vol)
        norm_subj_stats["Subj ID"].extend(
            list(itertools.repeat(subj_id, len(fr_vol_stats.median)))
        )
        norm_subj_stats["Resolution"].extend(
            list(itertools.repeat(tuple(fr_vol.shape[1:]), len(fr_vol_stats.median)))
        )
        norm_subj_stats["Channel Index"].extend(dti_channel_names)
        norm_subj_stats["Mean"].extend(fr_channel_means.cpu().flatten().tolist())
        norm_subj_stats["Var"].extend(fr_channel_vars.cpu().flatten().tolist())
        # Append FR boxplot stats to their corresponding fields. In other words, all
        # columns after the var column.
        for i, field in enumerate(
            summary_stats_header[(summary_stats_header.index("Var") + 1) :]
        ):
            if "Num" in field:
                norm_subj_stats[field].extend(list(map(len, fr_vol_stats[i])))
            else:
                norm_subj_stats[field].extend(fr_vol_stats[i])

        # Add LR stats to the summary table.
        lr_vol = preproced_subj.lr_dti["data"]
        lr_mask = preproced_subj.lr_brain_mask["data"].bool()
        masked_lr_vol = torch.masked_select(lr_vol, lr_mask).reshape(
            lr_vol.shape[0], -1
        )
        lr_channel_means = masked_lr_vol.mean(dim=1)
        lr_channel_means = lr_channel_means.reshape(-1, 1, 1, 1)
        lr_channel_vars = masked_lr_vol.var(dim=1)
        lr_channel_vars = lr_channel_vars.reshape(-1, 1, 1, 1)

        lr_vol_stats = batch_boxplot_stats(masked_lr_vol)
        norm_subj_stats["Subj ID"].extend(
            list(itertools.repeat(subj_id, len(lr_vol_stats.median)))
        )
        norm_subj_stats["Resolution"].extend(
            list(itertools.repeat(tuple(lr_vol.shape[1:]), len(lr_vol_stats.median)))
        )
        norm_subj_stats["Channel Index"].extend(dti_channel_names)
        norm_subj_stats["Mean"].extend(lr_channel_means.cpu().flatten().tolist())
        norm_subj_stats["Var"].extend(lr_channel_vars.cpu().flatten().tolist())
        # Append LR boxplot stats to their corresponding fields. In other words, all columns
        # after the var column.
        for i, field in enumerate(
            summary_stats_header[(summary_stats_header.index("Var") + 1) :]
        ):
            if "Num" in field:
                norm_subj_stats[field].extend(list(map(len, lr_vol_stats[i])))
            else:
                norm_subj_stats[field].extend(lr_vol_stats[i])

    #     with open(log_txt_file, "a+") as f:
    #         f.write(f"Subject ID {subj_id} masked voxel distribution by channel:\n")
    #         f.write(
    #             f"{desc_channel_dists(preproced_subj['fr_dti']['data'], preproced_subj['fr_brain_mask']['data'].bool())}\n"
    #         )

    subj_data[subj_id] = preproced_subj
    print("=" * 20)
#     breakpoint()

print("===Data Loaded & Transformed===")

subj_stats_str = tabulate(
    subj_stats, headers=summary_stats_header, tablefmt=summary_stats_fmt
)

if norm_subj_stats["Subj ID"]:
    norm_subj_stats_str = tabulate(
        norm_subj_stats, headers=summary_stats_header, tablefmt=summary_stats_fmt
    )
else:
    norm_subj_stats_str = ""

Brain mask volume before dilation: 652877
Loading NIFTI image: /mnt/storage/data/pitn/hcp/224022/T1w/Diffusion/data.nii.gz
	Loaded NIFTI image
Selecting with bvals: Subject 224022...Selected
Downsampling: Subject 224022...Downsampled
Fitting to DTI: Subject 224022......DWI shape: torch.Size([108, 162, 190, 162])......DTI shape: (6, 162, 190, 162)...Fitted DTI model: torch.Size([6, 162, 190, 162])
Fitting to DTI: Subject 224022......DWI shape: torch.Size([108, 81, 95, 81])......DTI shape: (6, 81, 95, 81)...Fitted DTI model: torch.Size([6, 81, 95, 81])
Dilated mask volume: 652877
Mask volume difference: 0
Brain mask volume before dilation: 847825
Loading NIFTI image: /mnt/storage/data/pitn/hcp/227432/T1w/Diffusion/data.nii.gz
	Loaded NIFTI image
Selecting with bvals: Subject 227432...Selected
Downsampling: Subject 227432...Downsampled
Fitting to DTI: Subject 227432......DWI shape: torch.Size([108, 162, 190, 162])......DTI shape: (6, 162, 190, 162)...Fitted DTI model: torch.Size([6, 162, 

In [27]:
# g = [vol[None, ...] for vol in fr_vol[:, :, :, 100]]
# g.extend([vol[None, ...] for vol in subj_data[810439].fr_dti.data[:, :, :, 100]])
# g = torchvision.utils.make_grid(g, nrow=6)
# plt.figure(dpi=120, figsize=(4, 9))
# plt.imshow(np.rot90(g.cpu().numpy()[0]))
# plt.colorbar()
# # plt.imshow().cpu().numpy())
# # subj_data[810439].fr_dti.data

In [28]:
with open(log_txt_file, "a+") as f:
    f.write(f"Preprocessing transformation pipeline: {str(preproc_transforms)}\n")
    f.write(f"Data Summary Statistics, no normalization:\n {subj_stats_str}\n\n")

logger.add_text("data_dists", subj_stats_str)
# If the subject data was normalized and those stats were recorded, log those stats.
if norm_subj_stats_str:
    logger.add_text("normalized_data_dists", norm_subj_stats_str)
    with open(log_txt_file, "a+") as f:
        f.write(
            f"Data Summary Statistics, after normalization:\n {norm_subj_stats_str}\n\n"
        )

In [29]:
display(Markdown(subj_stats_str))

|   Subj ID | Resolution      | Channel Index   |         Mean |         Var |   Num Outliers (Lower) |          Low |   25th Percentile |       Median |   75th Percentile |        High |   Num Outliers (Upper) |
|-----------|-----------------|-----------------|--------------|-------------|------------------------|--------------|-------------------|--------------|-------------------|-------------|------------------------|
|    224022 | (162, 190, 162) | Dxx             |  0.000968793 | 1.28181e-06 |                   5858 |  3.70389e-05 |       0.000655324 |  0.000804229 |       0.00106751  | 0.0016858   |                  65959 |
|    224022 | (162, 190, 162) | Dxy             |  1.74187e-06 | 1.2433e-08  |                  22446 | -0.000205006 |      -4.98258e-05 |  1.93845e-06 |       5.36278e-05 | 0.000208808 |                  21382 |
|    224022 | (162, 190, 162) | Dyy             |  0.00100758  | 1.29972e-06 |                   6074 |  4.87033e-05 |       0.000689919 |  0.000849358 |       0.0011174   | 0.00175861  |                  63695 |
|    224022 | (162, 190, 162) | Dxz             |  3.57424e-07 | 1.21273e-08 |                  25400 | -0.00019435  |      -4.8414e-05  | -2.58494e-26 |       4.88766e-05 | 0.000194812 |                  25510 |
|    224022 | (162, 190, 162) | Dyz             | -2.02446e-05 | 1.27969e-08 |                  29788 | -0.000222973 |      -6.88305e-05 | -1.47795e-05 |       3.39312e-05 | 0.000188074 |                  21120 |
|    224022 | (162, 190, 162) | Dzz             |  0.000970013 | 1.2748e-06  |                   5680 |  3.9299e-05  |       0.000655782 |  0.000808685 |       0.00106677  | 0.00168325  |                  66460 |
|    224022 | (81, 95, 81)    | Dxx             |  0.000952374 | 1.93482e-07 |                    118 |  0.000146015 |       0.000688964 |  0.000820619 |       0.00105093  | 0.00159388  |                   7195 |
|    224022 | (81, 95, 81)    | Dxy             |  2.04059e-06 | 9.67466e-09 |                   3045 | -0.000169714 |      -4.1001e-05  |  2.39488e-06 |       4.48075e-05 | 0.00017352  |                   2866 |
|    224022 | (81, 95, 81)    | Dyy             |  0.000990003 | 2.00876e-07 |                    121 |  0.000156822 |       0.000722034 |  0.000866042 |       0.00109884  | 0.00166405  |                   6861 |
|    224022 | (81, 95, 81)    | Dxz             |  4.13201e-07 | 9.49909e-09 |                   3503 | -0.000159895 |      -3.96367e-05 |  2.36004e-08 |       4.05356e-05 | 0.000160794 |                   3490 |
|    224022 | (81, 95, 81)    | Dyz             | -2.0055e-05  | 9.74726e-09 |                   4152 | -0.000185353 |      -5.92571e-05 | -1.50244e-05 |       2.48072e-05 | 0.000150903 |                   3008 |
|    224022 | (81, 95, 81)    | Dzz             |  0.000952976 | 1.98921e-07 |                    116 |  0.000144802 |       0.000689172 |  0.000823553 |       0.00105209  | 0.00159646  |                   7176 |
|    227432 | (162, 190, 162) | Dxx             |  0.000984648 | 3.45175e-07 |                   5904 |  2.32554e-05 |       0.000664256 |  0.000816109 |       0.00109159  | 0.00173259  |                  89385 |
|    227432 | (162, 190, 162) | Dxy             | -1.01113e-07 | 1.20774e-08 |                  27975 | -0.000209048 |      -5.20028e-05 |  4.03518e-07 |       5.26937e-05 | 0.000209738 |                  25991 |
|    227432 | (162, 190, 162) | Dyy             |  0.00101754  | 3.61814e-07 |                   5263 |  1.54631e-05 |       0.000688089 |  0.000852339 |       0.00113651  | 0.00180913  |                  86267 |
|    227432 | (162, 190, 162) | Dxz             | -6.86858e-07 | 1.21932e-08 |                  32886 | -0.000203925 |      -5.0969e-05  | -3.87741e-26 |       5.10015e-05 | 0.000203957 |                  30489 |
|    227432 | (162, 190, 162) | Dyz             | -1.80661e-05 | 1.25058e-08 |                  36771 | -0.000227709 |      -6.84023e-05 | -1.24124e-05 |       3.7802e-05  | 0.000197108 |                  24848 |
|    227432 | (162, 190, 162) | Dzz             |  0.000991539 | 3.57884e-07 |                   5772 |  2.60831e-05 |       0.00066697  |  0.000828955 |       0.00109423  | 0.00173511  |                  90613 |
|    227432 | (81, 95, 81)    | Dxx             |  0.000969405 | 2.04789e-07 |                    211 |  0.000131195 |       0.000695745 |  0.000826491 |       0.00107211  | 0.00163666  |                   9866 |
|    227432 | (81, 95, 81)    | Dxy             |  8.92681e-08 | 9.51749e-09 |                   3929 | -0.000173046 |      -4.2736e-05  |  6.18157e-07 |       4.41375e-05 | 0.000174448 |                   3632 |
|    227432 | (81, 95, 81)    | Dyy             |  0.00100112  | 2.17044e-07 |                    166 |  0.000114092 |       0.000716238 |  0.00086404  |       0.00111767  | 0.00171981  |                   9274 |
|    227432 | (81, 95, 81)    | Dxz             | -8.37028e-07 | 9.7574e-09  |                   4733 | -0.000168494 |      -4.20955e-05 | -2.09517e-07 |       4.21698e-05 | 0.000168568 |                   4352 |
|    227432 | (81, 95, 81)    | Dyz             | -1.77039e-05 | 9.76359e-09 |                   5480 | -0.000188402 |      -5.80634e-05 | -1.24238e-05 |       2.8829e-05  | 0.000159168 |                   3654 |
|    227432 | (81, 95, 81)    | Dzz             |  0.000975871 | 2.12803e-07 |                    192 |  0.000131256 |       0.000698542 |  0.000839655 |       0.00107673  | 0.00164402  |                   9928 |
|    751348 | (162, 190, 162) | Dxx             |  0.00115159  | 0.00334894  |                      0 | -3.2913e-05  |       0.000645597 |  0.000808113 |       0.00109794  | 0.00177645  |                  83682 |
|    751348 | (162, 190, 162) | Dxy             | -3.07335e-05 | 0.000219731 |                  31193 | -0.000191636 |      -4.53449e-05 |  3.33987e-06 |       5.21827e-05 | 0.000198474 |                  30128 |
|    751348 | (162, 190, 162) | Dyy             |  0.00105523  | 6.95758e-05 |                   5524 |  2.31201e-05 |       0.000701383 |  0.000868353 |       0.00115356  | 0.00183182  |                  83871 |
|    751348 | (162, 190, 162) | Dxz             |  1.58034e-05 | 0.000229047 |                  33475 | -0.000197179 |      -5.13982e-05 | -2.20941e-06 |       4.57889e-05 | 0.00019157  |                  30658 |
|    751348 | (162, 190, 162) | Dyz             | -2.40098e-05 | 2.98391e-05 |                  39522 | -0.000218504 |      -6.63371e-05 | -1.34872e-05 |       3.51074e-05 | 0.000187274 |                  27488 |
|    751348 | (162, 190, 162) | Dzz             |  0.0010288   | 9.7893e-05  |                      0 | -3.4755e-06  |       0.00066241  |  0.000828836 |       0.00110633  | 0.00177222  |                  86522 |
|    751348 | (81, 95, 81)    | Dxx             |  0.000964376 | 2.18475e-07 |                    107 |  8.31927e-05 |       0.000680416 |  0.000824313 |       0.00107856  | 0.00167579  |                   9060 |
|    751348 | (81, 95, 81)    | Dxy             |  3.11061e-06 | 9.09316e-09 |                   4370 | -0.000153023 |      -3.56788e-05 |  3.58017e-06 |       4.25504e-05 | 0.000159894 |                   4187 |
|    751348 | (81, 95, 81)    | Dyy             |  0.00101802  | 2.2087e-07  |                    170 |  0.000136353 |       0.000734591 |  0.000883039 |       0.00113342  | 0.00173165  |                   9022 |
|    751348 | (81, 95, 81)    | Dxz             | -3.55426e-06 | 8.60068e-09 |                   4823 | -0.000158517 |      -4.17383e-05 | -2.3893e-06  |       3.61139e-05 | 0.000152892 |                   4371 |
|    751348 | (81, 95, 81)    | Dyz             | -1.79516e-05 | 9.46246e-09 |                   5661 | -0.00017714  |      -5.57567e-05 | -1.36525e-05 |       2.51657e-05 | 0.000146549 |                   3920 |
|    751348 | (81, 95, 81)    | Dzz             |  0.000980089 | 2.21417e-07 |                    148 |  0.000104238 |       0.00069643  |  0.000844    |       0.00109122  | 0.00168342  |                   9206 |
|    810439 | (162, 190, 162) | Dxx             |  0.000848012 | 7.29606e-07 |                  13451 |  0.00012461  |       0.000619153 |  0.000756712 |       0.000948848 | 0.00144339  |                  65435 |
|    810439 | (162, 190, 162) | Dxy             |  2.95504e-06 | 1.25336e-08 |                  25158 | -0.000206608 |      -4.93718e-05 |  2.49486e-06 |       5.54523e-05 | 0.000212688 |                  24979 |
|    810439 | (162, 190, 162) | Dyy             |  0.00090458  | 7.33282e-07 |                  13229 |  0.000143766 |       0.000669096 |  0.000817772 |       0.00101932  | 0.00154465  |                  60639 |
|    810439 | (162, 190, 162) | Dxz             |  1.11232e-06 | 1.55135e-08 |                  32085 | -0.000205537 |      -4.97297e-05 |  1.93392e-06 |       5.4142e-05  | 0.00020995  |                  28626 |
|    810439 | (162, 190, 162) | Dyz             | -9.89005e-07 | 1.40721e-08 |                  35567 | -0.000212691 |      -5.21562e-05 |  1.38691e-06 |       5.48668e-05 | 0.000215401 |                  26687 |
|    810439 | (162, 190, 162) | Dzz             |  0.000865005 | 7.42959e-07 |                  12967 |  0.000131746 |       0.000633765 |  0.000777136 |       0.000968445 | 0.00147046  |                  63934 |
|    810439 | (81, 95, 81)    | Dxx             |  0.000847206 | 1.17391e-07 |                    520 |  0.000216988 |       0.000651275 |  0.000771675 |       0.0009408   | 0.00137509  |                   7326 |
|    810439 | (81, 95, 81)    | Dxy             |  2.90232e-06 | 8.59156e-09 |                   3579 | -0.000164648 |      -3.87826e-05 |  3.08994e-06 |       4.51276e-05 | 0.000170993 |                   3575 |
|    810439 | (81, 95, 81)    | Dyy             |  0.000903604 | 1.2138e-07  |                    480 |  0.000235507 |       0.000700525 |  0.000831182 |       0.00101054  | 0.00147556  |                   6527 |
|    810439 | (81, 95, 81)    | Dxz             |  7.62964e-07 | 9.90522e-09 |                   4587 | -0.000164668 |      -3.97978e-05 |  2.25124e-06 |       4.3449e-05  | 0.000168319 |                   4074 |
|    810439 | (81, 95, 81)    | Dyz             | -9.83321e-07 | 9.87725e-09 |                   5214 | -0.00016868  |      -4.11125e-05 |  1.96332e-06 |       4.39325e-05 | 0.0001715   |                   3802 |
|    810439 | (81, 95, 81)    | Dzz             |  0.000863985 | 1.1933e-07  |                    532 |  0.000221628 |       0.00066593  |  0.000790969 |       0.000962131 | 0.00140643  |                   6957 |

In [30]:
display(Markdown(norm_subj_stats_str))

|   Subj ID | Resolution      | Channel Index   |         Mean |      Var |   Num Outliers (Lower) |        Low |   25th Percentile |      Median |   75th Percentile |      High |   Num Outliers (Upper) |
|-----------|-----------------|-----------------|--------------|----------|------------------------|------------|-------------------|-------------|-------------------|-----------|------------------------|
|    224022 | (162, 190, 162) | Dxx             |  1.07042e-08 | 0.999922 |                   5858 | -0.82295   |      -0.276864    | -0.145347   |       0.0871931   | 0.633279  |                  65959 |
|    224022 | (162, 190, 162) | Dxy             |  5.5157e-09  | 0.992021 |                  22446 | -1.84678   |      -0.460629    |  0.00175596 |       0.463471    | 1.84962   |                  21382 |
|    224022 | (162, 190, 162) | Dyy             | -1.29853e-07 | 0.999923 |                   6074 | -0.841051  |      -0.278628    | -0.13878    |       0.0963206   | 0.658743  |                  63695 |
|    224022 | (162, 190, 162) | Dxz             |  4.67432e-10 | 0.991822 |                  25400 | -1.76082   |      -0.441062    | -0.00323235 |       0.43878     | 1.75854   |                  25510 |
|    224022 | (162, 190, 162) | Dyz             |  5.81953e-09 | 0.992246 |                  29788 | -1.78514   |      -0.427827    |  0.0481238  |       0.477049    | 1.83436   |                  21120 |
|    224022 | (162, 190, 162) | Dzz             |  1.18728e-08 | 0.999922 |                   5680 | -0.824286  |      -0.278298    | -0.142881   |       0.0856934   | 0.631681  |                  66460 |
|    224022 | (81, 95, 81)    | Dxx             |  7.8684e-08  | 0.999483 |                    118 | -1.83272   |      -0.598688    | -0.299458   |       0.224002    | 1.45804   |                   7195 |
|    224022 | (81, 95, 81)    | Dxy             |  7.22909e-09 | 0.989769 |                   3045 | -1.73723   |      -0.435349    |  0.00358346 |       0.432571    | 1.73445   |                   2866 |
|    224022 | (81, 95, 81)    | Dyy             | -2.8582e-07  | 0.999502 |                    121 | -1.85852   |      -0.597739    | -0.27651    |       0.24278     | 1.50356   |                   6861 |
|    224022 | (81, 95, 81)    | Dxz             |  3.14736e-09 | 0.989582 |                   3503 | -1.63622   |      -0.408777    | -0.00397654 |       0.409517    | 1.63696   |                   3490 |
|    224022 | (81, 95, 81)    | Dyz             | -9.8355e-10  | 0.989845 |                   4152 | -1.66575   |      -0.395049    |  0.0506945  |       0.452088    | 1.72279   |                   3008 |
|    224022 | (81, 95, 81)    | Dzz             |  8.42902e-08 | 0.999498 |                    116 | -1.81157   |      -0.591332    | -0.29011    |       0.22216     | 1.4424    |                   7176 |
|    227432 | (162, 190, 162) | Dxx             | -4.16464e-08 | 0.99971  |                   5904 | -1.63613   |      -0.545255    | -0.286825   |       0.181997    | 1.27288   |                  89385 |
|    227432 | (162, 190, 162) | Dxy             | -9.37673e-09 | 0.991788 |                  27975 | -1.89347   |      -0.470332    |  0.00457296 |       0.478426    | 1.90156   |                  25991 |
|    227432 | (162, 190, 162) | Dyy             |  8.69283e-08 | 0.999724 |                   5263 | -1.66571   |      -0.547639    | -0.274614   |       0.197745    | 1.31582   |                  86267 |
|    227432 | (162, 190, 162) | Dxz             |  3.92347e-09 | 0.991865 |                  32886 | -1.83304   |      -0.453504    |  0.00619489 |       0.466186    | 1.84572   |                  30489 |
|    227432 | (162, 190, 162) | Dyz             |  5.45326e-09 | 0.992067 |                  36771 | -1.86722   |      -0.448328    |  0.0503564  |       0.497599    | 1.91649   |                  24848 |
|    227432 | (162, 190, 162) | Dzz             | -9.74028e-08 | 0.999721 |                   5772 | -1.61362   |      -0.542469    | -0.271735   |       0.171629    | 1.24278   |                  90613 |
|    227432 | (81, 95, 81)    | Dxx             |  1.96997e-07 | 0.999512 |                    211 | -1.8518    |      -0.604579    | -0.31573    |       0.226901    | 1.47412   |                   9866 |
|    227432 | (81, 95, 81)    | Dxy             |  3.70693e-09 | 0.989602 |                   3929 | -1.76545   |      -0.436686    |  0.00539304 |       0.449156    | 1.77792   |                   3632 |
|    227432 | (81, 95, 81)    | Dyy             | -1.76268e-08 | 0.999539 |                    166 | -1.90355   |      -0.611354    | -0.294173   |       0.250109    | 1.5423    |                   9274 |
|    227432 | (81, 95, 81)    | Dxz             | -8.69994e-09 | 0.989855 |                   4733 | -1.68865   |      -0.415559    |  0.00632033 |       0.433168    | 1.70626   |                   4352 |
|    227432 | (81, 95, 81)    | Dyz             |  1.92912e-08 | 0.989862 |                   5480 | -1.71874   |      -0.406376    |  0.0531644  |       0.468535    | 1.7809    |                   3654 |
|    227432 | (81, 95, 81)    | Dzz             |  7.34956e-08 | 0.99953  |                    192 | -1.83049   |      -0.601043    | -0.295215   |       0.218591    | 1.44804   |                   9928 |
|    751348 | (162, 190, 162) | Dxx             |  1.18452e-09 | 1        |                      0 | -0.0204684 |      -0.00874363  | -0.00593535 |      -0.000927149 | 0.0107976 |                  83682 |
|    751348 | (162, 190, 162) | Dxy             | -3.58946e-11 | 1        |                  31193 | -0.0108547 |      -0.000985707 |  0.00229863 |       0.00559362  | 0.0154626 |                  30128 |
|    751348 | (162, 190, 162) | Dyy             | -1.65115e-09 | 0.999999 |                   5524 | -0.123736  |      -0.042421    | -0.0224035  |       0.0117888   | 0.0931034 |                  83871 |
|    751348 | (162, 190, 162) | Dxz             |  1.61526e-10 | 1        |                  33475 | -0.0140728 |      -0.00444035  | -0.0011902  |       0.00198129  | 0.0116138 |                  30658 |
|    751348 | (162, 190, 162) | Dyz             | -2.42288e-10 | 0.999997 |                  39522 | -0.0356052 |      -0.00774867  |  0.00192631 |       0.0108223   | 0.0386788 |                  27488 |
|    751348 | (162, 190, 162) | Dzz             | -2.27931e-09 | 0.999999 |                      0 | -0.104332  |      -0.0370308   | -0.02021    |       0.00783676  | 0.0751381 |                  86522 |
|    751348 | (81, 95, 81)    | Dxx             | -2.81124e-08 | 0.999542 |                    107 | -1.8848    |      -0.607376    | -0.299588   |       0.244242    | 1.52167   |                   9060 |
|    751348 | (81, 95, 81)    | Dxy             |  6.89621e-09 | 0.989122 |                   4370 | -1.62841   |      -0.404558    |  0.00489739 |       0.411341    | 1.63519   |                   4187 |
|    751348 | (81, 95, 81)    | Dyy             | -1.40524e-07 | 0.999547 |                    170 | -1.8756    |      -0.602954    | -0.287155   |       0.245477    | 1.51812   |                   9022 |
|    751348 | (81, 95, 81)    | Dxz             | -4.16411e-09 | 0.988507 |                   4823 | -1.66131   |      -0.40936     |  0.0124892  |       0.425271    | 1.67722   |                   4371 |
|    751348 | (81, 95, 81)    | Dyz             | -7.53684e-10 | 0.989542 |                   5661 | -1.6279    |      -0.386604    |  0.0439635  |       0.440927    | 1.68222   |                   3920 |
|    751348 | (81, 95, 81)    | Dzz             |  1.08154e-07 | 0.999549 |                    148 | -1.86091   |      -0.602688    | -0.289147   |       0.23613     | 1.49436   |                   9206 |
|    810439 | (162, 190, 162) | Dxx             |  8.2258e-08  | 0.999863 |                  13451 | -0.846848  |      -0.267913    | -0.10688    |       0.118043    | 0.696978  |                  65435 |
|    810439 | (162, 190, 162) | Dxy             |  4.26836e-09 | 0.992085 |                  25158 | -1.86445   |      -0.465544    | -0.00409413 |       0.467061    | 1.86597   |                  24979 |
|    810439 | (162, 190, 162) | Dyy             |  7.6205e-08  | 0.999864 |                  13229 | -0.88841   |      -0.274977    | -0.101367   |       0.133978    | 0.747411  |                  60639 |
|    810439 | (162, 190, 162) | Dxz             |  7.54323e-10 | 0.993595 |                  32085 | -1.65381   |      -0.406886    |  0.00657525 |       0.424395    | 1.67132   |                  28626 |
|    810439 | (162, 190, 162) | Dyz             |  5.07788e-09 | 0.992944 |                  35567 | -1.77831   |      -0.429808    |  0.0199579  |       0.469193    | 1.81769   |                  26687 |
|    810439 | (162, 190, 162) | Dzz             |  1.01098e-07 | 0.999865 |                  12967 | -0.85064   |      -0.268257    | -0.101936   |       0.119998    | 0.702381  |                  63934 |
|    810439 | (81, 95, 81)    | Dxx             |  1.78978e-07 | 0.999149 |                    520 | -1.83861   |      -0.571612    | -0.220355   |       0.273051    | 1.54005   |                   7326 |
|    810439 | (81, 95, 81)    | Dxy             | -5.74345e-09 | 0.988495 |                   3579 | -1.7972    |      -0.447126    |  0.00201244 |       0.452922    | 1.80299   |                   3575 |
|    810439 | (81, 95, 81)    | Dyy             | -1.62058e-07 | 0.999177 |                    480 | -1.91684   |      -0.582656    | -0.207788   |       0.306803    | 1.64099   |                   6527 |
|    810439 | (81, 95, 81)    | Dxz             |  8.2271e-09  | 0.990005 |                   4587 | -1.65388   |      -0.405502    |  0.0148789  |       0.426749    | 1.67513   |                   4074 |
|    810439 | (81, 95, 81)    | Dyz             |  1.47467e-09 | 0.989977 |                   5214 | -1.67888   |      -0.401749    |  0.0295     |       0.44967     | 1.7268    |                   3802 |
|    810439 | (81, 95, 81)    | Dzz             |  8.47547e-08 | 0.999163 |                    532 | -1.85874   |      -0.573098    | -0.211282   |       0.283998    | 1.56964   |                   6957 |

In [31]:
subj_dataset = torchio.SubjectsDataset(list(subj_data.values()), load_getitem=False)

# Model Training

## Set Up Patch-Based Data Loaders

In [56]:
# Data train/validation/test split

num_subjs = len(subj_dataset)
num_test_subjs = int(np.ceil(num_subjs * test_percent))
num_val_subjs = int(np.ceil(num_subjs * val_percent))
num_train_subjs = num_subjs - (num_test_subjs + num_val_subjs)
subj_list = subj_dataset.dry_iter()
# Randomly shuffle the list of subjects, then choose the first `num_test_subjs` subjects
# for testing.
random.shuffle(subj_list)

# Create partial function to collect list of samples and form a tuple of tensors.
collate_fn = functools.partial(
    pitn.samplers.collate_subj, full_res_key="fr_dti", low_res_key="lr_dti"
)

test_dataset = torchio.SubjectsDataset(subj_list[:num_test_subjs], load_getitem=False)
# Choose the remaining for training/validation.
# If only 1 subject is available, assume this is a debugging run.
if num_subjs == 1 and num_train_subjs == 0:
    print("DEBUG: Only 1 subject with no training subjects, mixing train and test set")
    subj_list = subj_list[:]
    num_train_subjs = num_test_subjs
else:
    subj_list = subj_list[num_test_subjs:]

val_dataset = torchio.SubjectsDataset(subj_list[:num_val_subjs], load_getitem=False)
subj_list = subj_list[num_val_subjs:]

train_dataset = torchio.SubjectsDataset(subj_list, load_getitem=False)

# Training patch sampler, random across all patches of all volumes.
train_sampler = pitn.samplers.MultiresSampler(
    source_img_key="fr_dti",
    low_res_key="lr_dti",
    downsample_factor=downsample_factor,
    label_name="fr_brain_mask",
    source_spatial_patch_size=output_spatial_patch_shape,
    low_res_spatial_patch_size=input_spatial_patch_shape,
    label_probabilities={0: 0, 1: 1},
)

patches_per_subj = 8000
# Hold enough for 2 epochs at a time.
queue_max_length = patches_per_subj * num_train_subjs * 2

# Set up a torchio.Queue to act as a sampler proxy for the torch DataLoader
train_queue = torchio.Queue(
    train_dataset,
    max_length=queue_max_length,
    samples_per_volume=patches_per_subj,
    sampler=train_sampler,
    shuffle_patches=True,
    shuffle_subjects=True,
    num_workers=0,
    #     verbose=True,
)

train_loader = torch.utils.data.DataLoader(
    train_queue,
    batch_size=batch_size,
    collate_fn=collate_fn,
    pin_memory=True,
    num_workers=0,
)

# Set up validation and testing objects.
# Calculate the patch overlap needed for ~50% of the patch volume overlapping (which is
# not the same as dividing each dimension by 2).
input_vol_half_overlap = int(
    np.floor(np.power(np.prod(input_spatial_patch_shape) / 2, 1 / 3))
)
# torchio requires an even-numbered overlap.
if input_vol_half_overlap % 2 == 1:
    input_vol_half_overlap += 1
input_vol_half_overlap = np.repeat(input_vol_half_overlap, 3)

# Repeat for the output.
output_vol_half_overlap = np.floor(
    np.power(np.prod(output_spatial_patch_shape) / 2, 1 / 3)
).astype(int)
# torchio requires an even-numbered overlap.
if output_vol_half_overlap % 2 == 1:
    output_vol_half_overlap += 1
output_vol_half_overlap = np.repeat(output_vol_half_overlap, 3)

# Validation samplers
# No backward pass is performed here, so we can increase the batch size to lower
# inference time.
tv_batch_size = 128
val_samplers = list()
for subj in val_dataset.dry_iter():
    val_samplers.append(
        pitn.samplers.MultiresGridSampler(
            subject=subj,
            source_img_key="fr_dti",
            low_res_key="lr_dti",
            downsample_factor=downsample_factor,
            source_spatial_patch_size=output_spatial_patch_shape,
            low_res_spatial_patch_size=input_spatial_patch_shape,
            patch_overlap=0,
            # Due to the oversampling in the LR->HR patch mapping, a mask is necessary
            # to avoid sampling from out of bounds.
            source_mask=subj["fr_brain_mask"].tensor[0].bool(),
        )
    )

# Alter the collate function to include locations in batches for visualization of sub-
# regions during validation.
val_collate_locs = functools.partial(
    pitn.viz.collate_locations, full_res_key="fr_dti", low_res_key="lr_dti"
)
concat_val_dataset = torch.utils.data.ConcatDataset(val_samplers)
val_loader = torch.utils.data.DataLoader(
    concat_val_dataset,
    batch_size=tv_batch_size,
    collate_fn=val_collate_locs,
    pin_memory=True,
    num_workers=0,
)

# Test samplers
test_samplers = list()
for subj in test_dataset.dry_iter():
    test_samplers.append(
        pitn.samplers.MultiresGridSampler(
            subject=subj,
            source_img_key="fr_dti",
            low_res_key="lr_dti",
            downsample_factor=downsample_factor,
            source_spatial_patch_size=output_spatial_patch_shape,
            low_res_spatial_patch_size=input_spatial_patch_shape,
            patch_overlap=output_vol_half_overlap,
            # Due to the oversampling in the LR->HR patch mapping, a mask is necessary
            # to avoid sampling from out of bounds.
            source_mask=subj["fr_brain_mask"].tensor[0].bool(),
        )
    )

concat_test_dataset = torch.utils.data.ConcatDataset(test_samplers)
test_loader = torch.utils.data.DataLoader(
    concat_test_dataset,
    batch_size=tv_batch_size,
    collate_fn=collate_fn,
    pin_memory=True,
    num_workers=6,
)

print("Training subject ID(s): ", [s.subj_id for s in train_dataset.dry_iter()])
print("Validation subject ID(s): ", [s.subj_id for s in val_dataset.dry_iter()])
print("Test subject ID(s): ", [s.subj_id for s in test_dataset.dry_iter()])

Training subject ID(s):  [227432, 751348]
Validation subject ID(s):  [224022]
Test subject ID(s):  [810439]


In [33]:
with open(log_txt_file, "a+") as f:
    f.write(f"Training Set Subjects: {[s.subj_id for s in train_dataset.dry_iter()]}\n")
    f.write(f"Validation Set Subjects: {[s.subj_id for s in val_dataset.dry_iter()]}\n")
    f.write(f"Test Set Subjects: {[s.subj_id for s in test_dataset.dry_iter()]}\n")

logger.add_text("train_subjs", str([s.subj_id for s in train_dataset.dry_iter()]))
logger.add_text("val_subjs", str([s.subj_id for s in val_dataset.dry_iter()]))
logger.add_text("test_subjs", str([s.subj_id for s in test_dataset.dry_iter()]))

## Model Definition

In [78]:
# Full pytorch-lightning module for contained training, validation, and testing.
debug_prob_threshold = -0.0005


class DIQTSystem(pl.LightningModule):

    loss_methods = {
        "MSE".casefold(): torch.nn.MSELoss(reduction="mean"),
        "SSE".casefold(): torch.nn.MSELoss(reduction="sum"),
        "RMSE".casefold(): lambda y_hat, y: torch.sqrt(
            F.mse_loss(y_hat, y, reduction="mean")
        ),
        "L1".casefold(): torch.nn.L1Loss(reduction="mean"),
    }

    def __init__(
        self,
        channels,
        downsample_factor,
        train_loss_method: str,
        opt_params: dict,
        val_viz_bboxes=None,
        val_patch_overlap=(0, 0, 0),
        norm_method=None,
    ):
        super().__init__()

        self._channels = channels
        self._downsample_factor = downsample_factor

        # Parameters
        # Network parameters
        self.net = pitn.nn.models.ThreeConv(
            self._channels, self._downsample_factor, norm_method=norm_method
        )

        ## Training parameters
        self.opt_params = opt_params

        try:
            self._loss_fn = self.loss_methods[train_loss_method.casefold()]
        except (AttributeError, KeyError) as e:
            if callable(train_loss_method):
                self._loss_fn = train_loss_method
            else:
                raise e
        # Sub-regions of the volume that should be logged in validation.
        if val_viz_bboxes is None:
            self.val_bboxes = torch.zeros(0, 6)
        else:
            self.val_bboxes = val_viz_bboxes
        self.val_patch_overlap = val_patch_overlap
        #         self.val_vmin = -torch.ones(self.val_bboxes.shape[0])
        #         self.val_vmax = torch.ones(self.val_bboxes.shape[0])
        # My own dinky logging object.
        self.plain_log = {"train_loss": list(), "val_loss": list(), "test_loss": list()}

    @staticmethod
    def normalize_with_layer(norm_layer, y):

        if norm_layer is not None:
            if not norm_layer.track_running_stats:
                y = norm_layer(y)
            else:
                if isinstance(norm_layer, torch.nn.InstanceNorm3d):
                    y = F.instance_norm(
                        y,
                        eps=norm_layer.eps,
                    )
                elif isinstance(norm_layer, torch.nn.BatchNorm3d):
                    y = F.batch_norm(
                        y,
                        eps=norm_layer.eps,
                    )

        return y

    def forward(self, x):
        y = self.net(x)
        return y

    def training_step(self, batch, batch_idx):
        x, y = batch

        # We need both prediction and ground truth to be a standard normal distribution
        # for a fair calculation of the loss.
        y_pred = self.net(x, norm_output=False)
        if self.net.norm is not None:
            y = self.normalize_with_layer(self.net.norm, y)

        loss = self._loss_fn(y_pred, y)
        if random.random() < debug_prob_threshold:
            #             breakpoint()
            print("Target Patch Dists")
            print(desc_channel_dists(y.detach()))
            print("Predicted Patch Dists")
            print(desc_channel_dists(y_pred.detach()))

        self.log("train_loss", loss, prog_bar=True)
        self.plain_log["train_loss"].append(float(loss.cpu()))
        return loss

    def validation_step(self, batch, batch_idx):
        # The validation samples also include the voxel coordinates within the volume(s).
        x, y, y_locs = batch

        y_pred = self.net(x, norm_output=False)

        val_loss = torch.sqrt(F.mse_loss(y_pred, y, reduction="mean"))
        self.log("val_loss", val_loss, prog_bar=True)
        self.plain_log["val_loss"].append(float(val_loss.cpu()))

        bbox_patches = list()
        y_locs = y_locs.detach().cpu()
        for bbox in self.val_bboxes:

            bbox = bbox.detach().cpu()
            locs_to_keep = torch.prod(
                y_locs[:, :3] >= bbox[None, :3], dim=1
            ) & torch.prod(y_locs[:, 3:] <= bbox[None, 3:], dim=1)
            locs_to_keep = locs_to_keep.to(y_pred).bool()
            bbox_patches.append(
                {
                    "full_res": y[locs_to_keep].detach().cpu(),
                    "pred": y_pred[locs_to_keep].detach().cpu(),
                    "locations": y_locs[locs_to_keep].detach().cpu(),
                }
            )
            assert (bbox_patches[-1]["locations"] >= 0).all()
        return {"loss": val_loss, "bbox": bbox_patches}

    def validation_epoch_end(self, outputs):
        batched_bbox_patches = [b["bbox"] for b in outputs]

        sub_regions = list()
        for (
            i,
            bbox,
        ) in enumerate(self.val_bboxes):

            bbox = bbox.detach().cpu()
            bbox_offset = bbox[:3]

            fr_patches = torch.cat(
                [batch[i]["full_res"] for batch in batched_bbox_patches]
            )

            pred_patches = torch.cat(
                [batch[i]["pred"] for batch in batched_bbox_patches]
            )
            locs = torch.cat([batch[i]["locations"] for batch in batched_bbox_patches])

            fr_agg = pitn.viz.SubGridAggregator(
                bbox[3:] - bbox[:3],
                location_offset=bbox_offset,
                patch_overlap=self.val_patch_overlap,
            )
            fr_agg.add_batch(fr_patches, locs)
            fr_sub_vol = fr_agg.get_output_tensor()

            pred_agg = pitn.viz.SubGridAggregator(
                bbox[3:] - bbox[:3],
                location_offset=bbox_offset,
                patch_overlap=self.val_patch_overlap,
            )
            pred_agg.add_batch(pred_patches, locs)
            pred_sub_vol = pred_agg.get_output_tensor()

            sub_regions.append({"full_res": fr_sub_vol, "pred": pred_sub_vol})

        for i, reg in enumerate(sub_regions):
            # Slice into full volume and just grab a B x C x H x W slice.
            slice_idx = (
                slice(None),
                None,
                slice(None),
                reg["full_res"].shape[1] // 2,
                slice(None),
            )

            row_fr = torchvision.utils.make_grid(
                reg["full_res"][slice_idx], pad_value=4, nrow=1
            )
            row_pred = torchvision.utils.make_grid(
                reg["pred"][slice_idx], pad_value=4, nrow=1
            )
            # Calculate voxel-wise root squared-error
            rse = torch.sqrt(F.mse_loss(reg["pred"], reg["full_res"], reduction="none"))
            row_rse = torchvision.utils.make_grid(rse[slice_idx], pad_value=4, nrow=1)

            reg_grid = torchvision.utils.make_grid(
                [row_fr, row_pred, row_rse], nrow=3, pad_value=4
            )[0]

            fig = plt.figure(figsize=(4.5, 8), clear=True)
            #             max_dti = np.quantile(np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.95)
            #             min_dti = np.quantile(np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.05)
            #             figsize=(12, 6), dpi=160
            #             color_norm = mpl.colors.Normalize(vmin=min_dti, vmax=max_dti)
            #             fig.colorbar(
            #                 mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap),
            #                 ax=axs,
            #                 location="right",
            #                 fraction=0.1,
            #                 pad=0.03,
            #             )
            plt.imshow(reg_grid.cpu().numpy(), cmap="jet")

            plt.colorbar()
            self.logger.experiment.add_figure(f"val_samples_{i}", fig, self.global_step)

    def test_step(self, batch, batch_idx):

        x, y = batch

        # We can't normalize the outputs or ground truth because it would change
        # our units out of $mm^2/s$.
        y_pred = self.net(x, norm_output=False)
        #         if self.net.norm is not None:
        #             y = self.normalize_with_layer(self.net.norm, y)

        test_loss = torch.sqrt(F.mse_loss(y_pred, y, reduction="mean"))
        self.log("test_loss", test_loss)
        self.plain_log["test_loss"].append(float(test_loss.cpu()))

        return test_loss

    def viz_step(self, x, norm_output=True):
        """Step for running inference for the purpose of a visualization.

        Mainly deals with normalization."""

        with torch.no_grad():
            if self.net.training:
                was_training = True
            else:
                was_training = False

            self.net.eval()

            y_pred = self.net(x, norm_output=norm_output)

            if isinstance(self.net.norm, torch.nn.InstanceNorm3d):
                target_mean = torch.mean(x, dim=(2, 3, 4), keepdim=True)
                target_var = torch.var(x, dim=(2, 3, 4), keepdim=True)
                eps = self.net.norm.eps

            elif isinstance(self.net.norm, torch.nn.BatchNorm3d):
                target_mean = torch.mean(x, dim=(0, 2, 3, 4), keepdim=True)
                target_var = torch.var(x, dim=(0, 2, 3, 4), keepdim=True)
                eps = self.net.norm.eps
            else:
                target_mean = torch.zeros(1, 1, 1, 1, 1).to(x)
                target_var = torch.ones(1, 1, 1, 1, 1).to(x)
                eps = 1e-10

            y_pred = pitn.data.norm.denormalize_batch(
                y_pred, target_mean, target_var, eps=eps
            )

            if was_training:
                self.net.train()

        return y_pred

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.net.parameters(), **self.opt_params)
        return optimizer

## Training Loop

In [58]:
# Create bbox coordinates for visualizing validation aggregate patches.
region_size = torch.as_tensor(output_spatial_patch_shape) * 3
vol_shape = torch.as_tensor(val_dataset.dry_iter()[0]["fr_dti"].shape[1:])
possible_bbox_ini = [
    torch.arange(0, vol_shape[0], output_spatial_patch_shape[0]),
    torch.arange(0, vol_shape[1], output_spatial_patch_shape[1]),
    torch.arange(0, vol_shape[2], output_spatial_patch_shape[2]),
]

bbox_idx_ini = list()
for possible_bbox_part in possible_bbox_ini:
    num_parts = len(possible_bbox_part)
    bbox_idx_ini.append(possible_bbox_part[round(num_parts * 0.4)])
bbox_idx_ini = torch.as_tensor(bbox_idx_ini)

bbox_coord = torch.cat([bbox_idx_ini, bbox_idx_ini + region_size])
bbox_coords = bbox_coord.reshape(1, *bbox_coord.shape)
bbox_coords, bbox_coords.shape

(tensor([[ 70,  84,  70, 112, 126, 112]]), torch.Size([1, 6]))

In [79]:
train_start_timestamp = datetime.datetime.now().replace(microsecond=0)

model = DIQTSystem(
    channels=channels,
    downsample_factor=downsample_factor,
    norm_method=network_norm_method,
    train_loss_method=train_loss_name,
    opt_params=opt_params,
    val_viz_bboxes=bbox_coords,
    val_patch_overlap=val_samplers[0].patch_overlap,
)

with open(log_txt_file, "a+") as f:
    f.write(f"Model overview: {model}\n")

# Create trainer object. Note: `automatic_optimization` needs to be set to `False` when
# manually performing backprop. See
# <https://colab.research.google.com/drive/1nGtvBFirIvtNQdppe2xBes6aJnZMjvl8?usp=sharing>
trainer = pl.Trainer(
    gpus=1,
    max_epochs=2,
    logger=pl_logger,
    log_every_n_steps=50,
    check_val_every_n_epoch=1,
    progress_bar_refresh_rate=20,
    terminate_on_nan=True,
)

trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
train_duration = datetime.datetime.now().replace(microsecond=0) - train_start_timestamp
print(f"Train duration: {train_duration}")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type      | Params
---------------------------------------
0 | net      | ThreeConv | 142 K 
1 | _loss_fn | MSELoss   | 0     
---------------------------------------
142 K     Trainable params
0         Non-trainable params
142 K     Total params
0.572     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

> [0;32m<ipython-input-78-28d19ad37dd4>[0m(139)[0;36mvalidation_epoch_end[0;34m()[0m
[0;32m    138 [0;31m            [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 139 [0;31m            [0mbbox[0m [0;34m=[0m [0mbbox[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mcpu[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    140 [0;31m            [0mbbox_offset[0m [0;34m=[0m [0mbbox[0m[0;34m[[0m[0;34m:[0m[0;36m3[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c



The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.



Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

> [0;32m<ipython-input-78-28d19ad37dd4>[0m(139)[0;36mvalidation_epoch_end[0;34m()[0m
[0;32m    138 [0;31m            [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 139 [0;31m            [0mbbox[0m [0;34m=[0m [0mbbox[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mcpu[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    140 [0;31m            [0mbbox_offset[0m [0;34m=[0m [0mbbox[0m[0;34m[[0m[0;34m:[0m[0;36m3[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


Validating: 0it [00:00, ?it/s]

> [0;32m<ipython-input-78-28d19ad37dd4>[0m(139)[0;36mvalidation_epoch_end[0;34m()[0m
[0;32m    138 [0;31m            [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 139 [0;31m            [0mbbox[0m [0;34m=[0m [0mbbox[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mcpu[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    140 [0;31m            [0mbbox_offset[0m [0;34m=[0m [0mbbox[0m[0;34m[[0m[0;34m:[0m[0;36m3[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


Train duration: 0:00:25


In [None]:
with open(log_txt_file, "a+") as f:
    f.write("\n")
    f.write(f"Training time: {train_duration}\n")
    f.write(
        f"\t{train_duration.days} Days, "
        + f"{train_duration.seconds // 3600} Hours,"
        + f"{(train_duration.seconds // 60) % 60} Minutes,"
        + f'{train_duration.seconds % 60} Seconds"\n'
    )

In [None]:
# Plot rolling average window of training loss values.
plt.figure(dpi=110)
window = 1000
rolling_mean = (
    np.convolve(model.plain_log["train_loss"], np.ones(window), "valid") / window
)
rolling_start = 100
plt.plot(
    np.arange(
        window + rolling_start,
        window + rolling_start + len(rolling_mean[rolling_start:]),
    ),
    rolling_mean[rolling_start:],
)
plt.title("Training Loss " + train_loss_name + f"\nRolling Mean {window}")
plt.xlabel("Epoch")
plt.ylabel("Loss")
# plt.ylim(0, 1)
print(np.median(rolling_mean))
print(
    np.mean(model.plain_log["train_loss"][: window + rolling_start]),
    np.var(model.plain_log["train_loss"][: window + rolling_start]),
    np.max(model.plain_log["train_loss"][: window + rolling_start]),
)

plt.savefig(experiment_results_dir / "train_loss.png")

# Model Testing/Evaluation

## Testing Loop

In [None]:
trainer.test(test_dataloaders=test_loader)

In [None]:
test_loss_name = "RMSE"
test_loss_log_file = experiment_results_dir / "test_loss.csv"

with open(experiment_results_log, "a+") as f:
    f.write(f"Test loss function: {test_loss_name}\n")

row_iter = enumerate(model.plain_log["test_loss"])
test_loss_tabular = "".join([f"{batch_idx}, {loss}\n" for batch_idx, loss in row_iter])

with open(test_loss_log_file, "a+") as f:
    f.write("batch_idx, loss\n")
    f.write(test_loss_tabular)

## Cubic Spline Baseline

In [None]:
cubic_spline_test_log = list()

for subj in test_dataset.dry_iter():
    print("---")
    target_shape = subj["fr_dti"]["data"].cpu().numpy().shape
    interp_cubic_spline = scipy.ndimage.zoom(
        subj["lr_dti"]["data"].cpu().numpy(),
        zoom=(1, downsample_factor, downsample_factor, downsample_factor),
        order=spline_interp_order,
    )
    if interp_cubic_spline.shape != target_shape:
        # Crop off the end few voxels to account for the lack of padding used in full-
        # volume inference.
        interp_cubic_spline = interp_cubic_spline[
            :, : target_shape[1], : target_shape[2], : target_shape[3]
        ]
    print(f"Subj {subj['subj_id']} done")

    cubic_spline_test_log.append(interp_cubic_spline)

cspline_loss = list()

# Calculate spline loss for test images.
for subj, cspline_pred in zip(test_dataset.dry_iter(), cubic_spline_test_log):
    # De-normalize the ground truth volume.
    gt = subj["fr_dti"]["data"]
    gt = gt.detach().cpu().numpy()
    gt = (gt * (subj["fr_vars"] + 1 ** (-10))) + subj["fr_means"]
    # De-normalize the spline prediction to match the channel-wise distribution of the
    # ground truth.
    cspline_pred = (cspline_pred * (subj["fr_vars"] + 10 ** (-10))) + subj["fr_means"]
    # Calculate the RMSE of just the values found in the mask.
    se = (gt - cspline_pred) ** 2
    se = se[:, subj["fr_brain_mask"]["data"].bool().cpu().numpy()[0]]
    loss = np.sqrt(se.mean())
    cspline_loss.append(loss)

# Find the grand mean of the spline RMSE's
cspline_loss_mean = np.mean(cspline_loss)
print(cspline_loss_mean)

## Evaluation Visualization

In [None]:
# Plot testing loss values over all patches.
fig, ax_prob = plt.subplots(figsize=(8, 4), dpi=120)
log_scale = False

sns.histplot(
    np.asarray(model.plain_log["test_loss"]),
    kde=True,
    stat="probability",
    log_scale=log_scale,
    ax=ax_prob,
    label="Current Model",
    legend=False,
)
plt.xlabel("Loss in $mm^2/second$")

# Draw an entire other plot, invisibly, to have another y-axis.
ax_count = ax_prob.twinx()
sns.histplot(
    np.asarray(model.plain_log["test_loss"]),
    alpha=0,
    ax=ax_count,
    stat="count",
    log_scale=log_scale,
    zorder=1.1,
)

# Draw means of different comparison models.
comparison_kwargs = {"ls": "--", "alpha": 0.8, "lw": 2.5}
# Plot the current DNN model performance.
plt.axvline(
    np.asarray(model.plain_log["test_loss"]).mean(),
    label="Current Model Mean",
    color="blue",
    **comparison_kwargs,
)
# Our spline mean performance.
plt.axvline(
    cspline_loss_mean,
    label="(Ours) C-spline Mean",
    color="black",
    **comparison_kwargs,
)
# Tanno, et. al., 2021 model comparisons.
# Taken from Table 2, HCP exterior.
plt.axvline(
    31.738e-4,
    label="(Tanno etal, 2021) C-spline Mean",
    color="red",
    **comparison_kwargs,
)
plt.axvline(
    23.139e-4,
    label="(Tanno etal, 2021) RF",
    color="orange",
    **comparison_kwargs,
)
plt.axvline(
    13.609e-4,
    label="(Tanno etal, 2021) ESPCN Baseline",
    color="green",
    **comparison_kwargs,
)
plt.axvline(
    13.412e-4,
    label="(Tanno etal, 2021) Best",
    color="purple",
    **comparison_kwargs,
)
# Best performing Blumberg, et. al., 2018 paper model.
plt.axvline(
    12.78e-4,
    label="(Blumberg etal, 2018) Best",
    color="pink",
    **comparison_kwargs,
)

plt.legend(fontsize="small")
plt.title(f"Test Loss Histogram with Test Metric {test_loss_name}")
plt.savefig(experiment_results_dir / "test_loss_hist.png")

In [None]:
print(np.asarray(model.plain_log["test_loss"]).max())
print(np.asarray(model.plain_log["test_loss"]).mean())

# Whole-Volume Visualization

In [None]:
# Create full 3D volume of full-res ground truth, low-res downsample, and high-res
# inferences.
@dataclass
class SubjResult:
    subj_id: int
    full_res: torch.Tensor
    low_res: torch.Tensor
    full_res_predicted: torch.Tensor
    full_res_cubic_spline: np.ndarray


test_vol_results = list()

with torch.no_grad():

    for subj in test_dataset.dry_iter():

        # Create a grid sampler for this subject.
        subj_sampler = pitn.samplers.MultiresGridSampler(
            subject=subj,
            source_img_key="fr_dti",
            low_res_key="lr_dti",
            downsample_factor=downsample_factor,
            source_spatial_patch_size=output_spatial_patch_shape,
            low_res_spatial_patch_size=input_spatial_patch_shape,
            source_mask=subj["fr_brain_mask"].tensor,
            patch_overlap=0,
        )

        loader = torch.utils.data.DataLoader(
            subj_sampler, batch_size=256, pin_memory=True
        )
        aggregator = torchio.GridAggregator(subj_sampler)

        # Iterate over all batches of patches.
        for batch in loader:

            x = batch["lr_dti"]["data"]
            # Locations are in reference to the full-res ground truth.
            locations = batch["location"]
            predictions = (
                model.viz_step(x.to(model.device), norm_output=False).detach().cpu()
            )

            aggregator.add_batch(predictions, locations)

        # Collect all variants of the volume and aggregate into one container object.

        # Calculate cubic spline as a baseline.
        full_res_cubic_spline = scipy.ndimage.zoom(
            subj["lr_dti"]["data"].cpu().numpy(),
            zoom=((1,) + (downsample_factor,) * 3),
            order=3,
        )
        full_res_cubic_spline = torch.from_numpy(full_res_cubic_spline).to(fr_vol)
        fr_vol = subj["fr_dti"]["data"]
        lr_vol = subj["lr_dti"]["data"]

        full_res_predicted = aggregator.get_output_tensor()

        if data_norm_method is not None and "channel" in data_norm_method.casefold():

            fr_means = torch.as_tensor(subj["fr_means"]).to(fr_vol)
            fr_vars = torch.as_tensor(subj["fr_vars"]).to(fr_vol)
            fr_vol = pitn.data.norm.denormalize_dti(fr_vol, fr_means, fr_vars)
            lr_means = torch.as_tensor(subj["lr_means"]).to(lr_vol)
            lr_vars = torch.as_tensor(subj["lr_vars"]).to(lr_vol)
            lr_vol = pitn.data.norm.denormalize_dti(lr_vol, lr_means, lr_vars)

            full_res_cubic_spline = pitn.data.norm.denormalize_dti(
                full_res_cubic_spline, fr_means, fr_vars
            )
            full_res_predicted = pitn.data.norm.denormalize_dti(
                full_res_predicted, fr_means, fr_vars
            )
        # Zero-out all voxels outside the mask.
        fr_mask = subj["fr_brain_mask"]["data"]
        full_res_cubic_spline = full_res_cubic_spline * fr_mask.to(
            full_res_cubic_spline
        )
        fr_vol = fr_vol * fr_mask.to(fr_vol)
        lr_vol = lr_vol * subj["lr_brain_mask"]["data"].to(lr_vol)

        subj_result = SubjResult(
            subj_id=subj["subj_id"],
            full_res=fr_vol,
            low_res=lr_vol,
            full_res_predicted=full_res_predicted,
            full_res_cubic_spline=full_res_cubic_spline,
        )

        test_vol_results.append(subj_result)

In [None]:
vis_subj_idx = 0

In [None]:
# Generate FA-weighted diffusion direction map for prediction.
tensor_key = "full_res_predicted"
pred_dir_map = pitn.viz.direction_map(
    test_vol_results[vis_subj_idx].__getattribute__(tensor_key).data.cpu().numpy()
)
# Set channels last for matplotlib
pred_dir_map = pred_dir_map.transpose(1, 2, 3, 0)

In [None]:
# Generate FA-weighted diffusion direction map for cubic spline interpolation.
tensor_key = "full_res_cubic_spline"
cspline_dir_map = pitn.viz.direction_map(
    test_vol_results[vis_subj_idx].__getattribute__(tensor_key).data.cpu().numpy()
)
# Set channels last for matplotlib
cspline_dir_map = cspline_dir_map.transpose(1, 2, 3, 0)

In [None]:
# Generate FA-weighted diffusion direction map.
tensor_key = "full_res"
fr_dir_map = pitn.viz.direction_map(
    test_vol_results[vis_subj_idx].__getattribute__(tensor_key).data.cpu().numpy()
)
# Set channels last for matplotlib
fr_dir_map = fr_dir_map.transpose(1, 2, 3, 0)

In [None]:
# Generate FA-weighted diffusion direction map for low-res input.
tensor_key = "low_res"
lr_dir_map = pitn.viz.direction_map(
    test_vol_results[vis_subj_idx].__getattribute__(tensor_key).data.cpu().numpy()
)
# Set channels last for matplotlib
lr_dir_map = lr_dir_map.transpose(1, 2, 3, 0)

In [None]:
slice_idx = (slice(None, None, None), slice(None, None, None), 100)
low_res_slice_idx = tuple(s // 2 if isinstance(s, int) else s for s in slice_idx)
print(slice_idx)
print(low_res_slice_idx)

In [None]:
cspline_dir_map.shape
fr_dir_map.shape

In [None]:
plt.figure(dpi=150)
plt.imshow(np.rot90(pred_dir_map[slice_idx]))
plt.axis("off")
plt.savefig(experiment_results_dir / "pred_dir_map_sample.png");

In [None]:
plt.figure(dpi=150)
plt.imshow(np.rot90(cspline_dir_map[slice_idx]))
plt.axis("off")
plt.savefig(experiment_results_dir / "cubic_spline_dir_map_sample.png");

In [None]:
plt.figure(dpi=150)
plt.imshow(np.rot90(fr_dir_map[slice_idx]))
# plt.colorbar()
plt.axis("off")
plt.savefig(experiment_results_dir / "ground_truth_dir_map_sample.png")

In [None]:
plt.figure(dpi=150)
plt.imshow(np.rot90(lr_dir_map[low_res_slice_idx]))
plt.axis("off")
plt.savefig(experiment_results_dir / "low_res_map_sample.png");

## Per-Image Normalization

In [None]:
# Display all 6 DTIs for ground truth, predicted, and root squared error

channel_names = ["Dxx", "Dxy", "Dyy", "Dxz", "Dyz", "Dzz"]
dti_names = [
    "Full-Res",
    "Low-Res Input",
    "Cubic Spline",
    "Predicted",
    "Root Squared Error\nFR vs. Prediction",
]
cmap = "jet"
# cmaps = ["Reds", "Greys", "Greens", "Purples", "Greys", "Blues", ]

dtis = [
    test_vol_results[vis_subj_idx]
    .full_res.data[(slice(None), *slice_idx)]
    .cpu()
    .numpy(),
    test_vol_results[vis_subj_idx]
    .low_res[(slice(None), *low_res_slice_idx)]
    .cpu()
    .numpy(),
    test_vol_results[vis_subj_idx].full_res_cubic_spline[(slice(None), *slice_idx)],
    test_vol_results[vis_subj_idx]
    .full_res_predicted[(slice(None), *slice_idx)]
    .cpu()
    .numpy(),
]

# Add root squared error between FR and prediction.
dtis.append(np.sqrt((dtis[0] - dtis[-1]) ** 2))

nrows = len(dtis)
ncols = len(channel_names)

fig = plt.figure(figsize=(12, 6), dpi=160)

grid = mpl.gridspec.GridSpec(
    nrows,
    ncols,
    figure=fig,
    hspace=0.05,
    wspace=0.05,
)

max_subplot_height = 0
for i_row in range(nrows):
    dti = dtis[i_row]

    for j_col in range(ncols):
        ax = fig.add_subplot(grid[i_row, j_col])
        ax.imshow(np.rot90(dti[j_col]), cmap=cmap, interpolation=None)
        if ax.get_subplotspec().is_first_col():
            ax.set_ylabel(dti_names[i_row], size="small")
        if ax.get_subplotspec().is_last_row():
            ax.set_xlabel(channel_names[j_col])

        # Update highest subplot to put the `suptitle` later on.
        max_subplot_height = max(
            max_subplot_height, ax.get_position(original=False).get_points()[1, 1]
        )
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
plt.suptitle(
    "DTI Channel Breakdown, Per-Image Normalization",
    y=max_subplot_height + 0.015,
    verticalalignment="bottom",
)
plt.savefig(experiment_results_dir / "DTI_channel_sample_per_img_norm.png");

## Global Normalization

In [None]:
# Display all 6 DTIs for ground truth, predicted, and root squared error

dti_names = [
    "Full-Res",
    "Low-Res Input",
    "Cubic Spline",
    "Predicted",
    "Root Squared Error\nFR vs. Prediction",
]
cmap = "jet"
# cmaps = ["Reds", "Greys", "Greens", "Purples", "Greys", "Blues", ]

dtis = [
    test_vol_results[vis_subj_idx]
    .full_res.data[(slice(None), *slice_idx)]
    .cpu()
    .numpy(),
    test_vol_results[vis_subj_idx]
    .low_res[(slice(None), *low_res_slice_idx)]
    .cpu()
    .numpy(),
    test_vol_results[vis_subj_idx].full_res_cubic_spline[(slice(None), *slice_idx)],
    test_vol_results[vis_subj_idx]
    .full_res_predicted[(slice(None), *slice_idx)]
    .cpu()
    .numpy(),
]

# Add root squared error
dtis.append(np.sqrt((dtis[0] - dtis[-1]) ** 2))

# Don't take the absolute max and min values, as there exist some extreme (e.g., > 3
# orders of magnitude) outliers. Instead, take some percente quantile.
# Reshape and concatenate the dtis in order to compute the quantiles of images with
# different shapes (e.g., the low-res input patch).
max_dti = np.quantile(np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.95)
min_dti = np.quantile(np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.05)

nrows = len(dtis)
ncols = len(channel_names)

fig = plt.figure(figsize=(12, 6), dpi=160)

grid = mpl.gridspec.GridSpec(
    nrows,
    ncols,
    figure=fig,
    hspace=0.05,
    wspace=0.05,
)
axs = list()
max_subplot_height = 0
for i_row in range(nrows):
    dti = dtis[i_row]

    for j_col in range(ncols):
        ax = fig.add_subplot(grid[i_row, j_col])
        ax.imshow(
            np.rot90(dti[j_col]),
            cmap=cmap,
            interpolation=None,
            vmin=min_dti,
            vmax=max_dti,
        )
        if ax.get_subplotspec().is_first_col():
            ax.set_ylabel(dti_names[i_row], size="small")
        if ax.get_subplotspec().is_last_row():
            ax.set_xlabel(channel_names[j_col])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        # Update highest subplot to put the `suptitle` later on.
        max_subplot_height = max(
            max_subplot_height, ax.get_position(original=False).get_points()[1, 1]
        )
        axs.append(ax)

color_norm = mpl.colors.Normalize(vmin=min_dti, vmax=max_dti)
fig.colorbar(
    mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap),
    ax=axs,
    location="right",
    fraction=0.1,
    pad=0.03,
)
plt.suptitle(
    "DTI Channel Breakdown, Normalized over All Images",
    y=max_subplot_height + 0.015,
    verticalalignment="bottom",
)
plt.savefig(experiment_results_dir / "DTI_channel_sample_global_norm.png");

## Channel-Wise Normalization

In [None]:
# Display all 6 DTIs for ground truth, predicted, low-res input, and root squared error
# Normalize by index in the DTI coefficients.
# Reshape and concatenate the dtis in order to compute the quantiles of images with
# different shapes (e.g., the low-res input patch).

channel_names = ["Dxx", "Dxy", "Dyy", "Dxz", "Dyz", "Dzz"]
dti_names = [
    "Full-Res",
    "Low-Res Input",
    "Cubic Spline",
    "Predicted",
    "Root Squared Error\nFR vs. Prediction",
]
cmap = "coolwarm"
# cmaps = ["Reds", "Greys", "Greens", "Purples", "Greys", "Blues", ]

dtis = [
    test_vol_results[vis_subj_idx]
    .full_res.data[(slice(None), *slice_idx)]
    .cpu()
    .numpy(),
    test_vol_results[vis_subj_idx]
    .low_res[(slice(None), *low_res_slice_idx)]
    .cpu()
    .numpy(),
    test_vol_results[vis_subj_idx].full_res_cubic_spline[(slice(None), *slice_idx)],
    test_vol_results[vis_subj_idx]
    .full_res_predicted[(slice(None), *slice_idx)]
    .cpu()
    .numpy(),
]

# Add root squared error
dtis.append(np.sqrt((dtis[0] - dtis[-1]) ** 2))


# Don't take the absolute max and min values, as there exist some extreme (e.g., > 3
# orders of magnitude) outliers. Instead, take some percente quantile.
max_dtis = np.quantile(
    np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.95, axis=1
)
min_dtis = np.quantile(
    np.concatenate([di.reshape(6, -1) for di in dtis], axis=1), 0.05, axis=1
)

max_dtis = np.max(np.abs([max_dtis, min_dtis]), axis=0)
min_dtis = -1 * max_dtis

nrows = len(dtis)
ncols = len(channel_names)

fig = plt.figure(figsize=(12, 7), dpi=160)

grid = mpl.gridspec.GridSpec(
    nrows,
    ncols,
    figure=fig,
    hspace=0.05,
    wspace=0.05,
)

axs = list()
max_subplot_height = 0
for i_row in range(nrows):
    dti = dtis[i_row]
    axs_cols = list()

    for j_col in range(ncols):
        ax = fig.add_subplot(grid[i_row, j_col])
        ax.imshow(
            np.rot90(dti[j_col]),
            cmap=cmap,
            interpolation=None,
            vmin=min_dtis[j_col],
            vmax=max_dtis[j_col],
        )
        if ax.get_subplotspec().is_first_col():
            ax.set_ylabel(dti_names[i_row], size="small")
        if ax.get_subplotspec().is_last_row():
            ax.set_xlabel(channel_names[j_col])

        # Update highest subplot to put the `suptitle` later on.
        max_subplot_height = max(
            max_subplot_height, ax.get_position(original=False).get_points()[1, 1]
        )
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])

        axs_cols.append(ax)

    axs.append(axs_cols)

# Place colorbars on each column.
for j_col in range(ncols):

    full_col_ax = [axs[i][j_col] for i in range(nrows)]

    color_norm = mpl.colors.Normalize(vmin=min_dtis[j_col], vmax=max_dtis[j_col])

    color_mappable = mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap)
    cbar = fig.colorbar(
        color_mappable,
        ax=full_col_ax,
        location="top",
        orientation="horizontal",
        pad=0.01,
        shrink=0.85,
    )
    cbar.ax.tick_params(labelsize=8, rotation=35)
    cbar.ax.xaxis.set_major_formatter(mpl.ticker.StrMethodFormatter("{x:g}"))
#     cbar.ax.ticklabel_format(scilimits=(3, -3), useOffset=False)

plt.suptitle(
    "DTI Channel Breakdown, Channel-Wise Normalization",
    y=max_subplot_height + 0.01,
    verticalalignment="bottom",
)
plt.savefig(experiment_results_dir / "DTI_channel_sample_channel_wise_norm.png");

In [None]:
# Close tensorboard logger.
pl_logger.finalize()
# Experiment is complete, move the results directory to its final location.
if experiment_results_dir != final_experiment_results_dir:
    print("Moving out of tmp location")
    experiment_results_dir = experiment_results_dir.rename(final_experiment_results_dir)