# Pain in the Net - Baseline Shallow Methods on T1w Images


Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

## Imports & Environment Setup

### Imports

In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 1

# imports
import collections
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
import zipfile

import ants
import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dipy.viz
import dipy.viz.regtools
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import mpl_toolkits
import matplotlib.pyplot as plt
import seaborn as sns

import IPython

# Try importing GPUtil for printing GPU specs.
# May not be installed if using CPU only.
try:
    import GPUtil
except ImportError:
    warnings.warn("WARNING: Package GPUtil not found, cannot print GPU specs")
from tabulate import tabulate
from IPython.display import display, Markdown
import ipyplot

# Data management libraries.
import nibabel as nib
import natsort
from natsort import natsorted
import addict
from addict import Addict
import box
from box import Box
import pprint
from pprint import pprint as ppr

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchio
import pytorch_lightning as pl
import monai

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy
import sklearn
import sklearn.neighbors
import sklearn.pipeline
import sklearn.mixture

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, edgeitems=2, threshold=100, linewidth=88)
torch.set_printoptions(
    sci_mode=False, edgeitems=2, threshold=100, linewidth=88, profile="short"
)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = f"direnv exec {os.getcwd()} /usr/bin/env"
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [None]:
# Project-specific scripts
# It's easier to import it this way rather than make an entirely new package, due to
# conflicts with local packages and anaconda installations.
# You made me do this, poor python package management!!
if "PROJECT_ROOT" in os.environ:
    lib_location = str(Path(os.environ["PROJECT_ROOT"]).resolve())
else:
    lib_location = str(Path("../../../").resolve())
if lib_location not in sys.path:
    sys.path.insert(0, lib_location)
import lib as pitn

# Include the top-level lib module along with its submodules.
%aimport lib
# Grab all submodules of lib, not including modules outside of the package.
includes = list(
    filter(
        lambda m: m.startswith("lib."),
        map(lambda x: x[1].__name__, inspect.getmembers(pitn, inspect.ismodule)),
    )
)
# Run aimport magic with constructed includes.
ipy = IPython.get_ipython()
ipy.run_line_magic("aimport", ", ".join(includes))

In [None]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

### Specs Recording

In [None]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():

    # GPU information
    # Taken from
    # <https://www.thepythoncode.com/article/get-hardware-system-information-python>.
    # If GPUtil is not installed, skip this step.
    try:
        gpus = GPUtil.getGPUs()
        print("=" * 50, "GPU Specs", "=" * 50)
        list_gpus = []
        for gpu in gpus:
            # get the GPU id
            gpu_id = gpu.id
            # name of GPU
            gpu_name = gpu.name
            driver_version = gpu.driver
            cuda_version = torch.version.cuda
            # get total memory
            gpu_total_memory = f"{gpu.memoryTotal}MB"
            gpu_uuid = gpu.uuid
            list_gpus.append(
                (
                    gpu_id,
                    gpu_name,
                    driver_version,
                    cuda_version,
                    gpu_total_memory,
                    gpu_uuid,
                )
            )

        print(
            tabulate(
                list_gpus,
                headers=(
                    "id",
                    "Name",
                    "Driver Version",
                    "CUDA Version",
                    "Total Memory",
                    "uuid",
                ),
            )
        )
    except NameError:
        print("CUDA Version: ", torch.version.cuda)

else:
    print("CUDA not in use, falling back to CPU")

In [None]:
# cap is defined in an ipython magic command
print(cap)

Author: Tyler Spears

Last updated: 2021-10-27T15:47:32.446382+00:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.23.1

Compiler    : GCC 7.3.0
OS          : Linux
Release     : 5.4.0-89-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

Git hash: d256576f3fe98ffb4827d194f1aa31d6bae082c1

GPUtil           : 1.4.0
torchio          : 0.18.37
skimage          : 0.18.1
addict           : 2.4.0
monai            : 0.7.dev2138
seaborn          : 0.11.1
matplotlib       : 3.4.1
torch            : 1.9.0
sklearn          : 0.0
scipy            : 1.5.3
sys              : 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0]
IPython          : 7.23.1
pandas           : 1.2.3
ants             : 0.2.7
dipy             : 1.4.1
pytorch_lightning: 1.4.5
nibabel          : 3.2.1
natsort          : 7.1.1
numpy            : 1.20.2
box              : 5.4.1
json             : 2.0.9

  id  Name              Driver Version      CUDA Ve

### Data Variables & Definitions Setup

In [None]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"])

processed_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"])
hcp_processed_data_dir = (
    processed_data_dir / "hcp/derivatives/mean-downsample/scale-1.00mm"
)
hcp_orig_scale_data_dir = (
    processed_data_dir / "hcp/derivatives/mean-downsample/scale-orig"
)
clinic_processed_data_dir = (
    processed_data_dir / "oasis3/derivatives/mean-downsample/scale-orig"
)
assert (
    hcp_processed_data_dir.exists()
    and clinic_processed_data_dir.exists()
    and hcp_orig_scale_data_dir.exists()
)
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

### Experiment Logging Setup

### Experiment Parameters

In [None]:
# Parameters
params = Box(default_box=True)

# Data params.
params.num_channels = 1
params.hcp.num_subjects = 13
params.clinic.num_subjects = 9
params.clamp_percentiles = (0.01, 99.0)
# params.data_scale_range = None
# Scale input data by the valid values of each channel of the vol.
# I.e., Dx,x in [0, 1], Dx,y in [-1, 1], Dy,y in [0, 1], Dy,z in [-1, 1], etc.
params.data_scale_range = None


params.use_grad_penalty = False

## Data Loading

In [None]:
# Transformation pipeline.
# The input to the laplacian pyramid must be divisible by 2 for the number of high-
# frequency levels in the pyramid.
laplace_pyramid_divisible_by_shape = 2**3

pre_process_pipeline = monai.transforms.Compose(
    [
        monai.transforms.CropForegroundd(["t1w", "mask"], source_key="mask", margin=3),
        monai.transforms.DivisiblePadd(
            ["t1w", "mask"], laplace_pyramid_divisible_by_shape
        ),
        monai.transforms.ToTensord("t1w", dtype=torch.float),
        monai.transforms.ToTensord("mask", dtype=torch.float),
    ]
)

### Load and Pre-Process HCP Data

In [None]:
# Find data directories for each subject.
hcp_subj_dirs: dict = dict()

possible_ids = [
    "sub-397154",
    "sub-224022",
    "sub-140117",
    "sub-751348",
    "sub-894774",
    "sub-156637",
    "sub-227432",
    "sub-303624",
    "sub-185947",
    "sub-810439",
    "sub-753251",
    "sub-644246",
    "sub-141422",
    "sub-135528",
    "sub-103010",
    "sub-700634",
]

## Sub-set the chosen participants for dev and debugging!
selected_ids = random.sample(possible_ids, params.hcp.num_subjects)
if params.hcp.num_subjects < len(possible_ids):
    warnings.warn(
        "WARNING: Sub-selecting participants for dev and debugging. "
        + f"Subj IDs selected: {selected_ids}"
    )
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_ids = natsorted(selected_ids)

for subj_id in selected_ids:
    hcp_subj_dirs[subj_id] = hcp_processed_data_dir / f"{subj_id}"
    assert hcp_subj_dirs[subj_id].exists()
ppr(hcp_subj_dirs)

In [None]:
# Data loading and processing loop.
hcp_subj_data = list()
# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=True)

# Directory prefixes for each image to be read.
t1w_file_prefix = "anat"
mask_file_prefix = "mask"

for subj_id, subj_dir in hcp_subj_dirs.items():
    subj_data = dict()
    subj_data["subj_id"] = subj_id

    # Load the T1s
    img_dir = subj_dir / t1w_file_prefix
    img_filename = list(img_dir.glob(f"{subj_id}*T1w*.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    # Add channel dimension if not found.
    if len(img.shape) == 3:
        img = img[
            None,
        ]
    subj_data["t1w"] = img
    # The default metadata key name for monai.
    subj_data["t1w_meta_dict"] = metadata

    # Load masks
    img_dir = subj_dir / mask_file_prefix
    img_filename = list(img_dir.glob(f"{subj_id}*mask*.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    if len(img.shape) == 3:
        img = img[
            None,
        ]
    subj_data["mask"] = img
    # The default metadata key name for monai.
    subj_data["mask_meta_dict"] = metadata

    # Pre-process subject vols.
    subj_data = pre_process_pipeline(subj_data)

    # Perform scaling of input data?
    if params.data_scale_range is not None:
        scaler = pitn.data.norm.DTIMinMaxScaler(
            params.data_scale_range[0],
            params.data_scale_range[1],
            quantile_low=params.clamp_percentiles[0] / 100,
            quantile_high=params.clamp_percentiles[1] / 100,
            dim=(1, 2, 3),
            channel_size=params.num_channels,
            clip=True,
        )
        scaled = scaler.scale(subj_data["t1w"] * subj_data["mask"], stateful=True)
        subj_data["t1w"] = scaled * subj_data["mask"]
        subj_data["scaler"] = scaler

    hcp_subj_data.append(subj_data)

# Create dataset with all HCP subjects included.
hcp_subj_dataset = monai.data.Dataset(hcp_subj_data)

### Load & Pre-Process Clinical Data

In [None]:
# Find data directories for each subject.
clinic_subj_dirs: dict = dict()

possible_ids = [
    "sub-OAS30188_MR_d3844",
    "sub-OAS30375_MR_d5792",
    "sub-OAS30558_MR_d2148",
    "sub-OAS30643_MR_d0280",
    "sub-OAS30685_MR_d0032",
    "sub-OAS30762_MR_d0043",
    "sub-OAS30770_MR_d1201",
    "sub-OAS30944_MR_d0089",
    "sub-OAS31018_MR_d0041",
    "sub-OAS31157_MR_d4924",
]

## Sub-set the chosen participants for dev and debugging!
selected_ids = random.sample(possible_ids, params.clinic.num_subjects)
if params.clinic.num_subjects < len(possible_ids):
    warnings.warn(
        "WARNING: Sub-selecting participants for dev and debugging. "
        + f"Subj IDs selected: {selected_ids}"
    )
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_ids = natsorted(selected_ids)

for subj_id in selected_ids:
    clinic_subj_dirs[subj_id] = clinic_processed_data_dir / f"{subj_id}"
    assert clinic_subj_dirs[subj_id].exists()
ppr(clinic_subj_dirs)

In [None]:
# Data loading and processing loop.
clinic_subj_data = list()
# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=True)

# Directory prefixes for each image to be read.
t1w_file_prefix = "anat"
mask_file_prefix = "mask"

for subj_id, subj_dir in clinic_subj_dirs.items():
    subj_data = dict()
    subj_data["subj_id"] = subj_id

    # Load the T1s
    img_dir = subj_dir / t1w_file_prefix
    img_filename = list(img_dir.glob(f"{subj_id}*T1w*.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    # Add channel dimension if not found.
    if len(img.shape) == 3:
        img = img[
            None,
        ]
    subj_data["t1w"] = img
    # The default metadata key name for monai.
    subj_data["t1w_meta_dict"] = metadata

    # Load masks
    img_dir = subj_dir / mask_file_prefix
    img_filename = list(img_dir.glob(f"{subj_id}*mask*.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    if len(img.shape) == 3:
        img = img[
            None,
        ]
    subj_data["mask"] = img
    # The default metadata key name for monai.
    subj_data["mask_meta_dict"] = metadata

    # Pre-process subject vols.
    subj_data = pre_process_pipeline(subj_data)

    # Perform scaling of input data?
    if params.data_scale_range is not None:
        scaler = pitn.data.norm.DTIMinMaxScaler(
            params.data_scale_range[0],
            params.data_scale_range[1],
            quantile_low=params.clamp_percentiles[0] / 100,
            quantile_high=params.clamp_percentiles[1] / 100,
            dim=(1, 2, 3),
            channel_size=params.num_channels,
            clip=True,
        )
        scaled = scaler.scale(subj_data["t1w"] * subj_data["mask"], stateful=True)
        subj_data["t1w"] = scaled * subj_data["mask"]
        subj_data["scaler"] = scaler

    clinic_subj_data.append(subj_data)

# Create dataset with all "clinical quality" subjects included.
clinic_subj_dataset = monai.data.Dataset(clinic_subj_data)

## Data Exploration

### T1 Histograms of Intensities

In [None]:
# HCP
num_hcp_subjs = len(hcp_subj_dataset)
nrows = np.ceil(np.sqrt(num_hcp_subjs)).astype(int)
ncols = nrows

fig = plt.figure(dpi=120, figsize=(7, 5))
grid = mpl.gridspec.GridSpec(nrows, ncols, figure=fig, wspace=0.1, hspace=0.1)

xlim = [np.inf, -np.inf]
ylim = [np.inf, -np.inf]
grid_cell_used = collections.defaultdict(lambda: False)

for dataset, (i, j) in zip(
    hcp_subj_dataset, itertools.product(range(nrows), range(ncols))
):

    ax = fig.add_subplot(grid[i, j])

    ax.hist(dataset["t1w"][dataset["mask"].bool()].flatten().cpu().numpy(), bins=1000)
    xlim[0] = min(xlim[0], ax.get_xlim()[0])
    xlim[1] = max(xlim[1], ax.get_xlim()[1])
    ylim[0] = min(ylim[0], ax.get_ylim()[0])
    ylim[1] = max(ylim[1], ax.get_ylim()[1])
    grid_cell_used[i, j] = True

xlim[1] += 10
ylim[1] += 10
xticks = np.round(
    np.arange(xlim[0], xlim[1], round(xlim[1] // 4, -len(str(round(xlim[1] // 4))) + 1))
).astype(int)
yticks = np.round(
    np.arange(ylim[0], ylim[1], round(ylim[1] // 4, -len(str(round(ylim[1] // 4))) + 1))
).astype(int)

for ax in fig.axes:
    ax_spec = ax.get_subplotspec()
    i = list(ax_spec.rowspan)[0]
    j = list(ax_spec.colspan)[0]

    if not grid_cell_used[i + 1, j]:
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticks, fontdict={"fontsize": "x-small"}, rotation=30)
    else:
        ax.set_xticks([])
    if not grid_cell_used[i, j - 1]:
        ax.set_yticks(yticks)
        ax.set_yticklabels(yticks, fontdict={"fontsize": "x-small"})
    else:
        ax.set_yticks([])

fig.align_labels()

In [None]:
# Clinic
num_clinic_subjs = len(clinic_subj_dataset)
nrows = np.ceil(np.sqrt(num_clinic_subjs)).astype(int)
ncols = nrows

fig = plt.figure(dpi=120, figsize=(7, 5))
grid = mpl.gridspec.GridSpec(nrows, ncols, figure=fig, wspace=0.1, hspace=0.1)

xlim = [np.inf, -np.inf]
ylim = [np.inf, -np.inf]
grid_cell_used = collections.defaultdict(lambda: False)

for dataset, (i, j) in zip(
    clinic_subj_dataset, itertools.product(range(nrows), range(ncols))
):

    ax = fig.add_subplot(grid[i, j])
    # ax.hist(dataset["t1w"][dataset["mask"].bool()].flatten().cpu().numpy(), bins=30)
    # sns.kdeplot(dataset["t1w"][dataset["mask"].bool()].flatten().cpu().numpy(), ax=ax)
    sns.histplot(dataset["t1w"][dataset["mask"].bool()].flatten().cpu().numpy(), ax=ax)
    xlim[0] = min(xlim[0], ax.get_xlim()[0])
    xlim[1] = max(xlim[1], ax.get_xlim()[1])
    ylim[0] = min(ylim[0], ax.get_ylim()[0])
    ylim[1] = max(ylim[1], ax.get_ylim()[1])
    grid_cell_used[i, j] = True

xlim[1] += 10
ylim[1] += 10
xticks = np.round(
    np.arange(xlim[0], xlim[1], round(xlim[1] // 4, -len(str(round(xlim[1] // 4))) + 1))
).astype(int)
yticks = np.round(
    np.arange(ylim[0], ylim[1], round(ylim[1] // 4, -len(str(round(ylim[1] // 4))) + 1))
).astype(int)

for ax in fig.axes:
    ax_spec = ax.get_subplotspec()
    i = list(ax_spec.rowspan)[0]
    j = list(ax_spec.colspan)[0]

    if not grid_cell_used[i + 1, j]:
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticks, fontdict={"fontsize": "x-small"}, rotation=30)
    else:
        ax.set_xticks([])
    if not grid_cell_used[i, j - 1]:
        ax.set_yticks(yticks)
        ax.set_yticklabels(yticks, fontdict={"fontsize": "x-small"})
    else:
        ax.set_yticks([])
        ax.set_ylabel("")

fig.align_labels()

In [None]:
d = clinic_subj_dataset[-2]
t1 = d["t1w"][0]
mask = d["mask"][0]

intensity_range = (0, 20)
filtered_t1 = torch.where(
    (t1 >= intensity_range[0]) & (t1 <= intensity_range[1]) & mask.bool(),
    t1.double(),
    0.0,
)
dipy.viz.regtools.plot_slices(filtered_t1.cpu().numpy()).set_dpi(120)
plt.suptitle(f"Range [{intensity_range[0]}, {intensity_range[1]}]")
plt.show()

intensity_range = (70, 90)
filtered_t1 = torch.where(
    (t1 >= intensity_range[0]) & (t1 <= intensity_range[1]) & mask.bool(),
    t1.double(),
    0.0,
)
dipy.viz.regtools.plot_slices(filtered_t1.cpu().numpy()).set_dpi(120)
plt.suptitle(f"Range [{intensity_range[0]}, {intensity_range[1]}]")
plt.show()

intensity_range = (100, 112)
filtered_t1 = torch.where(
    (t1 >= intensity_range[0]) & (t1 <= intensity_range[1]) & mask.bool(),
    t1.double(),
    0.0,
)
dipy.viz.regtools.plot_slices(filtered_t1.cpu().numpy()).set_dpi(120)
plt.suptitle(f"Range [{intensity_range[0]}, {intensity_range[1]}]")
plt.show();

### Peak Finding

#### HCP

In [None]:
# Peak width maximization without overlap.
def peak_contour_intersection_loss(
    rel_height, peaks, signal, lambda_height, lambda_non_curvedness
):
    try:
        if np.isscalar(rel_height):
            rel_height = np.repeat([rel_height], len(peaks))

        peak_ranges = list()
        for rel_h, peak in zip(rel_height, peaks):
            _, _, l, r = scipy.signal.peak_widths(signal, np.asarray([peak]), rel_h)
            peak_ranges.append([l, r])
        peak_ranges = np.asarray(peak_ranges)

        intersections = list()
        non_curvedness = list()
        for i_peak in range(len(peaks)):
            for j_peak in range(i_peak + 1, len(peaks)):

                p1_range = peak_ranges[i_peak]
                p2_range = peak_ranges[j_peak]

                # Ensure p1 is the first peak in order.
                if p2_range[0] < p1_range[0]:
                    p1_range, p2_range = p2_range, p1_range

                p1_low, p1_high = p1_range
                p2_low, p2_high = p2_range

                intersect: float

                if i_peak == j_peak:
                    intersect = 0
                else:
                    # p2 starts within p1
                    if p2_low < p1_high:
                        # p2 is contained in p1
                        if p2_high <= p1_high:
                            intersect = p2_high - p2_low
                        # intersection of p1 and p2 is not the entirety of p2.
                        else:
                            intersect = p1_high - p2_low
                    # Edge case where the low values are equal.
                    elif p1_low == p2_low:
                        intersect = min((p1_high - p1_low), (p2_high - p2_low))
                    else:
                        intersect = 0

                intersections.append(float(intersect))

            # How far does the selected signal range differ from a curve, parameterized by a
            # quadratic function?
            p_range = np.round(peak_ranges[i_peak]).astype(int).flatten().tolist()
            selected_signal = signal[p_range[0] : p_range[1]]
            x = np.arange(len(selected_signal))
            curve_fit = scipy.interpolate.interp1d(x, selected_signal, kind="quadratic")
            pred_curve = curve_fit(x)
            curvedness_loss = np.linalg.norm(
                selected_signal - pred_curve, ord=2
            ) / np.linalg.norm(selected_signal - selected_signal.mean(), ord=2)

            non_curvedness.append(curvedness_loss)

        intersect_loss = 2 ** (np.asarray(intersections) / len(signal)).mean()
        height_loss = (np.abs(1.0 - rel_height)).mean()
        non_curve_loss = np.mean(non_curvedness)
        # print("Intersect ", intersect_loss)
        # print("Height ", lambda_height * height_loss)
        # print("Curviture ", lambda_non_curvedness * non_curve_loss)
        # print("-----")
    except Exception as e:
        print(rel_height)
        print(e)

        raise e
    return (
        intersect_loss
        + lambda_height * height_loss
        + lambda_non_curvedness * non_curve_loss
    )

In [None]:
size_subsample = 50000
t1_sample = list()
for d in hcp_subj_dataset:
    t1 = d["t1w"][0]
    mask = d["mask"][0]
    masked_t1 = t1[mask.bool()].flatten().cpu()
    # Randomly sample (without replacement) voxels of the T1 image. The KDE estimation does not
    # scale well with 1 million + samples.
    sample_idx = torch.randperm(len(masked_t1))

    masked_t1 = masked_t1[sample_idx[:size_subsample]].numpy()
    t1_sample.append(masked_t1)

t1_sample = np.concatenate(t1_sample).flatten()

In [None]:
print(t1_sample.shape)

In [None]:
print("Starting KDE")
# Scott's rule for KDE bandwidth selection. From
# <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gaussian_kde.html>
d = 1
n = len(t1_sample)
bandwidth = n ** (-1.0 / (d + 4))
# We want much more smoothness than Scott's gives by default.
bandwidth = bandwidth * 10
kde = sklearn.neighbors.KernelDensity(bandwidth=bandwidth, atol=1e-8, rtol=1e-5).fit(
    t1_sample.reshape(-1, 1)
)
print("Finished KDE")

In [None]:
intensities = np.linspace(t1_sample.min(), t1_sample.max(), 10000)

# Multiply by step size (size of integration) to scale densities such that the integral
# over all given intensities is 1.0.
integration_step = intensities[1] - intensities[0]
log_like = kde.score_samples(intensities.reshape(-1, 1))
likelihoods = np.exp(log_like) * integration_step

In [None]:
num_peaks = 2
plt.figure(dpi=120, figsize=(8, 4))
peaks, properties = scipy.signal.find_peaks(likelihoods, rel_height=0.3)
prominences = scipy.signal.peak_prominences(likelihoods, peaks)
peaks = peaks[np.argsort(prominences[0])][-num_peaks:]
# Ensure the peaks are in order of intensity window.
peaks = np.sort(peaks)
prominences = scipy.signal.peak_prominences(likelihoods, peaks)
_, contour_height, peak_left, peak_right = scipy.signal.peak_widths(
    likelihoods, peaks, prominence_data=prominences, rel_height=0.5
)
peak_left = (
    peak_left / len(intensities) * (intensities.max() - intensities.min())
) + intensities.min()
peak_right = (
    peak_right / len(intensities) * (intensities.max() - intensities.min())
) + intensities.min()

plt.plot(intensities, likelihoods)
plt.vlines(intensities[peaks], 0, likelihoods.max(), color="gray", alpha=0.5)
plt.plot(
    np.stack([peak_left, peak_right]),
    np.stack([contour_height, contour_height]),
    color="red",
)
plt.xlabel("T1w Intensity")
plt.ylabel("KDE Likelihood Estimation")
plt.title("Intensity Peaks & Widths in HCP Subject Data");

In [None]:
# lambda_height = 0.001
# lambda_non_curvedness = 1e11
# result = scipy.optimize.minimize(
#     functools.partial(
#         peak_contour_intersection_loss,
#         peaks=peaks,
#         signal=likelihoods,
#         lambda_height=lambda_height,
#         lambda_non_curvedness=lambda_non_curvedness,
#     ),
#     [0.9, 0.9],
#     bounds=((0.1, 1.0), (0.1, 1.0)),
#     # bounds=scipy.optimize.Bounds([0.1] * 2, [1.0] * 2, keep_feasible=True),
#     # method="Powell",
#     # options={"maxiter": 100000, "maxfun": 100000, "eps": 1e-7, "maxls": 100, 'gtol':1e-10},
# )
# print(result)

In [None]:
# The optimization sucks, just set them by hand.
result = dict()
result["x"] = np.asarray([0.25, 0.75])

In [None]:
plt.figure(dpi=120, figsize=(8, 4))

# Save the peak ranges for later.
peak_ranges = dict()
for i_peak, rel_h in enumerate(result["x"]):

    p = np.asarray(
        [
            peaks[i_peak],
        ]
    )
    rel_h = np.asarray(
        [
            rel_h,
        ]
    )
    prominences = scipy.signal.peak_prominences(likelihoods, p)
    _, contour_height, peak_left, peak_right = scipy.signal.peak_widths(
        likelihoods, p, rel_height=rel_h
    )

    peak_left = (
        peak_left / len(intensities) * (intensities.max() - intensities.min())
    ) + intensities.min()
    peak_right = (
        peak_right / len(intensities) * (intensities.max() - intensities.min())
    ) + intensities.min()
    plt.plot(
        np.stack([peak_left, peak_right]),
        np.stack([contour_height, contour_height]),
        color="red",
    )

    peak_ranges[i_peak] = {
        "range": np.asarray(
            [
                [float(peak_left), float(peak_right)],
                [float(contour_height), float(contour_height)],
            ]
        ).T,
        "peak": float(p),
    }


plt.plot(intensities, likelihoods)
plt.vlines(intensities[peaks], 0, likelihoods.max(), color="gray", alpha=0.5)

plt.xlabel("T1w Intensity")
plt.ylabel("KDE Likelihood Estimation")
plt.title("Intensity Peaks & Widths in HCP Subject Data");

In [None]:
ppr(peak_ranges)

d = hcp_subj_dataset[-2]
t1 = d["t1w"][0]
mask = d["mask"][0]

for p in peak_ranges.values():

    intensity_range = tuple(p["range"][:, 0])
    filtered_t1 = torch.where(
        (t1 >= intensity_range[0]) & (t1 <= intensity_range[1]) & mask.bool(),
        t1.double(),
        0.0,
    )
    dipy.viz.regtools.plot_slices(filtered_t1.cpu().numpy()).set_dpi(120)
    plt.suptitle(f"Range [{intensity_range[0]}, {intensity_range[1]}]")
    plt.show();

In [None]:
# Align the zero points, the gray matter peak, the white matter peak, and the max points.
peak_intensities = np.unique(
    np.sort(
        np.asarray(
            [
                intensities.min(),
                intensities.max(),
                intensities[np.where(likelihoods == likelihoods.min())][0].item(),
                intensities[np.where(likelihoods == likelihoods.max())][0].item(),
                *intensities[peaks],
            ]
        )
    )
)
peak_likelihoods = likelihoods[
    np.isin(intensities, peak_intensities, assume_unique=True)
]
hcp_peaks = np.stack([peak_intensities, peak_likelihoods], axis=-1)
# hcp_peaks = np.stack([hcp_peaks, likelihoods[np.where(intensities ]], axis=-1)
ppr(hcp_peaks)
ppr(hcp_peaks.shape)

In [None]:
hcp_hist = np.stack([intensities, likelihoods], axis=-1)
hcp_kde = kde
hcp_samples = t1_sample
hcp_peak_ranges = peak_ranges

#### Clinical

In [None]:
size_subsample = 50000
t1_sample = list()
for d in clinic_subj_dataset:
    t1 = d["t1w"][0]
    mask = d["mask"][0]
    masked_t1 = t1[mask.bool()].flatten().cpu()
    # Randomly sample (without replacement) voxels of the T1 image. The KDE estimation does not
    # scale well with 1 million + samples.
    sample_idx = torch.randperm(len(masked_t1))

    masked_t1 = masked_t1[sample_idx[:size_subsample]].numpy()
    t1_sample.append(masked_t1)

t1_sample = np.concatenate(t1_sample).flatten()

In [None]:
print("Starting KDE")
# Scott's rule for KDE bandwidth selection. From
# <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gaussian_kde.html>
d = 1
n = len(t1_sample)
bandwidth = n ** (-1.0 / (d + 4))
# We want much more smoothness than Scott's gives by default.
bandwidth = bandwidth * 10
kde = sklearn.neighbors.KernelDensity(bandwidth=bandwidth, atol=1e-8, rtol=1e-5).fit(
    t1_sample.reshape(-1, 1)
)
print("Finished KDE")

In [None]:
intensities = np.linspace(t1_sample.min(), t1_sample.max(), 10000)
# Multiply by step size (size of integration) to scale densities such that the integral
# over all given intensities is 1.0.
integration_step = intensities[1] - intensities[0]
log_like = kde.score_samples(intensities.reshape(-1, 1))
likelihoods = np.exp(log_like) * integration_step

In [None]:
num_peaks = 2
plt.figure(dpi=120, figsize=(8, 4))
peaks, properties = scipy.signal.find_peaks(likelihoods, rel_height=0.5)
prominences = scipy.signal.peak_prominences(likelihoods, peaks)
peaks = peaks[np.argsort(prominences[0])][-num_peaks:]
# Ensure the peaks are in order of intensity window.
peaks = np.sort(peaks)
prominences = scipy.signal.peak_prominences(likelihoods, peaks)
_, contour_height, peak_left, peak_right = scipy.signal.peak_widths(
    likelihoods, peaks, prominence_data=prominences, rel_height=0.5
)
peak_left = (
    peak_left / len(intensities) * (intensities.max() - intensities.min())
) + intensities.min()
peak_right = (
    peak_right / len(intensities) * (intensities.max() - intensities.min())
) + intensities.min()

plt.plot(intensities, likelihoods)
plt.vlines(intensities[peaks], 0, likelihoods.max(), color="gray", alpha=0.5)
plt.plot(
    np.stack([peak_left, peak_right]),
    np.stack([contour_height, contour_height]),
    color="red",
)
plt.xlabel("T1w Intensity")
plt.ylabel("KDE Likelihood Estimation")
plt.title("Intensity Peaks & Widths in Clinical Subject Data");

In [None]:
# lambda_height = 0.001
# lambda_non_curvedness = 1e11
# result = scipy.optimize.minimize(
#     functools.partial(
#         peak_contour_intersection_loss,
#         peaks=peaks,
#         signal=likelihoods,
#         lambda_height=lambda_height,
#         lambda_non_curvedness=lambda_non_curvedness,
#     ),
#     [0.9, 0.9],
#     bounds=((0.1, 1.0), (0.1, 1.0)),
#     # bounds=scipy.optimize.Bounds([0.1] * 2, [1.0] * 2, keep_feasible=True),
#     # method="Powell",
#     # options={"maxiter": 100000, "maxfun": 100000, "eps": 1e-7, "maxls": 100, 'gtol':1e-10},
# )
# print(result)

In [None]:
# Just set the relative heights by hand, the optimization isn't tuned well.
result = dict()
result["x"] = np.asarray([0.75, 0.75])

In [None]:
plt.figure(dpi=120, figsize=(8, 4))

# Save the peak ranges for later.
peak_ranges = dict()
for i_peak, rel_h in enumerate(result["x"]):

    p = np.asarray(
        [
            peaks[i_peak],
        ]
    )
    rel_h = np.asarray(
        [
            rel_h,
        ]
    )
    prominences = scipy.signal.peak_prominences(likelihoods, p)
    _, contour_height, peak_left, peak_right = scipy.signal.peak_widths(
        likelihoods, p, rel_height=rel_h
    )

    peak_left = (
        peak_left / len(intensities) * (intensities.max() - intensities.min())
    ) + intensities.min()
    peak_right = (
        peak_right / len(intensities) * (intensities.max() - intensities.min())
    ) + intensities.min()
    plt.plot(
        np.stack([peak_left, peak_right]),
        np.stack([contour_height, contour_height]),
        color="red",
    )

    peak_ranges[i_peak] = {
        "range": np.asarray(
            [
                [float(peak_left), float(peak_right)],
                [float(contour_height), float(contour_height)],
            ]
        ).T,
        "peak": float(p),
    }


plt.plot(intensities, likelihoods)
plt.vlines(intensities[peaks], 0, likelihoods.max(), color="gray", alpha=0.5)

plt.xlabel("T1w Intensity")
plt.ylabel("KDE Likelihood Estimation")
plt.title("Intensity Peaks & Widths in Clinical Subject Data");

In [None]:
ppr(peak_ranges)

d = clinic_subj_dataset[-2]
t1 = d["t1w"][0]
mask = d["mask"][0]

for p in peak_ranges.values():

    intensity_range = tuple(p["range"][:, 0])
    filtered_t1 = torch.where(
        (t1 >= intensity_range[0]) & (t1 <= intensity_range[1]) & mask.bool(),
        t1.double(),
        0.0,
    )
    dipy.viz.regtools.plot_slices(filtered_t1.cpu().numpy()).set_dpi(120)
    plt.suptitle(f"Range [{intensity_range[0]}, {intensity_range[1]}]")
    plt.show();

In [None]:
# Align the zero points, the gray matter peak, the white matter peak, and the max points.
peak_intensities = np.unique(
    np.sort(
        np.asarray(
            [
                intensities.min(),
                intensities.max(),
                intensities[np.where(likelihoods == likelihoods.min())][0].item(),
                intensities[np.where(likelihoods == likelihoods.max())][0].item(),
                *intensities[peaks],
            ]
        )
    )
)
peak_likelihoods = likelihoods[np.isin(intensities, peak_intensities)]
clinic_peaks = np.stack([peak_intensities, peak_likelihoods], axis=-1)

ppr(clinic_peaks)
ppr(clinic_peaks.shape)

In [None]:
clinic_hist = np.stack([intensities, likelihoods], axis=-1)
clinic_kde = kde
clinic_samples = t1_sample
clinic_peak_ranges = peak_ranges

## Histogram Modeling and Interpolation

**Ignore this section, models are insufficient.**

Tried:
- Fitting regression models
- Fitting polynomial interpolation models
- Fitting the histogram CDFs
- Fitting Gaussian Mixture Models

### sklearn Regression

In [None]:
# # Fit a cubic polynomial to these points by projecting into polynomial feature space,
# # then using linear regression.

# scaler = sklearn.preprocessing.MinMaxScaler(
#     feature_range=(clinic_peaks[:, 0].min(), clinic_peaks[:, 0].max())
# )
# proj = sklearn.preprocessing.PolynomialFeatures(degree=2)
# model = sklearn.linear_model.LinearRegression()

# X = hcp_peaks
# Y = clinic_peaks
# X_p = X.copy()
# X_p[:, 0] = scaler.fit_transform(X_p[:, 0].reshape(-1, 1)).flatten()
# X_p = proj.fit_transform(X_p)
# model = model.fit(X_p, Y)
# model.coef_

In [None]:
# plt.figure(dpi=120, figsize=(9, 6))

# plt.plot(hcp_peaks[:, 0], hcp_peaks[:, 1], label="HCP", marker="o", ls="-")
# plt.plot(clinic_peaks[:, 0], clinic_peaks[:, 1], label="Cinical", marker="o", ls="-")

# X = hcp_peaks
# X_p = X.copy()
# X_p[:, 0] = scaler.fit_transform(X_p[:, 0].reshape(-1, 1)).flatten()
# plt.plot(
#     X_p[:, 0],
#     X_p[:, 1],
#     label="HCP Scaled to Clinical",
#     marker=".",
#     ls="--",
#     alpha=0.7,
# )

# X_p = proj.fit_transform(X_p)
# adapt_pred = model.predict(X_p)
# plt.plot(
#     adapt_pred[:, 0],
#     adapt_pred[:, 1],
#     label="HCP adapted to Clinical",
#     marker=".",
#     ls="--",
#     alpha=0.7,
# )
# plt.xlabel("T1w Intensity")
# plt.ylabel("Likelihood")
# plt.legend();

In [None]:
# print(proj.get_feature_names())
# print(model.coef_)

In [None]:
# # Compare entire distributions.
# plt.figure(dpi=100, figsize=(7, 5))
# plt.plot(clinic_hist[:, 0], clinic_hist[:, 1], label="Clinic Original")
# plt.vlines(
#     clinic_peaks[:, 0],
#     clinic_hist[:, 1].min(),
#     clinic_hist[:, 1].max(),
#     ls="--",
#     color="black",
# )
# # plt.plot(hcp_hist[:, 0], hcp_hist[:, 1], label='HCP Original')

# X = hcp_hist
# X_p = X.copy()
# X_p[:, 0] = scaler.fit_transform(X_p[:, 0].reshape(-1, 1)).flatten()
# plt.plot(
#     X_p[:, 0],
#     X_p[:, 1],
#     label="HCP Scaled to Clinical",
#     alpha=0.7,
# )

# X_p = proj.fit_transform(X_p)
# clinified = model.predict(X_p)
# plt.plot(clinified[:, 0], clinified[:, 1], label="Clinifed HCP")

# plt.legend();

### scipy Interpolation

In [None]:
# # Perform 1D interpolation for each axis independently.

# X_intens = hcp_peaks[:, 0]
# Y_intens = clinic_peaks[:, 0]
# f_intens = scipy.interpolate.interp1d(X_intens, Y_intens, kind="linear")

# X_likelihood = hcp_peaks[:, 1]
# Y_likelihood = clinic_peaks[:, 1]
# f_likelihood = scipy.interpolate.interp1d(X_likelihood, Y_likelihood, kind="linear")

In [None]:
# # Compare entire distributions.
# plt.figure(dpi=100, figsize=(7, 5))
# plt.plot(clinic_hist[:, 0], clinic_hist[:, 1], label="Clinic Original")
# plt.vlines(
#     clinic_peaks[:, 0],
#     clinic_hist[:, 1].min(),
#     clinic_hist[:, 1].max(),
#     ls="--",
#     color="black",
# )
# # plt.plot(hcp_hist[:, 0], hcp_hist[:, 1], label='HCP Original')

# plt.plot(
#     f_intens(hcp_hist[:, 0]),
#     f_likelihood(hcp_hist[:, 1]),
#     label="HCP Interpolated to Clinic",
# )
# plt.legend();

### Gaussian Mixture Model Fitting & Transformation

#### HCP

In [None]:
# # Fit HCP data.
# hcp_gmm = sklearn.mixture.BayesianGaussianMixture(
#     n_components=50,
#     weight_concentration_prior=1e-5,
#     weight_concentration_prior_type="dirichlet_process",
#     n_init=2,
#     # verbose=1,
#     max_iter=500,
#     tol=1e-2 / 2,
# )
# X = hcp_samples.reshape(-1, 1)[::50]
# hcp_gmm = hcp_gmm.fit(X)
# print(f"Number of Components: {(hcp_gmm.weights_ >= 1e-5).sum()}")

In [None]:
# intensities_to_predict = hcp_samples[::100]
# plt.hist(intensities_to_predict, label="Original Data", density=True, bins=100)
# intensities_to_predict = np.sort(intensities_to_predict).reshape(-1, 1)
# predict_log_likelihoods = hcp_gmm.score_samples(intensities_to_predict).flatten()
# predict_likelihoods = np.exp(predict_log_likelihoods)

# plt.plot(intensities_to_predict, predict_likelihoods, label="GMM Predicted")
# plt.legend();

In [None]:
# fig, ax = plt.subplots(1, 1, dpi=100)
# x, component_num = hcp_gmm.sample(100000)
# x = x.flatten()

# plt.hist(
#     hcp_samples, label="Original Data", bins=100, density=True, color="gray", alpha=0.2
# )
# plot_data = pd.DataFrame(
#     np.stack([x, component_num], axis=-1), columns=["intensity", "component"]
# )
# sns.histplot(
#     plot_data, x="intensity", hue="component", ax=ax, palette="tab10", stat="density"
# );

#### Clinic

In [None]:
# # Fit clinic data.
# clinic_gmm = sklearn.mixture.BayesianGaussianMixture(
#     n_components=50,
#     weight_concentration_prior=1e-5,
#     weight_concentration_prior_type="dirichlet_process",
#     n_init=2,
#     # verbose=1,
#     max_iter=500,
#     tol=1e-2 / 2,
# )
# X = clinic_samples.reshape(-1, 1)[::50]
# clinic_gmm = clinic_gmm.fit(X)
# print(f"Number of Components: {(clinic_gmm.weights_ >= 1e-5).sum()}")

In [None]:
# intensities_to_predict = clinic_samples[::10]
# plt.hist(intensities_to_predict, label="Original Data", density=True, bins=100)
# intensities_to_predict = np.sort(intensities_to_predict).reshape(-1, 1)
# predict_log_likelihoods = clinic_gmm.score_samples(intensities_to_predict).flatten()
# predict_likelihoods = np.exp(predict_log_likelihoods)

# plt.plot(intensities_to_predict, predict_likelihoods, label="GMM Predicted")
# plt.legend();

In [None]:
# fig, ax = plt.subplots(1, 1, dpi=100)
# x, component_num = clinic_gmm.sample(100000)
# x = x.flatten()

# plt.hist(
#     clinic_samples,
#     label="Original Data",
#     bins=100,
#     density=True,
#     color="gray",
#     alpha=0.2,
# )
# plot_data = pd.DataFrame(
#     np.stack([x, component_num], axis=-1), columns=["intensity", "component"]
# )
# sns.histplot(
#     plot_data, x="intensity", hue="component", ax=ax, palette="tab10", stat="density"
# );

## T1w Intensity CDF

In [None]:
hcp_cdf = np.zeros_like(hcp_hist)
hcp_cdf[:, 0] = hcp_hist[:, 0]
hcp_cdf[:, 1] = np.cumsum(hcp_hist[:, 1])
hcp_cdf[:, 1] = hcp_cdf[:, 1] / hcp_cdf[:, 1].max()
plt.plot(hcp_cdf[:, 0], hcp_cdf[:, 1], label="HCP")

hcp_peak_heights = hcp_cdf[:, 1][np.isin(hcp_cdf[:, 0], hcp_peaks[:, 0])]
plt.plot(hcp_peaks[:, 0], hcp_peak_heights, "o", label="HCP", color="C0")

clinic_cdf = np.zeros_like(clinic_hist)
clinic_cdf[:, 0] = clinic_hist[:, 0]
clinic_cdf[:, 1] = np.cumsum(clinic_hist[:, 1])
clinic_cdf[:, 1] = clinic_cdf[:, 1] / clinic_cdf[:, 1].max()
plt.plot(clinic_cdf[:, 0], clinic_cdf[:, 1], label="clinic")
clinic_peak_heights = clinic_cdf[:, 1][np.isin(clinic_cdf[:, 0], clinic_peaks[:, 0])]
plt.plot(clinic_peaks[:, 0], clinic_peak_heights, "o", label="Clinic", color="C1")
plt.title("CDFs of KDEs of HCP & Clinic T1 Histograms")
plt.xlabel("Intensity")
plt.legend();

## Histogram Transformation

Based on algorithms in:

```L. G. Nyul, J. K. Udupa, and X. Zhang, “New variants of a method of MRI scale standardization,” IEEE Transactions on Medical Imaging, vol. 19, no. 2, pp. 143–150, Feb. 2000, doi: 10.1109/42.836373.```

### Determine Landmarks in Source and Target Histograms

In [None]:
hcp_cdf = np.stack([hcp_hist[:, 0], np.cumsum(hcp_hist[:, 1])], axis=-1)
clinic_cdf = np.stack([clinic_hist[:, 0], np.cumsum(clinic_hist[:, 1])], axis=-1)

# Normalize to a sum of 1.0 to account for numerical error.
hcp_cdf[:, 1] = hcp_cdf[:, 1] / hcp_cdf[:, 1].max()
clinic_cdf[:, 1] = clinic_cdf[:, 1] / clinic_cdf[:, 1].max()

In [None]:
# Select histogram landmarks that mark a new linear mapping.
landmarks = Box(default_box=True)
params = Box(default_box=True)

In [None]:
# Add min and max intensities as landmarks, take only quantiles to remove tails & noise.
min_quantile = 0.00
max_quantile = 0.99
params.min_quantile = min_quantile
params.max_quantile = max_quantile

hcp_transform_hist = hcp_hist.copy()
hcp_transform_hist[:, 1] = np.where(
    (hcp_cdf[:, 1] >= min_quantile) & (hcp_cdf[:, 1] <= max_quantile),
    hcp_hist[:, 1],
    0.0,
)

# Renormalize to sum to 1.0.
hcp_transform_hist[:, 1] = hcp_transform_hist[:, 1] / hcp_transform_hist[:, 1].sum()

landmarks.source.min = hcp_transform_hist[:, 0].min()
landmarks.source.max = hcp_transform_hist[:, 0].max()

clinic_transform_hist = clinic_hist.copy()
clinic_transform_hist[:, 1] = np.where(
    (clinic_cdf[:, 1] >= min_quantile) & (clinic_cdf[:, 1] <= max_quantile),
    clinic_hist[:, 1],
    0.0,
)

# Renormalize to sum to 1.0.
clinic_transform_hist[:, 1] = (
    clinic_transform_hist[:, 1] / clinic_transform_hist[:, 1].sum()
)

landmarks.target.min = clinic_transform_hist[:, 0].min()
landmarks.target.max = clinic_transform_hist[:, 0].max()

In [None]:
# Select quantiles of the gray matter peak/curve as individual points to match.
gm_quantiles = (0.25, 0.5, 0.75)
params.gm_quantiles = gm_quantiles

gm_clinic_peak = clinic_peak_ranges[0]
gm_clinic_curve = np.take(
    clinic_transform_hist,
    np.where(
        (clinic_transform_hist[:, 0] >= gm_clinic_peak["range"][0, 0])
        & (clinic_transform_hist[:, 0] <= gm_clinic_peak["range"][1, 0]),
    )[0],
    axis=0,
)

# Create a CDF *only* for the GM curve, to find quantiles.
gm_cdf = np.stack([gm_clinic_curve[:, 0], np.cumsum(gm_clinic_curve[:, 1])], axis=-1)
gm_cdf[:, 1] = gm_cdf[:, 1] / gm_cdf[:, 1].max()
gm_landmarks = list()
for q in gm_quantiles:
    intensity = gm_cdf[:, 0][np.argmin(np.abs(q - gm_cdf[:, 1]))]
    gm_landmarks.append(intensity)
landmarks.target.gm = gm_landmarks

###### Do the same for HCP/source images.
gm_hcp_peak = hcp_peak_ranges[0]
gm_hcp_curve = np.take(
    hcp_transform_hist,
    np.where(
        (hcp_transform_hist[:, 0] >= gm_hcp_peak["range"][0, 0])
        & (hcp_transform_hist[:, 0] <= gm_hcp_peak["range"][1, 0]),
    )[0],
    axis=0,
)

# Create a CDF *only* for the GM curve, to find quantiles.
gm_cdf = np.stack([gm_hcp_curve[:, 0], np.cumsum(gm_hcp_curve[:, 1])], axis=-1)
gm_cdf[:, 1] = gm_cdf[:, 1] / gm_cdf[:, 1].max()
gm_landmarks = list()
for q in gm_quantiles:
    intensity = gm_cdf[:, 0][np.argmin(np.abs(q - gm_cdf[:, 1]))]
    gm_landmarks.append(intensity)
landmarks.source.gm = gm_landmarks

In [None]:
# Select quantiles of the white matter peak/curve as individual points to match.
wm_quantiles = (0.25, 0.5, 0.75)
params.wm_quantiles = wm_quantiles

wm_clinic_peak = clinic_peak_ranges[1]
wm_clinic_curve = np.take(
    clinic_transform_hist,
    np.where(
        (clinic_transform_hist[:, 0] >= wm_clinic_peak["range"][0, 0])
        & (clinic_transform_hist[:, 0] <= wm_clinic_peak["range"][1, 0]),
    )[0],
    axis=0,
)

# Create a CDF *only* for the WM curve, to find quantiles.
wm_cdf = np.stack([wm_clinic_curve[:, 0], np.cumsum(wm_clinic_curve[:, 1])], axis=-1)
wm_cdf[:, 1] = wm_cdf[:, 1] / wm_cdf[:, 1].max()
wm_landmarks = list()
for q in wm_quantiles:
    # Choose the intensity that is the closest match to this region's quantile.
    intensity = wm_cdf[:, 0][np.argmin(np.abs(q - wm_cdf[:, 1]))]
    wm_landmarks.append(intensity)
landmarks.target.wm = wm_landmarks

########## Do the same for HCP/source images
wm_hcp_peak = hcp_peak_ranges[1]
wm_hcp_curve = np.take(
    hcp_transform_hist,
    np.where(
        (hcp_transform_hist[:, 0] >= wm_hcp_peak["range"][0, 0])
        & (hcp_transform_hist[:, 0] <= wm_hcp_peak["range"][1, 0]),
    )[0],
    axis=0,
)

# Create a CDF *only* for the WM curve, to find quantiles.
wm_cdf = np.stack([wm_hcp_curve[:, 0], np.cumsum(wm_hcp_curve[:, 1])], axis=-1)
wm_cdf[:, 1] = wm_cdf[:, 1] / wm_cdf[:, 1].max()
wm_landmarks = list()
for q in wm_quantiles:
    # Choose the intensity that is the closest match to this region's quantile.
    intensity = wm_cdf[:, 0][np.argmin(np.abs(q - wm_cdf[:, 1]))]
    wm_landmarks.append(intensity)
landmarks.source.wm = wm_landmarks

In [None]:
ppr(landmarks)
ppr(params)

### Interpolation of Histogram x-Axis

In [None]:
# Create interpolation function to map all intensity values in the source to intensity
# values in the taret.
x = np.sort(
    [
        landmarks.source.min,
        *landmarks.source.gm,
        *landmarks.source.wm,
        landmarks.source.max,
    ]
).flatten()
y = np.sort(
    [
        landmarks.target.min,
        *landmarks.target.gm,
        *landmarks.target.wm,
        landmarks.target.max,
    ]
).flatten()
# Allow out-of-range interpolation as just clamping to the max and min.
intensity_map = scipy.interpolate.interp1d(
    x, y, kind="linear", bounds_error=False, fill_value=(y.min(), y.max())
)

In [None]:
# Plot the intensity mapping interpolation.
plt.figure(dpi=100)
source_intens = np.linspace(-1, hcp_hist[:, 0].max() * 1.1, 10000)
plt.plot(source_intens, intensity_map(source_intens))
plt.plot(x, y, ".", color="black")
plt.xlabel("Original Intensity")
plt.ylabel("Mapped Intensity");

In [None]:
# Plot HCP hist with transformed x-axis.
plt.plot(hcp_hist[:, 0], hcp_hist[:, 1], label="original")
plt.vlines(x, 0, hcp_hist[:, 1].max() * 1.1, ls="--", color="gray")
plt.legend()
plt.show()
plt.plot(intensity_map(hcp_hist[:, 0]), hcp_hist[:, 1], label="transformed")
plt.vlines(y, 0, hcp_hist[:, 1].max() * 1.1, ls="--", color="gray")
plt.legend()
plt.show();

In [None]:
scaled_hcp_hist = hcp_hist.copy()
scaled_hcp_hist[:, 0] = intensity_map(scaled_hcp_hist[:, 0])

### Create Final Image Voxel LUT


In [None]:
plt.plot(hcp_hist[:, 0], np.cumsum(hcp_hist[:, 1]), label="original")
plt.plot(scaled_hcp_hist[:, 0], np.cumsum(scaled_hcp_hist[:, 1]), label="transformed")

plt.legend();

In [None]:
# Mappings on the CDF
# Source Pixel intensity -> Source Cumulative Density
source_cdf_mapper = scipy.interpolate.interp1d(
    scaled_hcp_hist[:, 0],
    np.cumsum(scaled_hcp_hist[:, 1]),
    kind="linear",
    bounds_error=False,
    fill_value=(0.0, 1.0),
)
# Target Cumulative Density -> Target Pixel Intensity
target_cdf_inverse_mapper = scipy.interpolate.interp1d(
    np.cumsum(clinic_transform_hist[:, 1]),
    clinic_transform_hist[:, 0],
    kind="linear",
    bounds_error=False,
    fill_value=(clinic_transform_hist[:, 0].min(), clinic_transform_hist[:, 0].max()),
)

In [None]:
def adapt_img(
    img, source_hist_intensity_mapper, source_cdf_mapper, target_cdf_inverse_mapper
):
    shape = img.shape
    x = img.flatten()

    x_adapt_intens = source_hist_intensity_mapper(x)
    x_densities = source_cdf_mapper(x_adapt_intens)
    y_intens = target_cdf_inverse_mapper(x_densities)

    return y_intens.reshape(shape)

In [None]:
f_hcp_to_clinic = functools.partial(
    adapt_img,
    source_hist_intensity_mapper=intensity_map,
    source_cdf_mapper=source_cdf_mapper,
    target_cdf_inverse_mapper=target_cdf_inverse_mapper,
)

## Results Visualization

In [None]:
fig = plt.figure(dpi=140, constrained_layout=True, figsize=(8, 10))
hcp_d = hcp_subj_dataset[-3]
hcp_t1 = hcp_d["t1w"][0].cpu().numpy()
hcp_mask = hcp_d["mask"][0].cpu().numpy()

clinic_d = clinic_subj_dataset[-1]
clinic_t1 = clinic_d["t1w"][0].cpu().numpy()
clinic_mask = clinic_d["mask"][0].cpu().numpy()

hcp_slice = np.index_exp[:, :, 84]
clinic_slice = np.index_exp[:, :, 90]
imgs = [
    hcp_t1[hcp_slice],
    intensity_map(hcp_t1[hcp_slice]).reshape(hcp_t1.shape[:2]),
    f_hcp_to_clinic(hcp_t1[hcp_slice]),
    clinic_t1[clinic_slice],
]
clinic_bounds = (0, max(map(lambda img: np.quantile(img.flatten(), 0.99), imgs[1:])))
bounds = [(0, np.quantile(imgs[0], 0.99))] + list((clinic_bounds,) * (len(imgs) - 1))
colors = ["orange", "red", "blue", "green"]
histograms = [
    hcp_t1[hcp_mask.astype(bool)],
    intensity_map(hcp_t1[hcp_mask.astype(bool)]),
    f_hcp_to_clinic(hcp_t1[hcp_mask.astype(bool)]),
    clinic_t1[clinic_mask.astype(bool)],
]

gs = fig.add_gridspec(nrows=len(imgs) + 1, ncols=3, hspace=0.2)

for i_row in range(len(imgs)):
    img_ax = fig.add_subplot(gs[i_row, 0])
    im = img_ax.imshow(
        np.rot90(imgs[i_row]), cmap="gray", vmin=bounds[i_row][0], vmax=bounds[i_row][1]
    )
    img_ax.set_xticks([])
    img_ax.set_yticks([])
    img_ax.set_xticklabels([])
    img_ax.set_yticklabels([])
    plt.colorbar(im, ax=img_ax, location="left", shrink=0.85, fraction=0.1, aspect=10)

    hist_ax = fig.add_subplot(gs[i_row, 1:])
    hist_ax = sns.histplot(
        histograms[i_row],
        kde=True,
        stat="probability",
        binrange=bounds[i_row],
        ax=hist_ax,
        color=colors[i_row],
    )
    hist_ax.set_xlim(*bounds[i_row])
    hist_ax.set_ylim(0, 0.04)

combine_hist_ax = fig.add_subplot(gs[-1, :])
for i_hist, color, lab in zip(
    range(1, 4), ["red", "blue", "green"], ["Lin Map", "CDF Resample", "Clinic"]
):
    combine_hist_ax = sns.kdeplot(
        histograms[i_hist],
        ax=combine_hist_ax,
        label=lab,
        color=color,
        fill=True,
        alpha=0.4,
        legend=True,
    )

combine_hist_ax.set_xlim(*clinic_bounds)
combine_hist_ax.set_ylim(0, 0.04)
combine_hist_ax.legend();

In [None]:
for p in peak_ranges.values():

    intensity_range = tuple(p["range"][:, 0])
    filtered_t1 = torch.where(
        (t1 >= intensity_range[0]) & (t1 <= intensity_range[1]) & mask.bool(),
        t1.double(),
        0.0,
    )
    dipy.viz.regtools.plot_slices(filtered_t1.cpu().numpy()).set_dpi(120)
    plt.suptitle(f"Range [{intensity_range[0]}, {intensity_range[1]}]")
    plt.show();