# DTI Noise Removal via Threshold Finding

In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 2

# imports
import collections
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
import zipfile

import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import ipyplot

# Data management libraries.
import nibabel as nib
import nibabel.processing
import natsort
from natsort import natsorted
from addict import Addict
import pprint
from pprint import pprint as ppr
import box
from box import Box

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchio
import monai

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy

import pitn

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

### Specs Recording

In [None]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():
    # GPU information
    try:
        gpu_info = pitn.utils.system.get_gpu_specs()
        print(gpu_info)
    except NameError:
        print("CUDA Version: ", torch.version.cuda)
else:
    print("CUDA not in use, falling back to CPU")

In [None]:
# cap is defined in an ipython magic command
print(cap)

Author: Tyler Spears

Last updated: 2022-02-15T18:11:32.447400+00:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.23.1

Compiler    : GCC 7.3.0
OS          : Linux
Release     : 5.4.0-99-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

Git hash: 4152448860d966bfad5dd6d91175a4c83ce3512d

numpy     : 1.20.2
dipy      : 1.4.1
scipy     : 1.5.3
monai     : 0.8.0
skimage   : 0.18.1
box       : 5.4.1
pitn      : 0.0.post1.dev132+g02c0d1a
matplotlib: 3.4.1
seaborn   : 0.11.1
torch     : 1.10.0
torchio   : 0.18.30
natsort   : 7.1.1
json      : 2.0.9
nibabel   : 3.2.1
sys       : 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0]
pandas    : 1.2.3

  id  Name              Driver Version      CUDA Version  Total Memory    uuid
----  ----------------  ----------------  --------------  --------------  ----------------------------------------
   0  NVIDIA TITAN RTX  470.103.01                  11.3  24217.0MB       GPU-

### Data Variables & Definitions Setup

In [None]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"])
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"])
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

## Parameter Reading & Experiment Setup

### Parameters

<div class="alert alert-block alert-info"> <b>NOTE</b> Here are all the parameters! This makes it easy to find them! </div>

In [None]:
params = Box(default_box=True)

# General experiment-wide params
###############################################
params.experiment_name = "debug_rmse_metrics"
###############################################
# 6 channels for the 6 DTI components
params.n_channels = 6
params.n_subjs = 16
params.lr_vox_size = 2.5
params.fr_vox_size = 1.25
params.use_anat = True

# Data params
params.data.fr_dir = data_dir / f"scale-{params.fr_vox_size:.2f}mm"
params.data.lr_dir = data_dir / f"scale-{params.lr_vox_size:.2f}mm"
params.data.dti_fname_pattern = r"sub-*dti.nii.gz"
params.data.mask_fname_pattern = r"dti/sub-*mask.nii.gz"
params.data.anat_type = "t2w"
params.data.anat_fname_pattern = f"sub-*{params.data.anat_type}.nii.gz"
# The data were downsampled artificially by this factor.
params.data.downsampled_by_factor = params.lr_vox_size / params.fr_vox_size
params.data.downsampled_by_factor = (
    int(params.data.downsampled_by_factor)
    if int(params.data.downsampled_by_factor) == params.data.downsampled_by_factor
    else params.data.downsampled_by_factor
)

# This is the number of voxels to remove (read: center crop out) from the network's
# prediction. This allows for an "oversampling" of the low-res voxels to help inform a
# more constrained HR prediction. This value of voxels will be removed from each spatial
# dimension (D, H, and W) starting at the center of the output patches.
# Ex. A size of 1 will remove the 2 outer-most voxels from each dimension in the output,
# while still keeping the corresponding voxels in the LR input.
params.hr_center_crop_per_side = 0

# Quantile clamping to be done on the outer edge of the mask.
# NOTE: This clamping *will* effect the test and validation scores, as those voxels
# clamped by this are considered as "errors"/"noise" and will be discarded in testing.
# 80,000 is ~ the average volume of the entire brain mask.
# params.data.edge_correction_max_vox_to_change = 300
# # params.data.mask_edge_clamp_max_quantile = False
# # Second data scaling method, where the training data will be scaled and possibly clipped,
# # but the testing data will be compared on the originals.
# # Scale input data by the valid values of each channel of the DTI.
# # I.e., Dx,x in [0, 1], Dx,y in [-1, 1], Dy,y in [0, 1], Dy,z in [-1, 1], etc.
# params.data.dti_scale_range = ((0, -1, 0, -1, -1, 0), (1, 1, 1, 1, 1, 1))
# params.data.anat_scale_range = (0, 1)
# params.data.scale_to_quantiles = (0.0001, 0.9999)
# params.data.clip_to_quantiles = True

# Network params.
# The network's goal is to upsample the input by this factor.
params.net.upscale_factor = params.data.downsampled_by_factor
params.net.kwargs.n_res_units = 3
params.net.kwargs.n_dense_units = 3
params.net.kwargs.interior_channels = params.n_channels * 2
params.net.kwargs.activate_fn = F.elu
params.net.kwargs.upsample_activate_fn = F.elu
params.net.kwargs.center_crop_output_side_amt = params.hr_center_crop_per_side

# Adam optimizer kwargs
params.optim.name = "AdamW"
params.optim.kwargs.lr = 5e-4
params.optim.kwargs.betas = (0.9, 0.999)

# Testing params
params.test.dataset_subj_percent = 0.4

# Validation params
params.val.dataset_subj_percent = 0.2

# Training params
params.train.in_patch_size = (24, 24, 24)
params.train.batch_size = 32
params.train.samples_per_subj_per_epoch = 8000
params.train.max_epochs = 50
params.train.loss_name = "mse"
# Percentage of subjs in dataset that go into the training set.
params.train.dataset_subj_percent = 1 - (
    params.test.dataset_subj_percent + params.val.dataset_subj_percent
)

# Learning rate scheduler config.
params.train.lr_scheduler.name = "OneCycleLR"
params.train.lr_scheduler.kwargs.max_lr = 3e-3
params.train.lr_scheduler.kwargs.epochs = params.train.max_epochs
num_test_subjs = int(np.ceil(params.n_subjs * params.test.dataset_subj_percent))
num_val_subjs = int(np.ceil(params.n_subjs * params.val.dataset_subj_percent))
num_train_subjs = params.n_subjs - (num_test_subjs + num_val_subjs)
params.train.lr_scheduler.kwargs.steps_per_epoch = (
    params.train.samples_per_subj_per_epoch * num_train_subjs // params.train.batch_size
)

# If a config file exists, override the defaults with those values.
try:
    if "PITN_CONFIG" in os.environ.keys():
        config_fname = Path(os.environ["PITN_CONFIG"])
    else:
        config_fname = pitn.utils.system.get_file_glob_unique(Path("."), r"config.*")
    f_type = config_fname.suffix.casefold()
    if f_type in {"yaml", "yml"}:
        f_params = Box.from_yaml(config_fname)
    elif f_type == "json":
        f_params = Box.from_json(config_fname)
    elif f_type == "toml":
        f_params = Box.from_toml(config_fname)
    else:
        raise RuntimeError()

    params.merge_update(f_params)

except:
    pass

# Remove the default_box behavior now that params have been fully read in.
p = Box(default_box=False)
p.merge_update(params)
params = p

## Data Loading

### Subject ID Selection

In [None]:
# Find data directories for each subject.
subj_dirs: Box = Box()

selected_ids = [
    "397154",
    "224022",
    "140117",
    "751348",
    "894774",
    "156637",
    "227432",
    "303624",
    "185947",
    "810439",
    "753251",
    "644246",
    "141422",
    "135528",
    "103010",
    "700634",
]

## Sub-set the chosen participants for dev and debugging!
sampled_subjs = random.sample(selected_ids, params.n_subjs)
if len(sampled_subjs) < len(selected_ids):
    warnings.warn(
        f"WARNING: Sub-selecting {len(sampled_subjs)}/{len(selected_ids)} "
        + "participants for dev and debugging. "
        + f"Subj IDs selected: {sampled_subjs}"
    )
selected_subjs = sampled_subjs
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_subjs = natsorted(selected_subjs)

for subj_id in selected_subjs:
    subj_dirs[subj_id] = Box()
    subj_dirs[subj_id].fr = pitn.utils.system.get_file_glob_unique(
        params.data.fr_dir, f"*{subj_id}*"
    )
    subj_dirs[subj_id].lr = pitn.utils.system.get_file_glob_unique(
        params.data.lr_dir, f"*{subj_id}*"
    )
    assert subj_dirs[subj_id].fr.exists()
    assert subj_dirs[subj_id].lr.exists()
ppr(subj_dirs)

{'103010': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-103010'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-103010')},
 '135528': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-135528'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-135528')},
 '140117': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-140117'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-140117')},
 '141422': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/sub-141422'),
            'lr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-2.50mm/sub-141422')},
 '156637': {'fr': PosixPath('/srv/tmp/data/pitn/hcp/derivatives/diqt/mean-downsample/scale-1.25mm/su

### Loading and Preprocessing

In [None]:
# Prep for Dataset loading.

# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=True)

# HR -> LR patch coordinate conversion function.
fr2lr_patch_coords_fn = {
    "lr_dti": functools.partial(
        pitn.coords.transform.int_downscale_patch_idx,
        downscale_factor=params.data.downsampled_by_factor,
        downscale_patch_shape=params.train.in_patch_size,
    )
}

# Kwargs for the patches dataset (the _VolPatchDataset class) of the HR volumes.
patch_kwargs = dict(
    patch_shape=tuple(
        np.floor(
            np.asarray(params.train.in_patch_size) * params.data.downsampled_by_factor
        ).astype(int)
    ),
    stride=1,
    meta_keys_to_patch_index={"dti", "mask", params.data.anat_type},
    mask_name="mask",
)


def fix_downsample_shape_errors(
    fr_vol: torch.Tensor, fr_affine: torch.Tensor, target_spatial_shape: tuple
):
    """Small utility to fix shape differences between LR and FR data."""
    target_shape = np.asarray(target_spatial_shape)
    if fr_vol.shape[1:] != tuple(target_shape):
        # Use torchio objects because they fix the affine matrix, too.
        # Flip before transform to pad on the right/top/furthest side of the dimension
        # first, before the left/bottom/closest.
        flip_vol = fr_vol.flip([1, 2, 3])
        im = torchio.ScalarImage(tensor=flip_vol, affine=fr_affine)
        transform = torchio.transforms.CropOrPad(target_spatial_shape, 0, copy=False)
        im = transform(im)
        result_vol = im["data"]
        # Unflip.
        result_vol = result_vol.flip([1, 2, 3])
        result_aff = im["affine"]
    else:
        result_vol = fr_vol
        result_aff = fr_affine

    return result_vol, result_aff


def orient_to_viz(vol, affine):

    if torch.is_tensor(vol):
        v = vol.detach().cpu().numpy()
    else:
        v = vol
    v = np.rot90(np.rot90(v, k=1, axes=(1, 3)), k=2, axes=(2, 3))
    if torch.is_tensor(vol):
        v = torch.from_numpy(np.copy(v)).to(vol)

    # Adjust the affine matrix.
    full_rot_aff = np.zeros_like(affine)
    full_rot_aff[-1, -1] = 1.0
    # 90 degree rot around the second axis.
    q1 = nib.quaternions.angle_axis2quat(np.pi / 2, [0, 1, 0])
    # 180 degree rot around the first axis.
    q2 = nib.quaternions.angle_axis2quat(np.pi, [1, 0, 0])
    new_q = nib.quaternions.mult(q1, q2)
    rot_aff = nib.quaternions.quat2mat(new_q)
    full_rot_aff[:-1, :-1] = rot_aff
    new_aff = full_rot_aff @ affine

    return v, new_aff

#### Data Loading

In [None]:
# Import and organize all data.
subj_data: dict = dict()

meta_keys_to_keep = {"affine", "original_affine"}
for subj_id, subj_dir in subj_dirs.items():

    data = dict()
    data["subj_id"] = subj_id
    fr_subj_dir = subj_dirs[subj_id]["fr"]
    lr_subj_dir = subj_dirs[subj_id]["lr"]
    data["fr_subj_dir"] = fr_subj_dir
    data["lr_subj_dir"] = lr_subj_dir

    # Low-resolution DTI.
    lr_dti_f = pitn.utils.system.get_file_glob_unique(
        lr_subj_dir, params.data.dti_fname_pattern
    )
    im = nib_reader.read(lr_dti_f)
    lr_dti, meta = nib_reader.get_data(im)
    meta = {k: meta[k] for k in meta_keys_to_keep}
    lr_dti = torch.from_numpy(lr_dti)
    lr_dti, meta["affine"] = orient_to_viz(lr_dti, meta["affine"])
    data["lr_dti"] = lr_dti
    # Store raw LR DTI for cubic spline comparisons.
    data["raw_lr_dti"] = lr_dti
    data["lr_dti_meta_dict"] = meta

    # May need to handle shape errors when re-upscaling back from LR to HR.
    lr_dti_shape = np.asarray(lr_dti.shape[1:])
    target_fr_shape = np.floor(lr_dti_shape * params.net.upscale_factor).astype(int)

    # Full-resolution images/volumes.
    # DTI.
    dti_f = pitn.utils.system.get_file_glob_unique(
        fr_subj_dir, params.data.dti_fname_pattern
    )
    im = nib_reader.read(dti_f)
    dti, meta = nib_reader.get_data(im)
    dti = torch.from_numpy(dti)
    meta = {k: meta[k] for k in meta_keys_to_keep}
    dti, meta["affine"] = fix_downsample_shape_errors(
        dti, meta["affine"], target_fr_shape
    )
    dti, meta["affine"] = orient_to_viz(dti, meta["affine"])
    data["dti"] = dti
    # Store raw DTI for validation and testing.
    data["raw_dti"] = dti
    data["dti_meta_dict"] = meta

    # Diffusion mask.
    mask_f = pitn.utils.system.get_file_glob_unique(
        fr_subj_dir, params.data.mask_fname_pattern
    )
    im = nib_reader.read(mask_f)
    mask, meta = nib_reader.get_data(im)
    meta = {k: meta[k] for k in meta_keys_to_keep}
    mask = torch.from_numpy(mask)
    # Add channel dim if not available.
    if mask.ndim == 3:
        mask = mask[
            None,
        ]
    mask, meta["affine"] = fix_downsample_shape_errors(
        mask, meta["affine"], target_fr_shape
    )
    mask, meta["affine"] = orient_to_viz(mask, meta["affine"])
    mask = mask.bool()
    data["mask"] = mask
    data["mask_meta_dict"] = meta

    # Anatomical/structural volume.
    anat_f = pitn.utils.system.get_file_glob_unique(
        fr_subj_dir, params.data.anat_fname_pattern
    )
    im = nib_reader.read(anat_f)
    anat, meta = nib_reader.get_data(im)
    meta = {k: meta[k] for k in meta_keys_to_keep}
    anat = torch.from_numpy(anat)
    if anat.ndim == 3:
        anat = anat[
            None,
        ]
    anat, meta["affine"] = fix_downsample_shape_errors(
        anat, meta["affine"], target_fr_shape
    )
    anat, meta["affine"] = orient_to_viz(anat, meta["affine"])
    data[params.data.anat_type] = anat
    data[params.data.anat_type + "_meta_dict"] = meta

    # plt.clf()
    # dti = data['dti']
    # fig, axs = plt.subplots(nrows=3, dpi=110)
    # plt.title(f"Subject {subj_id}")
    # axs[0].plot(dti.view(6, 144, -1).max(2).values.numpy().T)
    # axs[1].plot(dti.permute(0, 2, 1, 3).contiguous().view(6, 174, -1).max(2).values.numpy().T)
    # axs[2].plot(dti.permute(0, 3, 1, 2).contiguous().view(6, 144, -1).max(2).values.numpy().T)
    # plt.show(close=True)

    # Consider this as the "noise correction" step to have more understandable,
    # stable, and consistent metric results. Otherwise, metrics can change by orders of
    # magnitude for no good reason!
    if "edge_correction_max_vox_to_change" in params.data:

        correct_dti = pitn.data.norm.correct_edge_noise_with_median(
            data["dti"],
            data["mask"],
            max_num_vox_to_change=params.data.edge_correction_max_vox_to_change,
            erosion_st_elem=skimage.morphology.ball(4),
            median_st_elem=skimage.morphology.cube(2),
        )
        print("DTI Correction Mean Abs. Error Per Tensor Component: ")
        abs_diff = torch.abs(data["dti"] - correct_dti)
        mae = abs_diff.view(6, -1).sum(1)
        mae = mae / data["mask"].sum()
        mae_str = ""

        for c in range(params.n_channels):
            mae_v = mae.detach().cpu()[c]
            # iqr = torch.quantile(data["dti"][c][data['mask'][0]], torch.as_tensor([0.25, 0.75]))
            # iqr = torch.abs(iqr[1] - iqr[0]).item()
            num_changed = (abs_diff[c] > 1e-13).sum().item()
            med = torch.median(data["dti"][c][data["mask"][0]]).item()
            s = (
                "\t"
                + str(mae_v.item())
                + f" with {num_changed} changes vs. median "
                + str(med)
                + "\n"
            )
            mae_str = mae_str + s

        print(mae_str)
        data["dti"] = correct_dti

    elif "mask_edge_quantile_clamp" in params.data:
        data["dti"] = pitn.data.norm.mask_constrain_clamp(
            data["dti"],
            data["mask"],
            quantile_clamp=params.data.mask_edge_quantile_clamp,
            selection_st_elem=skimage.morphology.ball(2),
        )

        # cheap_lr_mask = F.interpolate(
        #     data["mask"][
        #         None,
        #     ].float(),
        #     size=tuple(data["lr_dti"].shape[1:]),
        #     mode="nearest",
        # ).bool()[0]
        # data["lr_dti"] = pitn.data.norm.mask_constrain_clamp(
        #     data["lr_dti"],
        #     cheap_lr_mask,
        #     quantile_clamp=params.data.mask_edge_quantile_clamp,
        #     selection_st_elem=skimage.morphology.ball(1),
        # )
    # plt.clf()
    # dti = data['dti']
    # fig, axs = plt.subplots(nrows=3, dpi=110)
    # plt.title(f"Subject {subj_id}")
    # axs[0].plot(dti.view(6, 144, -1).max(2).values.numpy().T)
    # axs[1].plot(dti.permute(0, 2, 1, 3).contiguous().view(6, 174, -1).max(2).values.numpy().T)
    # axs[2].plot(dti.permute(0, 3, 1, 2).contiguous().view(6, 144, -1).max(2).values.numpy().T)
    # plt.show(close=True)

    # Perform scaling of input data?
    if "dti_scale_range" in params.data:
        scaler = pitn.data.norm.DTIMinMaxScaler(
            params.data.dti_scale_range[0],
            params.data.dti_scale_range[1],
            quantile_low=params.data.scale_to_quantiles[0],
            quantile_high=params.data.scale_to_quantiles[1],
            dim=(1, 2, 3),
            channel_size=params.n_channels,
            clip=params.data.clip_to_quantiles,
        )
        scaled = scaler.scale(data["dti"], stateful=True, keep_orig=False)
        data["dti"] = scaled
        data["dti_scaler"] = scaler

        scaler = pitn.data.norm.DTIMinMaxScaler(
            params.data.dti_scale_range[0],
            params.data.dti_scale_range[1],
            quantile_low=params.data.scale_to_quantiles[0],
            quantile_high=params.data.scale_to_quantiles[1],
            dim=(1, 2, 3),
            channel_size=params.n_channels,
            clip=params.data.clip_to_quantiles,
        )
        scaled = scaler.scale(data["lr_dti"], stateful=True, keep_orig=False)
        data["lr_dti"] = scaled
        data["lr_dti_scaler"] = scaler

    # Perform scaling of input data?
    if "anat_scale_range" in params.data:
        scaler = pitn.data.norm.DTIMinMaxScaler(
            params.data.anat_scale_range[0],
            params.data.anat_scale_range[1],
            quantile_low=params.data.scale_to_quantiles[0],
            quantile_high=params.data.scale_to_quantiles[1],
            dim=(1, 2, 3),
            channel_size=1,
            clip=params.data.clip_to_quantiles,
        )
        scaled = scaler.scale(
            data[params.data.anat_type], stateful=True, keep_orig=False
        )
        # *****Disables anat mode refinement*******
        # scaled = (scaled * 0) + 1
        data[params.data.anat_type] = scaled
        data[params.data.anat_type + "_scaler"] = scaler

    vol_names = {"dti", "mask", "lr_dti", params.data.anat_type}
    metadata_names = set(data.keys()) - vol_names
    vol_d = {k: data[k] for k in vol_names}
    meta_d = {k: data[k] for k in metadata_names}

    # Create multi-volume dataset for this subj-session.
    subj_dataset = pitn.data.SubjSesDataset(
        vol_d,
        primary_vol_name="dti",
        special_secondary2primary_coords_fns=fr2lr_patch_coords_fn,
        transform=None,
        primary_patch_kwargs=patch_kwargs,
        **meta_d,
    )
    print("Creating patches")
    # Init the patches dataset.
    subj_dataset.patches

    # Finalize this subject.
    subj_data[subj_id] = subj_dataset

    print("=" * 20)

print("===Data Loaded===")

## Threshold Calculation

In [None]:
# Store the upper thresholds found for the DTI components, the eigenvalues of the
# DTIs, and the threshold for the CSF found in the MD volumes.
dti_component_thresholds = list()
eigvals_thresholds = list()
csf_thresholds = list()

output_dir = results_dir.parent / "dti_threshold"
output_dir.mkdir(parents=True, exist_ok=True)

for k_subj in subj_data.keys():
    s = subj_data[k_subj][0]
    dti = torch.clone(s["raw_dti"]).detach().cpu().numpy()
    mask = torch.clone(s["mask"]).detach().cpu().numpy()[0].astype(bool)
    dti_tensor = dipy.reconst.dti.from_lower_triangular(dti.transpose(1, 2, 3, 0))
    eigvals, eigvecs = dipy.reconst.dti.decompose_tensor(dti_tensor)
    md = dipy.reconst.dti.mean_diffusivity(eigvals)

    # median filtered mask the mask
    st_elem = skimage.morphology.ball(8)
    mask = skimage.morphology.binary_erosion(mask, st_elem)
    # Focus more removal of the superior of the brain, but leave rest of sides more-or-less
    # unchanged.
    st_elem = skimage.morphology.ball(10)
    st_elem[10:, :, :] = 0
    mask = skimage.morphology.binary_erosion(mask, st_elem)

    selected_md = md * mask
    thresh = skimage.filters.threshold_li(
        selected_md[mask], tolerance=1e-6, initial_guess=0.0013
    )
    print(thresh)
    csf_thresholds.append(thresh)
    thresh_md = selected_md > thresh
    # Constrain the CSF selection to larger objects.
    thresh_md = skimage.morphology.remove_small_objects(thresh_md, 12**3, 2)
    csf_mask = thresh_md

    print(mask.sum())
    fig = plt.figure(dpi=120)
    pitn.viz.plot_vol_slices(thresh_md, md * mask, md * csf_mask, cmap="gray", fig=fig)
    plt.title(k_subj)
    plt.savefig(output_dir / f"sub-{k_subj}_thresh_md__mask_md__csf_md.png")
    plt.show()

    # Select some multiple of some quantile of CSF as the actual "cutoff"
    cutoff_quantile = 0.99
    cutoff_factor = 1.1

    # Select CSF over all tensor components.
    select_dti = dti.transpose(1, 2, 3, 0)[csf_mask]
    component_thresholds = (
        np.quantile(select_dti, cutoff_quantile, axis=0) * cutoff_factor
    )
    dti_component_thresholds.append(component_thresholds)
    plt.figure(dpi=120)
    hist = sns.histplot(select_dti, legend=True, bins=100)
    plt.vlines(component_thresholds, 0, plt.gca().get_ylim()[1], color="black", lw=0.8)
    plt.title(k_subj)
    plt.savefig(output_dir / f"sub-{k_subj}_dti_component_hist.png")
    plt.show()
    print(component_thresholds)

    # Select CSF over all eigenvalues components.
    select_eigvals = eigvals[csf_mask]
    eigval_cutoff = np.quantile(select_eigvals, cutoff_quantile, axis=0) * cutoff_factor
    eigvals_thresholds.append(eigval_cutoff)
    plt.figure(dpi=120)
    sns.histplot(select_eigvals, legend=True, bins=100)
    plt.vlines(eigval_cutoff, 0, plt.gca().get_ylim()[1], color="black", lw=0.8)
    plt.title(k_subj)
    plt.savefig(output_dir / f"sub-{k_subj}_eigenval_hist.png")
    plt.show()
    print(eigval_cutoff)

dti_component_thresholds = np.stack(dti_component_thresholds)
eigvals_thresholds = np.stack(eigvals_thresholds)
csf_thresholds = np.asarray(csf_thresholds)

np.savetxt(
    output_dir / "dti_component_cutoffs_over_subjs.txt", dti_component_thresholds
)
np.savetxt(output_dir / "eigval_cutoffs_over_subjs.txt", eigvals_thresholds)
np.savetxt(output_dir / "md_csf_thresholds_over_subjs.txt", csf_thresholds)