# PITN Model Results Analysis

In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 2

# imports
import collections
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
import zipfile
import ast

import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.patheffects
import matplotlib.pyplot as plt
import seaborn as sns

# Data management libraries.
import nibabel as nib
import nibabel.processing
import natsort
from natsort import natsorted
from pprint import pprint as ppr
from box import Box

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchio
import pytorch_lightning as pl
import monai
import einops
import torchinfo

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy

import pitn

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = f"direnv exec {os.getcwd()} /usr/bin/env"
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True)


In [None]:
# torch setup
# Will need CUDA for finding eigendecomposition of these large volumes.
# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device("cuda")
    # Activate cudnn benchmarking to optimize convolution algorithm speed.
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = True
        print("CuDNN convolution optimization enabled.")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

In [None]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"])
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"])
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

In [None]:
# Parameter setup.

params = Box(default_box=True)

# 6 channels for the 6 DTI components
params.n_channels = 6
params.n_subjs = 48
params.lr_vox_size = 2.5
params.fr_vox_size = 1.25

# Data params
params.data.fr_dir = data_dir / f"scale-{params.fr_vox_size:.2f}mm"
params.data.lr_dir = data_dir / f"scale-{params.lr_vox_size:.2f}mm"
params.data.dti_fname_pattern = r"sub-*dti.nii.gz"
params.data.mask_fname_pattern = r"dti/sub-*mask.nii.gz"

# The data were downsampled artificially by this factor.
params.data.downsampled_by_factor = params.lr_vox_size / params.fr_vox_size
params.data.downsampled_by_factor = (
    int(params.data.downsampled_by_factor)
    if int(params.data.downsampled_by_factor) == params.data.downsampled_by_factor
    else params.data.downsampled_by_factor
)

params.data.eigval_clip_cutoff = 0.00332008

# Needed for downsampling shape correction.
params.train.in_patch_size = (24, 24, 24)

## Load & Preprocess Ground Truth DTIs

### Subject Selection

In [None]:
# Find data directories for each subject.
subj_dirs: Box = Box()

selected_ids = [
    "397154",
    "224022",
    "140117",
    "751348",
    "894774",
    "156637",
    "227432",
    "303624",
    "185947",
    "810439",
    "753251",
    "644246",
    "141422",
    "135528",
    "103010",
    "700634",
    "406432",
    "803240",
    "815247",
    "167238",
    "100408",
    "792867",
    "157437",
    "164030",
    "103515",
    "118730",
    "198047",
    "189450",
    "203923",
    "108828",
    "124220",
    "386250",
    "118124",
    "701535",
    "679770",
    "382242",
    "231928",
    "196952",  # Hold-out subject; for visualization, ensure never in the train or val sets
    "567961",
    "910241",
    "175035",
    "567759",
    "978578",
    "150019",
    "690152",
    "297655",
    "307127",
    "634748",
]
HOLDOUT_SUBJ_ID = "196952"
selected_subjs = selected_ids
selected_subjs = natsorted(selected_subjs)

for subj_id in selected_subjs:
    subj_dirs[subj_id] = Box()
    subj_dirs[subj_id].fr = pitn.utils.system.get_file_glob_unique(
        params.data.fr_dir, f"*{subj_id}*"
    )
    subj_dirs[subj_id].lr = pitn.utils.system.get_file_glob_unique(
        params.data.lr_dir, f"*{subj_id}*"
    )
    assert subj_dirs[subj_id].fr.exists()
    assert subj_dirs[subj_id].lr.exists()
ppr(subj_dirs)

### Loading & Preprocessing

In [None]:
# This dict will contain all ground truth data.
subj_gt: Box

In [None]:
# Prep for Dataset loading.

# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=True)

# HR -> LR patch coordinate conversion function.
fr2lr_patch_coords_fn = {
    "lr_dti": functools.partial(
        pitn.coords.transform.int_downscale_patch_idx,
        downscale_factor=params.data.downsampled_by_factor,
        downscale_patch_shape=params.train.in_patch_size,
    )
}
fr2lr_patch_coords_fn["lr_log_euclid"] = fr2lr_patch_coords_fn["lr_dti"]
fr2lr_patch_coords_fn["lr_mask"] = fr2lr_patch_coords_fn["lr_dti"]

# Kwargs for the patches dataset (the _VolPatchDataset class) of the HR volumes.
patch_kwargs = dict(
    patch_shape=tuple(
        np.floor(
            np.asarray(params.train.in_patch_size) * params.data.downsampled_by_factor
        ).astype(int)
    ),
    stride=1,
    meta_keys_to_patch_index={"dti", "log_euclid", "mask"},
    mask_name="mask",
)


# Coefficients to the log-euclidean lower triangle/6D vector that properly scales
# the Euclidean distance under the log-euclidean metrics.
mat_norm_coeffs = torch.ones(6)
mat_norm_coeffs[torch.as_tensor([1, 3, 4])] = np.sqrt(2)
mat_norm_coeffs = mat_norm_coeffs.reshape(-1, 1, 1, 1)


def fix_downsample_shape_errors(
    fr_vol: torch.Tensor, fr_affine: torch.Tensor, target_spatial_shape: tuple
):
    """Small utility to fix shape differences between LR and FR data."""
    target_shape = np.asarray(target_spatial_shape)
    if fr_vol.shape[1:] != tuple(target_shape):
        # Use torchio objects because they fix the affine matrix, too.
        # Flip before transform to pad on the right/top/furthest side of the dimension
        # first, before the left/bottom/closest.
        flip_vol = fr_vol.flip([1, 2, 3])
        im = torchio.ScalarImage(tensor=flip_vol, affine=fr_affine)
        transform = torchio.transforms.CropOrPad(target_spatial_shape, 0, copy=False)
        im = transform(im)
        result_vol = im["data"]
        # Unflip.
        result_vol = result_vol.flip([1, 2, 3])
        result_aff = im["affine"]
    else:
        result_vol = fr_vol
        result_aff = fr_affine

    return result_vol, result_aff


def orient_to_viz(vol, affine):

    if torch.is_tensor(vol):
        v = vol.detach().cpu().numpy()
    else:
        v = vol
    v = np.rot90(np.rot90(v, k=1, axes=(1, 3)), k=2, axes=(2, 3))
    if torch.is_tensor(vol):
        v = torch.from_numpy(np.copy(v)).to(vol)

    # Adjust the affine matrix.
    full_rot_aff = np.zeros_like(affine)
    full_rot_aff[-1, -1] = 1.0
    # 90 degree rot around the second axis.
    q1 = nib.quaternions.angle_axis2quat(np.pi / 2, [0, 1, 0])
    # 180 degree rot around the first axis.
    q2 = nib.quaternions.angle_axis2quat(np.pi, [1, 0, 0])
    new_q = nib.quaternions.mult(q1, q2)
    rot_aff = nib.quaternions.quat2mat(new_q)
    full_rot_aff[:-1, :-1] = rot_aff
    new_aff = full_rot_aff @ affine

    return v, new_aff

In [None]:
subj_gt = Box(default_box=True)

meta_keys_to_keep = {"affine", "original_affine"}

with torch.no_grad():
    for subj_id, subj_dir in subj_dirs.items():

        data = dict()
        data["subj_id"] = subj_id
        fr_subj_dir = subj_dirs[subj_id]["fr"]
        lr_subj_dir = subj_dirs[subj_id]["lr"]
        data["fr_subj_dir"] = fr_subj_dir
        data["lr_subj_dir"] = lr_subj_dir

        ####### Low-resolution DTIs/volumes
        lr_dti_f = pitn.utils.system.get_file_glob_unique(
            lr_subj_dir, params.data.dti_fname_pattern
        )
        im = nib_reader.read(lr_dti_f)
        lr_dti, meta = nib_reader.get_data(im)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        lr_dti = torch.from_numpy(lr_dti)
        lr_dti, meta["affine"] = orient_to_viz(lr_dti, meta["affine"])
        data["lr_dti"] = lr_dti
        data["lr_dti_meta_dict"] = meta

        # May need to handle shape errors when re-upscaling back from LR to HR.
        lr_dti_shape = np.asarray(lr_dti.shape[1:])
        target_fr_shape = np.floor(
            lr_dti_shape * params.data.downsampled_by_factor
        ).astype(int)

        ####### Full-resolution images/volumes.
        # DTI.
        dti_f = pitn.utils.system.get_file_glob_unique(
            fr_subj_dir, params.data.dti_fname_pattern
        )
        im = nib_reader.read(dti_f)
        dti, meta = nib_reader.get_data(im)
        dti = torch.from_numpy(dti)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        dti, meta["affine"] = fix_downsample_shape_errors(
            dti, meta["affine"], target_fr_shape
        )
        dti, meta["affine"] = orient_to_viz(dti, meta["affine"])
        data["dti"] = dti
        data["dti_meta_dict"] = meta

        # Diffusion mask.
        mask_f = pitn.utils.system.get_file_glob_unique(
            fr_subj_dir, params.data.mask_fname_pattern
        )
        im = nib_reader.read(mask_f)
        mask, meta = nib_reader.get_data(im)
        meta = {k: meta[k] for k in meta_keys_to_keep}
        mask = torch.from_numpy(mask)
        # Add channel dim if not available.
        if mask.ndim == 3:
            mask = mask[None]
        mask, meta["affine"] = fix_downsample_shape_errors(
            mask, meta["affine"], target_fr_shape
        )
        mask, meta["affine"] = orient_to_viz(mask, meta["affine"])
        mask = mask.bool()
        data["mask"] = mask
        data["mask_meta_dict"] = meta

        # Construct a quick and cheap mask for the LR DTI
        cheap_lr_mask = F.interpolate(
            data["mask"][None].float(),
            size=data["lr_dti"][0].shape,
            mode="nearest",
        )[0]
        data["lr_mask"] = cheap_lr_mask.bool()

        # Consider this as the "noise correction" step to have more informative, consistent
        # results with minimal biasing. Otherwise, outliers (which are clearly errors) can
        # change loss and performance metrics by orders of magnitude for no good reason!
        if "eigval_clip_cutoff" in params.data and params.data.eigval_clip_cutoff:
            correct_dti = pitn.data.outliers.clip_dti_eigvals(
                data["dti"].to(device),
                tensor_components_dim=0,
                eigval_max=params.data.eigval_clip_cutoff,
            ).to(data["dti"])
            correct_dti = correct_dti * data["mask"]
            ####
            # sae_fr = (
            #     F.l1_loss(data["dti"], correct_dti, reduction="none") * data["mask"]
            # )
            # sae_fr = sae_fr.view(6, -1).sum(1)
            # mae_fr = sae_fr / torch.count_nonzero(data["mask"])
            # print(f"---Subj {subj_id}---")
            # print(
            #     "MAE of FR DTI after eigenvalue clipping:\n",
            #     mae_fr.tolist(),
            # )
            ####
            data["dti"] = correct_dti

            correct_lr_dti = pitn.data.outliers.clip_dti_eigvals(
                data["lr_dti"].to(device),
                tensor_components_dim=0,
                eigval_max=params.data.eigval_clip_cutoff,
            ).to(data["lr_dti"])
            lr_mask = data["lr_mask"]
            correct_lr_dti = correct_lr_dti * lr_mask
            ####
            # sae_lr = (
            #     F.l1_loss(data["lr_dti"], correct_lr_dti, reduction="none") * lr_mask
            # )
            # sae_lr = sae_lr.view(6, -1).sum(1)
            # mae_lr = sae_lr / torch.count_nonzero(lr_mask)
            # print(
            #     "MAE of LR DTI after eigenvalue clipping:\n",
            #     mae_lr.tolist(),
            # )
            ####
            data["lr_dti"] = correct_lr_dti

        ####### Log-euclid pre-computed volumes

        # LR log-euclid volume.
        lr_log_euclid = pitn.eig.tril_vec2sym_mat(data["lr_dti"], tril_dim=0)
        lr_log_euclid = pitn.riemann.log_euclid.log_map(lr_log_euclid)
        lr_log_euclid = pitn.eig.sym_mat2tril_vec(lr_log_euclid, tril_dim=0)
        lr_log_euclid = lr_log_euclid * mat_norm_coeffs
        data["lr_log_euclid"] = lr_log_euclid

        log_euclid = pitn.eig.tril_vec2sym_mat(data["dti"], tril_dim=0)
        log_euclid = pitn.riemann.log_euclid.log_map(log_euclid)
        log_euclid = pitn.eig.sym_mat2tril_vec(log_euclid, tril_dim=0)
        log_euclid = log_euclid * mat_norm_coeffs
        data["log_euclid"] = log_euclid

        # Pre-compute FA maps
        fa = pitn.metrics.fast_fa(
            data["dti"][None].to(device), foreground_mask=data["mask"][None].to(device)
        )
        fa = fa.to(data["dti"])[0]
        data["fa"] = fa

        # Finalize this subject.
        subj_gt[subj_id] = data
        print("Loaded Subject ", subj_id)
        print("=" * 20)

print("===Data Loaded===")

### Calculate Aggregate Statistics

In [None]:
# Only choose subjects in the training and validation datasets.

subj_agg_stats = Box(default_box=True)
subj_agg_stats.dti.min = torch.zeros(params.n_channels).to(subj_gt[selected_ids[0]].dti)
subj_agg_stats.dti.max = torch.zeros(params.n_channels).to(subj_agg_stats.dti.min)

subj_agg_stats.log_euclid.min = subj_agg_stats.dti.min
subj_agg_stats.log_euclid.max = subj_agg_stats.dti.max

for s in subj_gt.values():
    fr_mask = s.mask
    dti = torch.masked_select(s.dti, fr_mask)
    subj_agg_stats.dti.min = torch.minimum(
        subj_agg_stats.dti.min, dti.view(params.n_channels, -1).min(-1).values
    )
    subj_agg_stats.dti.max = torch.maximum(
        subj_agg_stats.dti.max, dti.view(params.n_channels, -1).max(-1).values
    )

    lr_dti = s.lr_dti
    lr_mask = s.lr_mask
    lr_dti = torch.masked_select(lr_dti, lr_mask)
    subj_agg_stats.dti.min = torch.minimum(
        subj_agg_stats.dti.min, lr_dti.view(params.n_channels, -1).min(-1).values
    )
    subj_agg_stats.dti.max = torch.maximum(
        subj_agg_stats.dti.max, lr_dti.view(params.n_channels, -1).max(-1).values
    )


print(subj_agg_stats.dti.min)
print(subj_agg_stats.dti.max)

In [None]:
# Calculate global ranges of data for PSNR calculations.

expander = functools.partial(einops.rearrange, pattern="c -> 1 c 1 1 1")
# Collect DTI global data features.
dti_min = expander(subj_agg_stats.dti.min)
dti_max = expander(subj_agg_stats.dti.max)

feat_min, feat_max = torch.as_tensor(
    [
        [0] * 6,
        [1] * 6,
    ]
)
feat_min = expander(feat_min)
feat_max = expander(feat_max)

# PSNR is calculated on the final output tensor components, so no log-euclidean or
# scaling will occur here.
psnr_range_params = pitn.data.norm.GlobalScaleParams(
    feature_min=feat_min, feature_max=feat_max, data_min=dti_min, data_max=dti_max
)

## Select & Describe Models for Comparison

In [None]:
# def col_params_from_log(log_file: Path):
#     with open(Path(log_file), "r") as f:
#         p_str = list()
#         params = dict(
#             use_anat=False, use_half_precision_float=False, use_log_euclid=False
#         )
#         # 'use_anat': False,
#         # 'use_half_precision_float': False,
#         # 'use_log_euclid': False,

#         search_line = False
#         for line in f.readlines():
#             if "hardware" in line.casefold():
#                 break

#             if search_line:
#                 found_key = None
#                 for k in params.keys():
#                     if k in line:
#                         found_key = k
#                         break
#                 if found_key is not None:
#                     v_str = line[line.index(":") + 1 :]
#                     v_str = (
#                         v_str.strip()
#                         .replace(",", "")
#                         .replace("[", "")
#                         .replace("]", "")
#                         .replace("{", "")
#                         .replace("}", "")
#                     )
#                     v = ast.literal_eval(v_str)
#                     params[k] = v

#             if "timestamp" in line.casefold():
#                 search_line = True

#     return params

def run_params_from_config(run_config_file: Path) -> dict:
    config_f = Path(run_config_file).resolve()
    with open(config_f, "r") as f:
        config_str = f.read()
    config = ast.literal_eval(config_str)
    # Clean up any recursive weirdness.
    config = Box(config).to_dict()
    return config

In [None]:
baseline_spline_runs = []
baseline_espcn_runs = []
baseline_revnet_runs = [
    "2022-03-22T11_07_32__uvers_espcn_revnet_split_1",
    "2022-03-22T11_07_32__uvers_espcn_revnet_split_2",
    "2022-03-23T04_03_59__uvers_espcn_revnet_split_3",
]
diqt_carn_single_stream_runs = []
diqt_carn_anat_stream_runs = [
    "2022-03-22T21_33_49__uvers_pitn_anat_stream_dti_split_1",
    "2022-03-24T03_22_55__uvers_pitn_anat_stream_dti_split_2",
    "2022-03-23T17_25_43__uvers_pitn_anat_stream_le_split_1",
]


In [None]:
cols = ["run_name", "model_name", "use_le", "use_anat", "use_half_precision"]


## Unpack Model Predictions

## Calculate Metrics

In [None]:
# Save out Results

## Display Results