# Pain in the Net - LPTN Gan on T1w Images
Application of Laplacian Pyramid Translation Network (LPTN) to domain adaptation of diffusion MRI.


Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

---

Based on the following work(s):

* `J. Liang, H. Zeng, and L. Zhang, “High-Resolution Photorealistic Image Translation in Real-Time: A Laplacian Pyramid Translation Network,” 2021, pp. 9392–9400. Accessed: Aug. 26, 2021. [Online]. Available: https://openaccess.thecvf.com/content/CVPR2021/html/Liang_High-Resolution_Photorealistic_Image_Translation_in_Real-Time_A_Laplacian_Pyramid_Translation_CVPR_2021_paper.html
`


## Imports & Environment Setup

### Imports

In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 1

# imports
import collections
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
import zipfile

import ants
import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dipy.viz
import dipy.viz.regtools
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import mpl_toolkits
import matplotlib.pyplot as plt
import seaborn as sns

import IPython

# Try importing GPUtil for printing GPU specs.
# May not be installed if using CPU only.
try:
    import GPUtil
except ImportError:
    warnings.warn("WARNING: Package GPUtil not found, cannot print GPU specs")
from tabulate import tabulate
from IPython.display import display, Markdown
import ipyplot

# Data management libraries.
import nibabel as nib
import natsort
from natsort import natsorted
import addict
from addict import Addict
import box
from box import Box
import pprint
from pprint import pprint as ppr

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchio
import pytorch_lightning as pl
import monai

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, edgeitems=2, threshold=100, linewidth=88)
torch.set_printoptions(
    sci_mode=False, edgeitems=2, threshold=100, linewidth=88, profile="short"
)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = f"direnv exec {os.getcwd()} /usr/bin/env"
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [None]:
# Project-specific scripts
# It's easier to import it this way rather than make an entirely new package, due to
# conflicts with local packages and anaconda installations.
# You made me do this, poor python package management!!
if "PROJECT_ROOT" in os.environ:
    lib_location = str(Path(os.environ["PROJECT_ROOT"]).resolve())
else:
    lib_location = str(Path("../../").resolve())
if lib_location not in sys.path:
    sys.path.insert(0, lib_location)
import lib as pitn

# Include the top-level lib module along with its submodules.
%aimport lib
# Grab all submodules of lib, not including modules outside of the package.
includes = list(
    filter(
        lambda m: m.startswith("lib."),
        map(lambda x: x[1].__name__, inspect.getmembers(pitn, inspect.ismodule)),
    )
)
# Run aimport magic with constructed includes.
ipy = IPython.get_ipython()
ipy.run_line_magic("aimport", ", ".join(includes))

In [None]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

### Specs Recording

In [None]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():

    # GPU information
    # Taken from
    # <https://www.thepythoncode.com/article/get-hardware-system-information-python>.
    # If GPUtil is not installed, skip this step.
    try:
        gpus = GPUtil.getGPUs()
        print("=" * 50, "GPU Specs", "=" * 50)
        list_gpus = []
        for gpu in gpus:
            # get the GPU id
            gpu_id = gpu.id
            # name of GPU
            gpu_name = gpu.name
            driver_version = gpu.driver
            cuda_version = torch.version.cuda
            # get total memory
            gpu_total_memory = f"{gpu.memoryTotal}MB"
            gpu_uuid = gpu.uuid
            list_gpus.append(
                (
                    gpu_id,
                    gpu_name,
                    driver_version,
                    cuda_version,
                    gpu_total_memory,
                    gpu_uuid,
                )
            )

        print(
            tabulate(
                list_gpus,
                headers=(
                    "id",
                    "Name",
                    "Driver Version",
                    "CUDA Version",
                    "Total Memory",
                    "uuid",
                ),
            )
        )
    except NameError:
        print("CUDA Version: ", torch.version.cuda)

else:
    print("CUDA not in use, falling back to CPU")

In [None]:
# cap is defined in an ipython magic command
print(cap)

### Data Variables & Definitions Setup

In [None]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"])

processed_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"])
hcp_processed_data_dir = (
    processed_data_dir / "hcp/derivatives/mean-downsample/scale-1.00mm"
)
clinic_processed_data_dir = (
    processed_data_dir / "oasis3/derivatives/mean-downsample/scale-orig"
)
assert hcp_processed_data_dir.exists() and clinic_processed_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

### Experiment Logging Setup

In [None]:
# tensorboard experiment logging setup.
EXPERIMENT_NAME = "debug_t1w_64_cube_patches"

ts = datetime.datetime.now().replace(microsecond=0).isoformat()
# Break ISO format because many programs don't like having colons ':' in a filename.
ts = ts.replace(":", "_")
experiment_name = ts + "__" + EXPERIMENT_NAME
run_name = experiment_name
print(experiment_name)
# experiment_results_dir = results_dir / experiment_name

# Create temporary directory for results directory, in case experiment does not finish.
tmp_dirs = list(filter(lambda s: not str(s).startswith("."), tmp_results_dir.glob("*")))

# Only keep up to N tmp results.
n_tmp_to_keep = 3
if len(tmp_dirs) > (n_tmp_to_keep - 1):
    print(f"More than {n_tmp_to_keep} temporary results, culling to the most recent")
    tmps_to_delete = natsorted([str(tmp_dir) for tmp_dir in tmp_dirs])[
        : -(n_tmp_to_keep - 1)
    ]
    for tmp_dir in tmps_to_delete:
        shutil.rmtree(tmp_dir)
        print("Deleted temporary results directory ", tmp_dir)

experiment_results_dir = tmp_results_dir / experiment_name
# Final target directory, to be made when experiment is complete.
final_experiment_results_dir = results_dir / experiment_name

In [None]:
# Pass this object into the pytorchlightning Trainer object, for easier logging within
# the training/testing loops.
pl_logger = pl.loggers.TensorBoardLogger(
    tmp_results_dir,
    name=experiment_name,
    version="",
    log_graph=False,
    default_hp_metric=False,
)
# Use the lower-level logger for logging histograms, images, etc.
logger = pl_logger.experiment

# Create a separate txt file to log streams of events & info besides parameters & results.
log_txt_file = Path(logger.log_dir) / "log.txt"
with open(log_txt_file, "a+") as f:
    f.write(f"Experiment Name: {experiment_name}\n")
    f.write(f"Timestamp: {ts}\n")
    # cap is defined in an ipython magic command
    f.write(f"Environment and Hardware Info:\n {cap}\n\n")

### Experiment Parameters

In [None]:
# Parameters
params = Box(default_box=True)

# Data params.
params.num_channels = 1
params.hcp.num_subjects = 13
params.clinic.num_subjects = 9
params.clamp_percentiles = (0.05, 95.0)
# params.data_scale_range = None
# Scale input data by the valid values of each channel of the vol.
# I.e., Dx,x in [0, 1], Dx,y in [-1, 1], Dy,y in [0, 1], Dy,z in [-1, 1], etc.
params.data_scale_range = ((-1), (1))

# Network params.
params.num_laplace_high_freq = 3
params.discriminator_downscale_factors = [1, 2, 4]
params.lambda_adversary_loss = 1
params.lambda_reconst_loss_weight = 10
params.use_grad_penalty = False
params.lambda_grad_penalty = 100
# Set the init function to None to change to pytorch default initialization.
# params.net_init.f = None
params.net_init.mean = 0.0
params.net_init.std = 0.02

# Adam optimizer kwargs for each network.
params.optim.gen_kwargs.lr = 2e-4
params.optim.gen_kwargs.betas = (0.5, 0.99)
params.optim.discriminator_kwargs.lr = 2e-4
params.optim.discriminator_kwargs.betas = (0.5, 0.99)

# Training, validation, & testing params
# Patch size must be a factor of 2**num_laplace_high_freq
params.train.patch_size = (64, 64, 64)
params.batch_size = 8
params.samples_per_subj_per_epoch = 300
params.max_epochs = 50
params.train.hcp_num_subjects = 12
params.val.hcp_num_subjects = 1
# Data augmentation parameters
# params.train.aug.input_noise_dist = None
params.train.aug.input_noise_dist = torch.distributions.normal.Normal(0, 0.05)
# params.train.aug.label_noise_dist = None
params.train.aug.label_noise_dist = torch.distributions.uniform.Uniform(-0.3, 0.3)

# Create these assert statements because having an invalid number of train/val subjects
# may not be caught in the loading below and cause a silent runtime error.
assert params.train.hcp_num_subjects <= params.hcp.num_subjects
assert params.val.hcp_num_subjects <= params.hcp.num_subjects

with open(log_txt_file, "a+") as f:
    f.write(pprint.pformat(params) + "\n")

In [None]:
# Optional weight & bias initialization of conv layers.
@torch.no_grad()
def conv_init_normal(m, mean, std):
    if isinstance(m, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
        torch.nn.init.normal_(m.weight, mean=mean, std=std)
        torch.nn.init.normal_(m.bias, mean=mean, std=std)


if "mean" in params.net_init and "std" in params.net_init:
    f = functools.partial(
        conv_init_normal, mean=params.net_init.mean, std=params.net_init.std
    )
    params.net_init.f = f

## Data Loading

In [None]:
# Transformation pipeline.
# The input to the laplacian pyramid must be divisible by 2 for the number of high-
# frequency levels in the pyramid.
laplace_pyramid_divisible_by_shape = 2**params.num_laplace_high_freq

pre_process_pipeline = monai.transforms.Compose(
    [
        monai.transforms.CropForegroundd(["t1w", "mask"], source_key="mask", margin=3),
        monai.transforms.DivisiblePadd(
            ["t1w", "mask"], laplace_pyramid_divisible_by_shape
        ),
        monai.transforms.ToTensord("t1w", dtype=torch.float),
        monai.transforms.ToTensord("mask", dtype=torch.float),
    ]
)

### Load and Pre-Process HCP Data

In [None]:
# Find data directories for each subject.
hcp_subj_dirs: dict = dict()

possible_ids = [
    "sub-397154",
    "sub-224022",
    "sub-140117",
    "sub-751348",
    "sub-894774",
    "sub-156637",
    "sub-227432",
    "sub-303624",
    "sub-185947",
    "sub-810439",
    "sub-753251",
    "sub-644246",
    "sub-141422",
    "sub-135528",
    "sub-103010",
    "sub-700634",
]

## Sub-set the chosen participants for dev and debugging!
selected_ids = random.sample(possible_ids, params.hcp.num_subjects)
if params.hcp.num_subjects < len(possible_ids):
    warnings.warn(
        "WARNING: Sub-selecting participants for dev and debugging. "
        + f"Subj IDs selected: {selected_ids}"
    )
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_ids = natsorted(selected_ids)

for subj_id in selected_ids:
    hcp_subj_dirs[subj_id] = hcp_processed_data_dir / f"{subj_id}"
    assert hcp_subj_dirs[subj_id].exists()
ppr(hcp_subj_dirs)

In [None]:
# Log to file and experiment.
with open(log_txt_file, "a+") as f:
    f.write(f"Selected HCP Subjects: {selected_ids}\n")

logger.add_text("hcp_subjs", pprint.pformat(selected_ids))

In [None]:
# Data loading and processing loop.
hcp_subj_data = list()
# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=True)

# Directory prefixes for each image to be read.
t1w_file_prefix = "anat"
mask_file_prefix = "mask"

for subj_id, subj_dir in hcp_subj_dirs.items():
    subj_data = dict()
    subj_data["subj_id"] = subj_id

    # Load the T1s
    img_dir = subj_dir / t1w_file_prefix
    img_filename = list(img_dir.glob(f"{subj_id}*T1w*.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    # Add channel dimension if not found.
    if len(img.shape) == 3:
        img = img[
            None,
        ]
    subj_data["t1w"] = img
    # The default metadata key name for monai.
    subj_data["t1w_meta_dict"] = metadata

    # Load masks
    img_dir = subj_dir / mask_file_prefix
    img_filename = list(img_dir.glob(f"{subj_id}*mask*.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    if len(img.shape) == 3:
        img = img[
            None,
        ]
    subj_data["mask"] = img
    # The default metadata key name for monai.
    subj_data["mask_meta_dict"] = metadata

    # Pre-process subject vols.
    subj_data = pre_process_pipeline(subj_data)

    # Perform scaling of input data?
    if params.data_scale_range is not None:
        scaler = pitn.data.norm.DTIMinMaxScaler(
            params.data_scale_range[0],
            params.data_scale_range[1],
            quantile_low=params.clamp_percentiles[0] / 100,
            quantile_high=params.clamp_percentiles[1] / 100,
            dim=(1, 2, 3),
            channel_size=params.num_channels,
            clip=True,
        )
        scaled = scaler.scale(subj_data["t1w"] * subj_data["mask"], stateful=True)
        subj_data["t1w"] = scaled * subj_data["mask"]
        subj_data["scaler"] = scaler

    hcp_subj_data.append(subj_data)

# Create dataset with all HCP subjects included.
hcp_subj_dataset = monai.data.Dataset(hcp_subj_data)

### Load & Pre-Process Clinical Data

In [None]:
# Find data directories for each subject.
clinic_subj_dirs: dict = dict()

possible_ids = [
    "sub-OAS30188_MR_d3844",
    "sub-OAS30375_MR_d5792",
    "sub-OAS30558_MR_d2148",
    "sub-OAS30643_MR_d0280",
    "sub-OAS30685_MR_d0032",
    "sub-OAS30762_MR_d0043",
    "sub-OAS30770_MR_d1201",
    "sub-OAS30944_MR_d0089",
    "sub-OAS31018_MR_d0041",
    "sub-OAS31157_MR_d4924",
]

## Sub-set the chosen participants for dev and debugging!
selected_ids = random.sample(possible_ids, params.clinic.num_subjects)
if params.clinic.num_subjects < len(possible_ids):
    warnings.warn(
        "WARNING: Sub-selecting participants for dev and debugging. "
        + f"Subj IDs selected: {selected_ids}"
    )
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_ids = natsorted(selected_ids)

for subj_id in selected_ids:
    clinic_subj_dirs[subj_id] = clinic_processed_data_dir / f"{subj_id}"
    assert clinic_subj_dirs[subj_id].exists()
ppr(clinic_subj_dirs)

In [None]:
# OASIS3 Dataset

# # Find data directories for each subject.
# clinic_subj_dirs: dict = dict()

# possible_ids = [
#     "OAS30052",
#     "OAS30058",
#     "OAS30132",
#     "OAS30176",
#     "OAS30180",
#     "OAS30203",
#     "OAS30233",
#     "OAS30257",
#     "OAS30261",
#     "OAS30719",
#     "OAS30733",
#     "OAS30748",
#     "OAS30959",
#     "OAS31084",
#     "OAS31088",
#     "OAS31111",
# ]

# ## Sub-set the chosen participants for dev and debugging!
# selected_ids = random.sample(possible_ids, params.clinic.num_subjects)
# if params.clinic.num_subjects < len(possible_ids):
#     warnings.warn(
#         "WARNING: Sub-selecting participants for dev and debugging. "
#         + f"Subj IDs selected: {selected_ids}"
#     )
# # ### A nested warning! For debugging only.
# # warnings.warn("WARNING: Mixing training and testing subjects")
# # selected_ids.append(selected_ids[0])
# # ###
# ##

# selected_ids = natsorted(selected_ids)

# for subj_id in selected_ids:
#     clinic_subj_dirs[subj_id] = (
#         clinic_processed_data_dir / f"derivatives/diffusion/sub-{subj_id}/ses-01"
#     )
#     assert clinic_subj_dirs[subj_id].exists()
# ppr(clinic_subj_dirs)

In [None]:
# Log to file and experiment.
with open(log_txt_file, "a+") as f:
    f.write(f"Selected Clinically-Scanned Subjects: {selected_ids}\n")

logger.add_text("clinic_data_subjs", pprint.pformat(selected_ids))

In [None]:
# Data loading and processing loop.
clinic_subj_data = list()
# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=True)

# Directory prefixes for each image to be read.
t1w_file_prefix = "anat"
mask_file_prefix = "mask"

for subj_id, subj_dir in clinic_subj_dirs.items():
    subj_data = dict()
    subj_data["subj_id"] = subj_id

    # Load the T1s
    img_dir = subj_dir / t1w_file_prefix
    img_filename = list(img_dir.glob(f"{subj_id}*T1w*.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    # Add channel dimension if not found.
    if len(img.shape) == 3:
        img = img[
            None,
        ]
    subj_data["t1w"] = img
    # The default metadata key name for monai.
    subj_data["t1w_meta_dict"] = metadata

    # Load masks
    img_dir = subj_dir / mask_file_prefix
    img_filename = list(img_dir.glob(f"{subj_id}*mask*.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    if len(img.shape) == 3:
        img = img[
            None,
        ]
    subj_data["mask"] = img
    # The default metadata key name for monai.
    subj_data["mask_meta_dict"] = metadata

    # Pre-process subject vols.
    subj_data = pre_process_pipeline(subj_data)

    # Perform scaling of input data?
    if params.data_scale_range is not None:
        scaler = pitn.data.norm.DTIMinMaxScaler(
            params.data_scale_range[0],
            params.data_scale_range[1],
            quantile_low=params.clamp_percentiles[0] / 100,
            quantile_high=params.clamp_percentiles[1] / 100,
            dim=(1, 2, 3),
            channel_size=params.num_channels,
            clip=True,
        )
        scaled = scaler.scale(subj_data["t1w"] * subj_data["mask"], stateful=True)
        subj_data["t1w"] = scaled * subj_data["mask"]
        subj_data["scaler"] = scaler

    clinic_subj_data.append(subj_data)

# Create dataset with all "clinical quality" subjects included.
clinic_subj_dataset = monai.data.Dataset(clinic_subj_data)

## Setup of Training Objects

In [None]:
# Designate HCP subjects for training, validation, and testing.
hcp_ids = [s["subj_id"] for s in hcp_subj_data]
random.shuffle(hcp_ids)
hcp_train_ids = hcp_ids[: params.train.hcp_num_subjects]
hcp_val_ids = hcp_ids[: params.val.hcp_num_subjects]

# Designate clinic subject IDs for training.
clinic_ids = [s["subj_id"] for s in clinic_subj_data]
random.shuffle(clinic_ids)
# Just select all clinic IDs.
clinic_train_ids = clinic_ids[: params.clinic.num_subjects]

In [None]:
# Set up dataset and data loading objects.
# ! The samplers created here will cause the source domain patches and the target domain
# patches to *not* be aligned in any way; this is intentional for unpaired I2I.

# Set up HCP scan data.
# Training set.
source_patch_ds = list()
for subj_dict in filter(lambda s: s["subj_id"] in hcp_train_ids, hcp_subj_data):
    source_patch_ds.append(
        pitn.data.MaskFilteredPatchDataset3d(
            subj_dict["t1w"], mask=subj_dict["mask"], patch_size=params.train.patch_size
        )
    )

source_train_dataset = torch.utils.data.ConcatDataset(source_patch_ds)
source_train_sampler = pitn.samplers.ConcatDatasetBalancedRandomSampler(
    source_train_dataset.datasets,
    max_samples_per_dataset=params.samples_per_subj_per_epoch,
)

source_train_loader = monai.data.DataLoader(
    source_train_dataset,
    sampler=source_train_sampler,
    batch_size=params.batch_size,
    pin_memory=True,
    num_workers=7,
    persistent_workers=True,
)

# Validation set.
source_vol_ds = list()
for subj_dict in filter(lambda s: s["subj_id"] in hcp_val_ids, hcp_subj_data):
    source_vol_ds.append(
        subj_dict["t1w"][
            None,
        ]
    )

source_val_dataset = torch.utils.data.ConcatDataset(source_vol_ds)

source_val_loader = monai.data.DataLoader(
    source_val_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
    #     persistent_workers=True,
)

In [None]:
# Set up clinic scan data.
target_patch_ds = list()
for subj_dict in filter(lambda s: s["subj_id"] in clinic_train_ids, clinic_subj_data):
    target_patch_ds.append(
        pitn.data.MaskFilteredPatchDataset3d(
            subj_dict["t1w"], mask=subj_dict["mask"], patch_size=params.train.patch_size
        )
    )

target_train_dataset = torch.utils.data.ConcatDataset(target_patch_ds)

# Calculate the number of clinic samples per subject to match the total length of the
# source domain dataset.
num_clinic_samples_per_img = np.floor(
    len(source_train_dataset.datasets)
    * params.samples_per_subj_per_epoch
    / len(target_train_dataset.datasets)
).astype(int)

target_train_sampler = pitn.samplers.ConcatDatasetBalancedRandomSampler(
    target_train_dataset.datasets,
    max_samples_per_dataset=num_clinic_samples_per_img,
)

target_train_loader = monai.data.DataLoader(
    target_train_dataset,
    sampler=target_train_sampler,
    batch_size=params.batch_size,
    pin_memory=True,
    num_workers=7,
    persistent_workers=True,
)

## Model Definition

In [None]:
class ClinicMatchGAN(pl.LightningModule):
    def __init__(
        self,
        num_channels: int,
        gen_num_high_freq: int = 3,
        discriminator_downsample_factors=[1, 2, 4],
        lambda_adversary_loss: float = 1,
        lambda_grad_penalty: float = 1,
        lambda_reconst_loss_weight=1,
        gen_optim_kwargs=dict(),
        discriminator_optim_kwargs=dict(),
        weight_init_fn=None,
        input_noise_dist=None,
        label_noise_dist=None,
    ):
        super().__init__()

        self.save_hyperparameters()
        if self.hparams.lambda_grad_penalty is None:
            self.hparams.use_grad_penalty = False
        else:
            self.hparams.use_grad_penalty = True
        self.generator = pitn.nn.gan.generative.LPTN(
            num_channels, num_high_freq_levels=self.hparams.gen_num_high_freq
        )

        self.discriminator = pitn.nn.gan.adversarial.MultiDiscriminator(
            num_channels, self.hparams.discriminator_downsample_factors
        )

        if weight_init_fn is not None:
            self.generator = self.generator.apply(weight_init_fn)
            self.discriminator = self.discriminator.apply(weight_init_fn)

        self.val_psnr_metric = monai.metrics.PSNRMetric(max_val=1.0)
        self.val_viz_slice = None
        self.val_viz_range = None

        self.plain_log = Box(default_box=True, loss_gen=dict(), loss_discrim=dict())
        self.plain_log.discrim_preds.real = dict()
        self.plain_log.discrim_preds.fake = dict()
        self.plain_log.val.discrim_preds.fake = dict()

    def forward(self, x, input_noise=True):
        if input_noise:
            if self.hparams.input_noise_dist is not None:
                x = x + self.hparams.input_noise_dist.sample(x.shape).to(x)
        return self.generator(x)

    def reconstruction_loss(self, y_source, y_pred):
        return F.mse_loss(y_source, y_pred, reduction="mean")

    def ls_adversarial_loss(self, sample, label: int, label_noise_dist=None):

        sample_pred = self.discriminator(sample)
        label_t = torch.ones_like(sample_pred) * label
        if label_noise_dist is not None:
            label_t = label_t + label_noise_dist.sample(label_t.shape).to(label_t)
        loss = F.mse_loss(sample_pred, label_t, reduction="mean")

        return loss

    def grad_penalty(
        self,
        real_samples: torch.Tensor,
        fake_samples: torch.Tensor,
    ):

        batch_size = real_samples.shape[0]
        avg_weight_rand = torch.rand(batch_size, *((1,) * (real_samples.ndim - 1))).to(
            real_samples
        )
        # For each sample in the batch, find a randomly-weighted linear interpolation
        # between the real and generated/fake samples.
        weighted_interpolate = (avg_weight_rand * real_samples) + (
            (1 - avg_weight_rand) * fake_samples
        )
        # Need to require grad for the gradient calculation.
        weighted_interp_samples = weighted_interpolate.requires_grad_(True)
        pred_interp_samples = self.discriminator(weighted_interp_samples)

        grad = torch.autograd.grad(
            outputs=pred_interp_samples,
            inputs=weighted_interp_samples,
            grad_outputs=torch.ones_like(pred_interp_samples),
            create_graph=True,
            only_inputs=True,
            retain_graph=True,
        )[0]

        grad = grad.view(batch_size, -1)
        # Calculate L2 norm manually so a small epsilon can be used to avoid NaNs.
        eps = 1e-7
        penalty = torch.mean((torch.sqrt(torch.sum((grad**2), dim=1) + eps) - 1) ** 2)

        return penalty

    def ls_gan_grad_penalty(self, real_samples, noise_scale: float, k=1):
        """Implements another form of grad penalty from Kodali, et. al., 2017, used in
        Mao, et. al., 2018 (2nd LS-GAN paper).
        """
        batch_size = real_samples.shape[0]

        # Technically, the original paper specified the noise as a multi-variate Gaussian
        # with a diagonal covariance matrix filled with the same value. So, the
        # 'c' value in that formulation would scale up the *variance*, while the
        # equivalent 1D Normal distribution here specifies the *standard deviation*.
        # It probably doesn't matter.
        noise_dist = torch.distributions.Normal(0.0, noise_scale)
        noise = noise_dist.sample(real_samples.shape).to(real_samples)
        noisy_samples = real_samples + noise
        # Need to require grad for the gradient calculation.
        noisy_samples = noisy_samples.requires_grad_(True)

        pred_samples = self.discriminator(noisy_samples)

        grad = torch.autograd.grad(
            outputs=pred_samples,
            inputs=noisy_samples,
            grad_outputs=torch.ones_like(pred_samples),
            create_graph=True,
            only_inputs=True,
            retain_graph=True,
        )[0]

        grad = grad.view(batch_size, -1)
        # Calculate L2 norm manually so a small epsilon can be used to avoid NaNs.
        eps = 1e-7
        penalty = torch.mean((torch.sqrt(torch.sum((grad**2), dim=1) + eps) - k) ** 2)

        return penalty

    def training_step(self, batch, batch_idx, optimizer_idx):

        source_samples = batch["source"]
        target_samples = batch["target"]
        if self.hparams.input_noise_dist is not None:
            source_samples = source_samples + self.hparams.input_noise_dist.sample(
                source_samples.shape
            ).to(source_samples)
        # Optimizer index decides whether this step updates the generator or discriminator.
        # Update generator network.
        if optimizer_idx == self._GENERATOR_OPTIMIZER_IDX:

            translated_samples = self.generator(source_samples)

            l_g_reconstruct = self.reconstruction_loss(
                source_samples, translated_samples
            )
            l_g_reconstruct *= self.hparams.lambda_reconst_loss_weight
            self.log(
                "train_loss_terms/gen_reconstruct",
                l_g_reconstruct.detach(),
            )

            l_g_adversarial = self.ls_adversarial_loss(
                translated_samples,
                label=0,
                label_noise_dist=self.hparams.label_noise_dist,
            )
            l_g_adversarial *= self.hparams.lambda_adversary_loss * 1 / 2
            self.log(
                "train_loss_terms/gen_adversarial",
                l_g_adversarial.detach(),
            )

            # Combine terms into final loss.
            loss_gen = l_g_reconstruct + l_g_adversarial
            self.log("train/gen_loss", loss_gen.detach())
            # Log loss and set up return dictionary.
            self.plain_log.loss_gen[self.global_step] = float(
                loss_gen.detach().cpu().item()
            )

            tqdm_dict = {"loss_gen": loss_gen.detach()}
            output = collections.OrderedDict(
                {"loss": loss_gen, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )

        ### Update discriminator network.
        elif optimizer_idx == self._DISCRIMINATOR_OPTIMIZER_IDX:

            # Real images.
            loss_real = (
                self.ls_adversarial_loss(
                    target_samples,
                    label=1,
                    label_noise_dist=self.hparams.label_noise_dist,
                )
                / 2
            )
            self.log("train_loss_terms/discrim_real_loss", loss_real.detach())

            # Translated (i.e., fake) images
            # We aren't updating the generator weights here, so there's no need to
            # keep track of the generator's gradients.
            with torch.no_grad():
                translated_samples = self.generator(source_samples)
            loss_fake = (
                self.ls_adversarial_loss(
                    translated_samples,
                    label=-1,
                    label_noise_dist=self.hparams.label_noise_dist,
                )
                / 2
            )
            self.log("train_loss_terms/discrim_fake_loss", loss_fake.detach())

            # Noise scaling found by taking `~ 0.1176 x abs diff between max and min`
            # (of values of the input tensors, here the samples from the target domain).
            if self.hparams.use_grad_penalty:
                grad_penalty = self.ls_gan_grad_penalty(
                    target_samples, noise_scale=0.2352
                )
                grad_penalty *= self.hparams.lambda_grad_penalty
                self.log(
                    "train_loss_terms/discrim_grad_penalty",
                    grad_penalty.detach(),
                )
            else:
                grad_penalty = torch.zeros_like(loss_fake)

            # Combine terms into final loss value.
            loss_discrim = loss_fake + loss_real + grad_penalty
            self.log("train/discrim_loss", loss_discrim.detach())

            # Record loss and set up return dictionary.
            self.plain_log.loss_discrim[self.global_step] = float(
                loss_discrim.detach().cpu().item()
            )
            tqdm_dict = {"loss_discrim": loss_discrim.detach()}
            output = collections.OrderedDict(
                {"loss": loss_discrim, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )

        else:
            raise RuntimeError(f"ERROR: Invalid optimizer index {optimizer_idx}")
        # Record discriminator predictions for later plotting.
        if self.global_step % 50 == 0:
            with torch.no_grad():
                real_preds = self.discriminator(target_samples)
                fake_preds = self.discriminator(translated_samples)
                self.plain_log.discrim_preds.real[self.global_step] = torch.clone(
                    real_preds.detach().cpu()
                )
                self.plain_log.discrim_preds.fake[self.global_step] = torch.clone(
                    fake_preds.detach().cpu()
                )
        return output

    def validation_step(self, batch, batch_idx):

        source_sample = batch
        if self.hparams.input_noise_dist is not None:
            source_sample = source_sample + self.hparams.input_noise_dist.sample(
                source_sample.shape
            ).to(source_sample)
        source_translate = self.generator(source_sample)
        reconstruction_loss = (
            self.hparams.lambda_reconst_loss_weight
            * self.reconstruction_loss(source_sample, source_translate)
        )
        adv_loss = (
            self.hparams.lambda_adversary_loss
            * self.ls_adversarial_loss(source_translate, label=0)
            / 2
        )

        self.log("val/reconstruct_loss", reconstruction_loss.detach())
        self.log("val/adversarial_loss", adv_loss.detach())

        psnr_loss = self.val_psnr_metric(y_pred=source_translate, y=source_sample)
        self.log("val/psnr", psnr_loss.detach())

        # Only plot subject translation if batch size of the validation step is 1.
        if source_sample.shape[0] == 1:
            # Save the discriminator's prediction over the entire translated volume.
            d_pred = self.discriminator(source_translate)
            # Should only have 1 set of predictions over all spatial scales.
            d_pred = d_pred.flatten().detach().cpu()
            self.plain_log.val.discrim_preds.fake[self.global_step] = d_pred

            plot_vol = source_translate[0].cpu()
            plot_vol = torch.clip(
                plot_vol,
                *tuple(torch.quantile(plot_vol, q=torch.as_tensor([0.001, 0.999]))),
            )
            plot_vol = monai.transforms.utils.rescale_array(
                plot_vol, minv=0.0, maxv=255.0
            )
            monai.visualize.img2tensorboard.add_animated_gif(
                image_tensor=plot_vol,
                writer=self.logger.experiment,
                tag="val_subj",
                max_out=1,
                scale_factor=1.0,
                global_step=self.global_step,
            )

            # Log a slice of the source, translated, and the abs. error.
            fig = plt.figure(dpi=100)

            if self.val_viz_slice is None:
                self.val_viz_slice = (
                    slice(None),
                    slice(None),
                    (source_translate.shape[-1] // 2) + 2,
                )
            vol_to_plot = [
                source_sample[0, 0].cpu().numpy(),
                source_translate[0, 0].cpu().numpy(),
                torch.abs(source_sample[0, 0] - source_translate[0, 0]).cpu().numpy(),
            ]
            vol_to_plot = list(map(lambda v: v[self.val_viz_slice], vol_to_plot))

            if self.val_viz_range is None:
                vmin = np.min(np.stack(vol_to_plot))
                vmax = np.max(np.stack(vol_to_plot))
                self.val_viz_range = (vmin, vmax)
            else:
                vmin, vmax = self.val_viz_range
            cmap = "gray"
            grid = mpl_toolkits.axes_grid1.ImageGrid(
                fig,
                111,
                nrows_ncols=(1, 3),
                axes_pad=0.1,
                share_all=True,
                cbar_mode="single",
                cbar_location="bottom",
                cbar_pad=0.1,
            )

            map_names = ["Source", "Translated", "Abs. Error"]
            for ax, label, vol in zip(grid, map_names, vol_to_plot):
                im = ax.imshow(
                    np.rot90(vol), interpolation=None, cmap=cmap, vmin=vmin, vmax=vmax
                )
                ax.set_xlabel(label)
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_xticklabels([])
                ax.set_yticklabels([])

            grid.cbar_axes[0].colorbar(im)

            self.logger.experiment.add_figure("val_slice", fig, self.global_step)

        return psnr_loss

    def configure_optimizers(self):

        opt_gen = torch.optim.Adam(
            self.generator.parameters(), **self.hparams.gen_optim_kwargs
        )
        opt_discriminator = torch.optim.Adam(
            self.discriminator.parameters(), **self.hparams.discriminator_optim_kwargs
        )

        self._GENERATOR_OPTIMIZER_IDX = 0
        self._DISCRIMINATOR_OPTIMIZER_IDX = 1
        return [opt_gen, opt_discriminator], []

## Model Training

In [None]:
# Training loop
train_start_timestamp = datetime.datetime.now().replace(microsecond=0)
# Explicitly set whether or not to use grad penalty.
lambda_grad_penalty = params.lambda_grad_penalty if params.use_grad_penalty else None

# Instantiate model.
model = ClinicMatchGAN(
    params.num_channels,
    gen_num_high_freq=params.num_laplace_high_freq,
    discriminator_downsample_factors=params.discriminator_downscale_factors,
    lambda_adversary_loss=params.lambda_adversary_loss,
    lambda_grad_penalty=lambda_grad_penalty,
    lambda_reconst_loss_weight=params.lambda_reconst_loss_weight,
    gen_optim_kwargs=params.optim.gen_kwargs,
    discriminator_optim_kwargs=params.optim.discriminator_kwargs,
    weight_init_fn=params.net_init.f,
    input_noise_dist=params.train.aug.input_noise_dist,
    label_noise_dist=params.train.aug.label_noise_dist,
)

with open(log_txt_file, "a+") as f:
    f.write(f"Model overview: {model}\n")

# Create trainer object.
trainer = pl.Trainer(
    gpus=1,
    max_epochs=params.max_epochs,
    logger=pl_logger,
    multiple_trainloader_mode="max_size_cycle",
    log_every_n_steps=min([50, len(source_train_loader), len(target_train_loader)]),
    check_val_every_n_epoch=3,
    #     progress_bar_refresh_rate=10,
    terminate_on_nan=True,
)

# Many warnings are produced here, so it's better for my sanity (and worse in every other
# way) to just filter and ignore them...
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    # with torch.autograd.detect_anomaly():
    trainer.fit(
        model,
        train_dataloaders={
            "source": source_train_loader,
            "target": target_train_loader,
        },
        val_dataloaders=source_val_loader,
    )

train_duration = datetime.datetime.now().replace(microsecond=0) - train_start_timestamp
print(f"Train duration: {train_duration}")
with open(log_txt_file, "a+") as f:
    f.write("\n")
    f.write(f"Training time: {train_duration}\n")
    f.write(
        f"\t{train_duration.days} Days, "
        + f"{train_duration.seconds // 3600} Hours,"
        + f"{(train_duration.seconds // 60) % 60} Minutes,"
        + f'{train_duration.seconds % 60} Seconds"\n'
    )

In [None]:
# Save out trained model
trainer.save_checkpoint(str(experiment_results_dir / "model.ckpt"))

## Result Visualization

In [None]:
enable_fig_save = True

### Metrics During Training

In [None]:
# Predictions over real data, with a target value of 1
plt.figure(dpi=120)

real_discrim_steps = list(model.plain_log.discrim_preds.real.keys())
real_discrim_steps = np.asarray(real_discrim_steps)
real_discrim_preds = list(model.plain_log.discrim_preds.real.values())
real_discrim_preds = torch.stack(real_discrim_preds, dim=0).cpu()
real_discrim_preds = torch.mean(real_discrim_preds, dim=1).numpy()

plt.plot(
    real_discrim_steps,
    real_discrim_preds,
    ls="-",
    label=[f"Real = 1, Scale {f}" for f in params.discriminator_downscale_factors],
    lw=0.9,
    alpha=0.85,
)

fake_discrim_steps = list(model.plain_log.discrim_preds.fake.keys())
fake_discrim_steps = np.asarray(fake_discrim_steps)
fake_discrim_preds = list(model.plain_log.discrim_preds.fake.values())
fake_discrim_preds = torch.stack(fake_discrim_preds, dim=0).cpu()
fake_discrim_preds = torch.mean(fake_discrim_preds, dim=1).numpy()

plt.plot(
    fake_discrim_steps,
    fake_discrim_preds,
    label=[f"Fakes = -1, Scale {f}" for f in params.discriminator_downscale_factors],
    lw=0.9,
    alpha=0.85,
)
plt.title("Prediction of Discriminator during Training")
plt.xlabel("Training Step")
plt.ylabel("Mean (Over Batch) Discriminator Prediction")
plt.legend(ncol=2)
plt.hlines(
    [-1, 1],
    -1,
    max(fake_discrim_steps.max(), real_discrim_steps.max()),
    color="black",
    ls="--",
    lw=0.7,
    alpha=0.8,
    zorder=999,
)
plt.xlim(-1, max(fake_discrim_steps.max(), real_discrim_steps.max()))
plt.ylim(-2.5, 2.5)
if enable_fig_save:
    plt.savefig(experiment_results_dir / "discrim_pred_in_training.png")

In [None]:
# Discriminator predictions during validation, e.g. what does the discriminator predict
# given entire volumes?

plt.figure(dpi=120)

# Only fake images are available
fake_discrim_steps = list(model.plain_log.val.discrim_preds.fake.keys())
fake_discrim_steps = np.asarray(fake_discrim_steps)
fake_discrim_preds = list(model.plain_log.val.discrim_preds.fake.values())
fake_discrim_preds = torch.stack(fake_discrim_preds, dim=0).cpu().numpy()

plt.plot(
    fake_discrim_steps,
    fake_discrim_preds,
    label=[f"Scale {f}" for f in params.discriminator_downscale_factors],
)
plt.title(
    "Prediction of Discriminator in Validation Set During Training\nAll Fake(=-1)"
)
plt.xlabel("Training Step")
plt.ylabel("Discriminator Prediction")
plt.hlines(
    [-1, 1], -1, fake_discrim_steps.max(), color="black", lw=0.75, ls="--", zorder=999
)
plt.xlim(-1, fake_discrim_steps.max())
plt.legend()

if enable_fig_save:
    plt.savefig(experiment_results_dir / "discrim_pred_in_validation.png")

### Final Inference Images

In [None]:
# Set up visualization objects.
# Find a common size for all volumes
spatial_shapes = list()
for subj in itertools.chain(hcp_subj_dataset, clinic_subj_dataset):
    spatial_shapes.append(tuple(subj["t1w"].shape[-3:]))
target_spatial_shape = tuple(np.max(np.asarray(spatial_shapes), axis=0))
padder = monai.transforms.SpatialPad(
    torch.Size(target_spatial_shape), method="symmetric", mode="replicate"
)
cropper = monai.transforms.CenterSpatialCrop(torch.Size(target_spatial_shape))

In [None]:
# Generate viz objects/volumes.
viz_data = Box(default_box=True)

with torch.no_grad():

    # Grab HCP data for viz.
    for subj in hcp_subj_dataset:
        data = Box(default_box=True)

        vol = subj["t1w"][
            None,
        ].to(model.device)
        # ! Make sure to perform model inference *before* descaling volume.
        translated = model.forward(vol, input_noise=True)
        pred_class = model.discriminator(translated)

        # Pad and crop vols to be the same shape.
        # There is a bug in monai that makes the cropping output as an ndarray
        vol = np.asarray(cropper(padder(vol[0].cpu())))
        vol = torch.from_numpy(vol)
        data.vol = subj["scaler"].descale(vol).numpy()

        mask = cropper(padder(subj["mask"].float())).astype(bool)
        mask = np.asarray(mask)
        data.mask = mask

        translated = np.asarray(cropper(padder(translated[0].cpu())))
        translated = torch.from_numpy(translated)
        translated = subj["scaler"].descale(translated.cpu()).numpy()
        data.translated = translated
        data.pred_class = pred_class[0].cpu().numpy()
        data.affine = subj["t1w_meta_dict"]["affine"]

        viz_data.hcp[str(subj["subj_id"])] = data

    # Grab clinic data for viz.
    for subj in clinic_subj_dataset:

        vol = subj["t1w"][
            None,
        ].to(model.device)
        # ! Make sure to predict real/fake *before* descaling volume.
        pred_class = model.discriminator(vol)

        # Pad and crop vols to be the same shape.
        # There is a bug in monai that makes the cropping output as an ndarray
        vol = np.asarray(cropper(padder(vol[0].cpu())))
        vol = torch.from_numpy(vol)
        data.vol = subj["scaler"].descale(vol).numpy()

        mask = cropper(padder(subj["mask"].float())).astype(bool)
        mask = np.asarray(mask)
        data.mask = mask
        data.pred_class = pred_class[0].cpu().numpy()
        data.affine = subj["t1w_meta_dict"]["affine"]

        viz_data.clinic[str(subj["subj_id"])] = data

In [None]:
# Save out all network predictions to Nifti2 files and compress them into a zip archive.
if enable_fig_save:
    img_names = list()
    for subj_id, viz in viz_data.hcp.items():
        pred_vol = viz.translated
        affine = viz.affine
        nib_img = nib.Nifti2Image(pred_vol, affine)

        filename = experiment_results_dir / f"{subj_id}_translated_t1w.nii.gz"
        nib.save(nib_img, str(filename))
        img_names.append(filename)

    with zipfile.ZipFile(experiment_results_dir / "translated_t1w.zip", "w") as fzip:
        for filename in img_names:
            fzip.write(
                filename,
                arcname=filename.name,
                compress_type=zipfile.ZIP_DEFLATED,
                compresslevel=6,
            )
            os.remove(filename)
    # Make sure we exit the 'with' statement above.
    print("Done with files")

In [None]:
with open(log_txt_file, "a+") as f:
    f.write("Fake: -1\n")
    f.write(str([v.pred_class for v in viz_data.hcp.values()]))
    f.write("\n\nReal: 1\n")
    f.write(str([v.pred_class for v in viz_data.clinic.values()]))
    f.write("\n")

In [None]:
# Predictions over real data, with a target value of 1
viz_real_d_preds = [v.pred_class for v in viz_data.hcp.values()]
viz_real_d_preds = np.stack(viz_real_d_preds)

viz_fake_d_preds = [v.pred_class for v in viz_data.clinic.values()]
viz_fake_d_preds = np.stack(viz_fake_d_preds)

print(f"Fake = -1, scales {params.discriminator_downscale_factors}")
print(viz_real_d_preds)
print(f"Real = 1, scales {params.discriminator_downscale_factors}")
print(viz_fake_d_preds)

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, sharey=True, dpi=100)
cmap = "Spectral"

axs[0, 0].matshow(viz_real_d_preds.T, cmap=cmap)
axs[0, 0].set_title("Pred. on Real(=1)")
axs[0, 0].set_ylabel("Downscale Factor")
axs[0, 0].set_xticks([])

axs[1, 0].matshow(viz_fake_d_preds.T, cmap=cmap)
axs[1, 0].set_title("Pred. on Fake(=-1)")
axs[1, 0].set_xticks([])

# Loop over data dimensions and create text annotations.
for ax, preds in zip([axs[0, 0], axs[1, 0]], [viz_real_d_preds.T, viz_fake_d_preds.T]):
    for i in range(preds.shape[0]):
        for j in range(preds.shape[1]):
            ax.text(
                j,
                i,
                f"{preds[i, j]:.2f}",
                ha="center",
                va="center",
                color="black",
                size="xx-small",
            )

plt.yticks(
    list(range(len(params.discriminator_downscale_factors))),
    params.discriminator_downscale_factors,
)
vmax = max(viz_real_d_preds.max(), viz_fake_d_preds.max())
vmin = min(viz_real_d_preds.min(), viz_fake_d_preds.min())
vmax = np.abs(max(vmax, vmin))
vmin = -1 * vmax
color_norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
fig.colorbar(
    mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap),
    ax=axs[1, :],
    location="top",
    #     fraction=0.1,
    #     pad=0.03,
)
axs[1, 0].axis("off")
axs[1, 1].axis("off")
# if enable_fig_save:
#     plt.savefig(experiment_results_dir / "whole_vol_discrim_preds.png")

In [None]:
hcp_viz_subj_idx = 1
hcp_viz_subj_id = list(viz_data.hcp.keys())[hcp_viz_subj_idx]
hcp_viz_subj = viz_data.hcp[hcp_viz_subj_id]
clinic_viz_subj_idx = 0
clinic_viz_subj_id = list(viz_data.clinic.keys())[clinic_viz_subj_idx]
clinic_viz_subj = viz_data.clinic[clinic_viz_subj_id]

# Grab 3 slices from each axis in roughly the center, offset by a few mms.
viz_slices = list()

# Horiztonal slice
viz_slices.append(
    (
        (
            0,
            slice(None),
            slice(None),
            (hcp_viz_subj.vol.shape[3] // 2) - 2,
        )
    )
)
# Coronal slice
viz_slices.append(
    (
        (
            0,
            slice(None),
            (hcp_viz_subj.vol.shape[2] // 2) + 3,
            slice(None),
        )
    )
)
# Saggital slice
viz_slices.append(
    (
        0,
        (hcp_viz_subj.vol.shape[1] // 2) + 4,
        slice(None),
        slice(None),
    )
)


def abs_error_map(y, y_pred):
    y = torch.as_tensor(y)
    y_pred = torch.as_tensor(y_pred)
    error = torch.abs(y - y_pred)

    return error.cpu().numpy()

### T1 Comparisons

In [None]:
# Display vols for the following:
# Source domain
# Translated
# Target domain
# Absolute error between source and translated

cmap = "gray"

col_names = ["Horizontal", "Coronal", "Saggital"]

row_names = [
    "Source HCP",
    "Translated",
    "Target Clinic",
    "Abs Error\nSource-Translated",
]

vol_rows = [
    hcp_viz_subj.vol,
    hcp_viz_subj.translated,
    clinic_viz_subj.vol,
    abs_error_map(hcp_viz_subj.vol, hcp_viz_subj.translated) * hcp_viz_subj.mask,
]

vols_to_plot = list()
for i in range(len(row_names)):
    col_plot = list()
    for j in range(len(col_names)):
        vol_to_plot = vol_rows[i]
        slice_to_plot = vol_to_plot[viz_slices[j]]
        col_plot.append(slice_to_plot)
    vols_to_plot.append(col_plot)

nrows = len(row_names)
ncols = len(col_names)

# Don't take the absolute max and min values, as there exist some extreme (e.g., > 3
# orders of magnitude) outliers. Instead, take some percente quantile.
# Reshape and concatenate the vols in order to compute the quantiles of images with
# different shapes (e.g., the low-res input patch).
max_vol = np.quantile(
    np.concatenate(
        [np.asarray(a).flatten() for a in itertools.chain.from_iterable(vols_to_plot)]
    ),
    1.0,
)
min_vol = np.quantile(
    np.concatenate(
        [np.asarray(a).flatten() for a in itertools.chain.from_iterable(vols_to_plot)]
    ),
    0.0,
)


fig = plt.figure(figsize=(4.5, 6), dpi=180)

grid = mpl.gridspec.GridSpec(
    nrows,
    ncols,
    figure=fig,
    hspace=0.05,
    wspace=0.05,
)
axs = list()
max_subplot_height = 0
for i_row in range(nrows):
    vol = vols_to_plot[i_row]

    for j_col in range(ncols):
        ax = fig.add_subplot(grid[i_row, j_col])
        ax.imshow(
            np.rot90(vol[j_col]),
            cmap=cmap,
            interpolation=None,
            vmin=min_vol,
            vmax=max_vol,
        )
        if ax.get_subplotspec().is_first_col():
            ax.set_ylabel(row_names[i_row], size="xx-small")
        if ax.get_subplotspec().is_last_row():
            ax.set_xlabel(col_names[j_col])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        # Update highest subplot to put the `suptitle` later on.
        max_subplot_height = max(
            max_subplot_height, ax.get_position(original=False).get_points()[1, 1]
        )
        axs.append(ax)

color_norm = mpl.colors.Normalize(vmin=min_vol, vmax=max_vol)
fig.colorbar(
    mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap),
    ax=axs,
    location="right",
    fraction=0.1,
    pad=0.03,
)
plt.suptitle(
    "Vol and Abs. Error, Normalized over All Images",
    y=max_subplot_height + 0.015,
    verticalalignment="bottom",
)
if enable_fig_save:
    plt.savefig(experiment_results_dir / "Vol_w_Abs_Err.png")

In [None]:
# No normalization.
# Display vols for the following:
# Source domain
# Translated
# Target domain
# Absolute error between source and translated

cmap = "gray"

col_names = ["Horizontal", "Coronal", "Saggital"]

row_names = [
    "Source HCP",
    "Translated",
    "Target Clinic",
    "Abs Error\nSource-Translated",
]

vol_rows = [
    hcp_viz_subj.vol,
    hcp_viz_subj.translated,
    clinic_viz_subj.vol,
    abs_error_map(hcp_viz_subj.vol, hcp_viz_subj.translated) * hcp_viz_subj.mask,
]

vols_to_plot = list()
for i in range(len(row_names)):
    col_plot = list()
    for j in range(len(col_names)):
        vol_to_plot = vol_rows[i]
        slice_to_plot = vol_to_plot[viz_slices[j]]
        col_plot.append(slice_to_plot)
    vols_to_plot.append(col_plot)

nrows = len(row_names)
ncols = len(col_names)

fig = plt.figure(figsize=(4.5, 6), dpi=180)

grid = mpl.gridspec.GridSpec(
    nrows,
    ncols,
    figure=fig,
    hspace=0.05,
    wspace=0.05,
)
axs = list()
max_subplot_height = 0
for i_row in range(nrows):
    vol = vols_to_plot[i_row]

    for j_col in range(ncols):
        ax = fig.add_subplot(grid[i_row, j_col])
        ax.imshow(
            np.rot90(vol[j_col]),
            cmap=cmap,
            interpolation=None,
        )
        if ax.get_subplotspec().is_first_col():
            ax.set_ylabel(row_names[i_row], size="xx-small")
        if ax.get_subplotspec().is_last_row():
            ax.set_xlabel(col_names[j_col])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        # Update highest subplot to put the `suptitle` later on.
        max_subplot_height = max(
            max_subplot_height, ax.get_position(original=False).get_points()[1, 1]
        )
        axs.append(ax)

plt.suptitle(
    "Vol and Abs. Error, No normalization over All Images",
    y=max_subplot_height + 0.015,
    verticalalignment="bottom",
)
if enable_fig_save:
    plt.savefig(experiment_results_dir / "No_Norm_Vol_slices.png")

---

## End Experiment

In [None]:
pl_logger.experiment.flush()
# Close tensorboard logger.
# Don't finalize if the experiment was for debugging.
if "debug" not in EXPERIMENT_NAME.casefold():
    pl_logger.finalize("success")
    # Experiment is complete, move the results directory to its final location.
    if experiment_results_dir != final_experiment_results_dir:
        print("Moving out of tmp location")
        experiment_results_dir = experiment_results_dir.rename(
            final_experiment_results_dir
        )
        log_txt_file = experiment_results_dir / log_txt_file.name

---