# Non-Learning Baseline Comparisons for Super-Resolution of fODFs in Diffusion MRI

Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

## Imports & Setup

In [19]:
# Imports
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 2

# imports
import collections
import copy
import datetime
import functools
import inspect
import io
import itertools
import math
import os
import pathlib
import pdb
import random
import shutil
import subprocess
import sys
import tempfile
import time
import typing
import warnings
import zipfile
from pathlib import Path
from pprint import pprint as ppr

import aim
import dotenv
import einops

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import monai

# Data management libraries.
import nibabel as nib
import nibabel.processing
import transforms3d

# Computation & ML libraries.
import numpy as np
import SimpleITK as sitk
import pandas as pd
import seaborn as sns
import skimage
import torch
import torch.nn.functional as F
from box import Box
from icecream import ic
from natsort import natsorted
import dipy

import pitn

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})
plt.rcParams.update({"image.cmap": "gray"})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True)

direnv: loading ~/work/pitn/.envrc
direnv: creating conda environment

CondaValueError: prefix already exists: /home/tas6hh/miniconda/envs/pitn2


EnvironmentLocationNotFound: Not a conda environment: /



True

In [3]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():
    # GPU information
    try:
        gpu_info = pitn.utils.system.get_gpu_specs()
        print(gpu_info)
    except NameError:
        print("CUDA Version: ", torch.version.cuda)
else:
    print("CUDA not in use, falling back to CPU")

In [4]:
# cap is defined in an ipython magic command
try:
    print(cap)
except NameError:
    pass

Author: Tyler Spears

Last updated: 2022-12-01T19:34:45.848421+00:00

Python implementation: CPython
Python version       : 3.10.8
IPython version      : 8.4.0

Compiler    : GCC 10.4.0
OS          : Linux
Release     : 5.15.0-53-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 20
Architecture: 64bit

Git hash: 0d7cc23be643dc3985734d4d9cac80a355c2ebe9

json        : 2.0.9
aim         : 3.14.4
skimage     : 0.19.3
sys         : 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:23:14) [GCC 10.4.0]
nibabel     : 4.0.1
SimpleITK   : 2.1.0
monai       : 1.0.0
pandas      : 1.5.2
matplotlib  : 3.5.2
transforms3d: 0.4.1
numpy       : 1.23.4
seaborn     : 0.12.1
torch       : 1.12.1
einops      : 0.4.1
pitn        : 0.0.post1.dev206+gf002231.d20220911

CUDA not in use, falling back to CPU



## Experiment & Parameters Setup

In [5]:
p = Box(default_box=True)
# Experiment defaults, can be overridden in a config file.

# General experiment-wide params
###############################################
# p.experiment_name = "interpolation_baseline"
p.override_experiment_name = False
p.results_dir = "/data/srv/outputs/pitn/results/runs"
p.tmp_results_dir = "/data/srv/outputs/pitn/results/tmp"

p.tvt_split_files = list(Path("./data_splits").glob("HCP*train-val-test_split*.csv"))

# If a config file exists, override the defaults with those values.
try:
    if "PITN_CONFIG" in os.environ.keys():
        config_fname = Path(os.environ["PITN_CONFIG"])
    else:
        config_fname = pitn.utils.system.get_file_glob_unique(Path("."), r"config.*")
    f_type = config_fname.suffix.casefold()
    if f_type in {".yaml", ".yml"}:
        f_params = Box.from_yaml(filename=config_fname)
    elif f_type == ".json":
        f_params = Box.from_json(filename=config_fname)
    elif f_type == ".toml":
        f_params = Box.from_toml(filename=config_fname)
    else:
        raise RuntimeError()

    p.merge_update(f_params)

except:
    print("WARNING: Config file not loaded")
    pass

# Remove the default_box behavior now that params have been fully read in.
_p = Box(default_box=False)
_p.merge_update(p)
p = _p



In [6]:
subj_ids = set()
for f in p.tvt_split_files:
    split = pd.read_csv(f)
    split_subjs = set(split.subj_id.tolist())
    subj_ids = subj_ids | split_subjs

p.subj_ids = natsorted(list(subj_ids))
p.viz_subjs = random.sample(p.subj_ids, k=4)

## Data Loading

In [7]:
hcp_full_res_data_dir = Path("/data/srv/data/pitn/hcp")
hcp_full_res_fodf_dir = Path("/data/srv/outputs/pitn/hcp/full-res/fodf")
hcp_low_res_data_dir = Path("/data/srv/outputs/pitn/hcp/downsample/scale-2.00mm/vol")
hcp_low_res_fodf_dir = Path("/data/srv/outputs/pitn/hcp/downsample/scale-2.00mm/fodf")

assert hcp_full_res_data_dir.exists()
assert hcp_full_res_fodf_dir.exists()
assert hcp_low_res_data_dir.exists()
assert hcp_low_res_fodf_dir.exists()

### Validation & Test Datasets

In [8]:
with warnings.catch_warnings(record=True) as warn_list:

    # Test dataset.
    # The test dataset won't be cached, as each image should only be loaded once.
    test_paths_dataset = pitn.data.datasets.HCPfODFINRDataset(
        subj_ids=p.subj_ids,
        dwi_root_dir=hcp_full_res_data_dir,
        fodf_root_dir=hcp_full_res_fodf_dir,
        lr_dwi_root_dir=hcp_low_res_data_dir,
        lr_fodf_root_dir=hcp_low_res_fodf_dir,
        transform=pitn.data.datasets.HCPfODFINRDataset.default_pre_sample_tf(
            0, skip_sample_mask=True
        ),
    )
    test_dataset = pitn.data.datasets.HCPfODFINRWholeVolDataset(
        test_paths_dataset,
        transform=pitn.data.datasets.HCPfODFINRWholeVolDataset.default_tf(),
    )

print("=" * 10)
print("Warnings caught:")
ws = "\n".join(
    [
        warnings.formatwarning(
            w.message, w.category, w.filename, w.lineno, w.file, w.line
        )
        for w in warn_list
    ]
)
ws = "\n".join(filter(lambda s: bool(s.strip()), ws.splitlines()))
print(ws, flush=True)
print("=" * 10)




In [9]:
test_dataloader = monai.data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=True,
    pin_memory=False,
    # num_workers=0,
    num_workers=3,
    prefetch_factor=3,
    persistent_workers=True,
)

## Evaluation

In [10]:
result_table = {"subj_id": list(), "mse_mean": list(), "mse_var": list()}

### Linear Interpolation

In [13]:
SHOW_WARNINGS = False

RUN_NAME = "linear_interpolation_baseline"
ts = datetime.datetime.now().replace(microsecond=0).isoformat()
# Break ISO format because many programs don't like having colons ':' in a filename.
ts = ts.replace(":", "_")
tmp_res_dir = Path(p.tmp_results_dir) / "_".join([ts, RUN_NAME])
tmp_res_dir.mkdir(parents=True)

lin_results = copy.deepcopy(result_table)
with warnings.catch_warnings(record=True) as warn_list:
    for subj_dict in test_dataloader:
        subj_id = subj_dict["subj_id"][0]
        x = subj_dict["lr_fodf"][0].cpu()
        y = subj_dict["fodf"][0].cpu()
        mask = subj_dict["mask"][0].bool().cpu()
        x_affine = subj_dict["affine_lrvox2acpc"][0].cpu().numpy().astype(np.double)
        y_affine = subj_dict["affine_vox2acpc"][0].cpu().numpy().astype(np.double)
        x_transl, x_rot, x_zoom, x_shear = transforms3d.affines.decompose(x_affine)
        y_transl, y_rot, y_zoom, y_shear = transforms3d.affines.decompose(y_affine)

        x = einops.rearrange(x, "c z y x -> z y x c")
        x = sitk.GetImageFromArray(x.numpy())
        x.SetSpacing(tuple(x_zoom))
        x.SetOrigin(tuple(x_transl))
        x.SetDirection(tuple(x_rot.flatten()))

        y = einops.rearrange(y, "c z y x -> z y x c")
        y_img = sitk.GetImageFromArray(y.numpy())
        y_img.SetSpacing(tuple(y_zoom))
        y_img.SetOrigin(tuple(y_transl))
        y_img.SetDirection(tuple(y_rot.flatten()))

        resampled = sitk.Resample(x, y_img, interpolator=sitk.sitkLinear)
        y_pred = sitk.GetArrayViewFromImage(resampled)
        y_pred = einops.rearrange(y_pred, "z y x c -> c z y x")
        y_pred = torch.from_numpy(y_pred).to(y)

        y = einops.rearrange(y, "z y x c -> c z y x")
        mask_broad = torch.broadcast_to(mask, y.shape)
        mse = F.mse_loss(y_pred[mask_broad], y[mask_broad], reduction="none")
        mse_mean = mse.mean().item()
        mse_var = torch.var(mse).item()

        lin_results["subj_id"].append(subj_id)
        lin_results["mse_mean"].append(mse_mean)
        lin_results["mse_var"].append(mse_var)

        print(f"Subj {subj_id} MSE {mse_mean} |", end=" ")

        if int(subj_id) in p.viz_subjs:
            print("Creating prediction viz")
            with mpl.rc_context({"font.size": 6.0}):
                fig = plt.figure(dpi=175, figsize=(9, 5))
                fig, _ = pitn.viz.plot_fodf_coeff_slices(
                    y_pred,
                    y,
                    y_pred - y,
                    fig=fig,
                    fodf_vol_labels=("Predicted", "Target", "Pred - GT"),
                    imshow_kwargs={"interpolation": "antialiased", "cmap": "gray"},
                )
                fig_fname = f"subj_{subj_id}_{RUN_NAME}_viz.png"
                fig.savefig(tmp_res_dir / fig_fname)
                plt.close(fig)

pd.DataFrame.from_dict(lin_results).to_csv(tmp_res_dir / f"run_results_{RUN_NAME}.csv")
shutil.copytree(tmp_res_dir, Path(p.results_dir) / tmp_res_dir.name)

if SHOW_WARNINGS:
    print("=" * 10)
    print("Warnings caught:")
    ws = "\n".join(
        [
            warnings.formatwarning(
                w.message, w.category, w.filename, w.lineno, w.file, w.line
            )
            for w in warn_list
        ]
    )
    ws = "\n".join(filter(lambda s: bool(s.strip()), ws.splitlines()))
    print(ws, flush=True)
    print("=" * 10)

Subj 647858 MSE 0.00026799298939295113 | Subj 510225 MSE 0.0006521903560496867 | Subj 165032 MSE 0.0007730822544544935 | Subj 497865 MSE 0.0007140398374758661 | Subj 519950 MSE 0.0012175769079476595 | Subj 153227 MSE 0.0015535678248852491 | Subj 149236 MSE 0.0006453408859670162 | Subj 135225 MSE 0.00028459817986004055 | Subj 134425 MSE 0.0008819212089292705 | Subj 845458 MSE 0.0015323726693168283 | Subj 500222 MSE 0.0012331182369962335 | Subj 896778 MSE 0.0014263910707086325 | Subj 236130 MSE 0.00031864785705693066 | Subj 217429 MSE 0.0015833323122933507 | Subj 143830 MSE 0.00038678274722769856 | Subj 139839 MSE 0.0004411425907164812 | Subj 929464 MSE 0.00023061926185619086 | Subj 480141 MSE 0.001694306032732129 | Subj 195849 MSE 0.0009307701839134097 | Subj 769064 MSE 0.0005133814993314445 | Subj 248339 MSE 0.0008642168249934912 | Subj 749058 MSE 0.00035164383007213473 | Subj 159946 MSE 0.00028469215612858534 | Subj 857263 MSE 0.0004839627363253385 | Subj 580751 MSE 0.0002975362294819

In [17]:
df = pd.DataFrame.from_dict(lin_results)
df.mse_mean.median()

0.0005742022767663002

In [39]:
sph = dipy.data.get_sphere("symmetric362")
sph.find_closest([sph.x[43], sph.y[56], sph.z[8]])

34