# Continuous-Space Super-Resolution of fODFs in Diffusion MRI

Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

## Imports & Setup

In [None]:
# Imports
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 2

# imports
import collections
import copy
import datetime
import functools
import inspect
import io
import itertools
import math
import os
import pathlib
import pdb
import random
import shutil
import subprocess
import sys
import tempfile
import time
import typing
import warnings
import zipfile
from pathlib import Path
from pprint import pprint as ppr

import aim
import dotenv
import einops

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import monai

# Data management libraries.
import nibabel as nib
import nibabel.processing

# Computation & ML libraries.
import numpy as np
import pandas as pd
import seaborn as sns
import skimage
import torch
import torch.nn.functional as F
import torchinfo
from box import Box
from icecream import ic
from natsort import natsorted

from lightning_fabric.fabric import Fabric

import pitn

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})
plt.rcParams.update({"image.cmap": "gray"})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True)

In [None]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    # Pick only one device for the default, may use multiple GPUs for training later.
    if "CUDA_PYTORCH_DEVICE_IDX" in os.environ.keys():
        dev_idx = int(os.environ["CUDA_PYTORCH_DEVICE_IDX"])
    else:
        dev_idx = 0
    device = torch.device(f"cuda:{dev_idx}")
    print("CUDA Device IDX ", dev_idx)
    torch.cuda.set_device(device)
    print("CUDA Current Device ", torch.cuda.current_device())
    print("CUDA Device properties: ", torch.cuda.get_device_properties(device))
    # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
    # in PyTorch 1.12 and later.
    torch.backends.cuda.matmul.allow_tf32 = True
    # See
    # <https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices>
    # for details.

    # Activate cudnn benchmarking to optimize convolution algorithm speed.
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = True
        print("CuDNN convolution optimization enabled.")
        # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
        torch.backends.cudnn.allow_tf32 = True

else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

## Experiment & Parameters Setup

In [None]:
p = Box(default_box=True)
# Experiment defaults, can be overridden in a config file.

p.results_dir = "/data/srv/outputs/pitn/results/runs"
p.tmp_results_dir = "/data/srv/outputs/pitn/results/tmp"
p.test.subj_ids = ["299154"]
p.model_weight_f = str(
    Path(p.tmp_results_dir) / "2023-02-09T21_09_47/state_dict_epoch_174_step_35000.pt"
)
p.target_vox_size = 0.374
###############################################
# Network/model parameters.
p.encoder = dict(
    interior_channels=80,
    out_channels=128,
    n_res_units=3,
    n_dense_units=3,
    activate_fn="relu",
)
p.decoder = dict(
    context_v_features=128,
    in_features=p.encoder.out_channels,
    out_features=45,
    m_encode_num_freqs=36,
    sigma_encode_scale=3.0,
)


# If a config file exists, override the defaults with those values.
try:
    if "PITN_CONFIG" in os.environ.keys():
        config_fname = Path(os.environ["PITN_CONFIG"])
    else:
        config_fname = pitn.utils.system.get_file_glob_unique(Path("."), r"config.*")
    f_type = config_fname.suffix.casefold()
    if f_type in {".yaml", ".yml"}:
        f_params = Box.from_yaml(filename=config_fname)
    elif f_type == ".json":
        f_params = Box.from_json(filename=config_fname)
    elif f_type == ".toml":
        f_params = Box.from_toml(filename=config_fname)
    else:
        raise RuntimeError()

    p.merge_update(f_params)

except:
    print("WARNING: Config file not loaded")
    pass


# Remove the default_box behavior now that params have been fully read in.
_p = Box(default_box=False)
_p.merge_update(p)
p = _p

## Data Loading

In [None]:
hcp_full_res_data_dir = Path("/data/srv/data/pitn/hcp")
hcp_full_res_fodf_dir = Path("/data/srv/outputs/pitn/hcp/full-res/fodf")
hcp_low_res_data_dir = Path("/data/srv/outputs/pitn/hcp/downsample/scale-2.00mm/vol")
hcp_low_res_fodf_dir = Path("/data/srv/outputs/pitn/hcp/downsample/scale-2.00mm/fodf")

assert hcp_full_res_data_dir.exists()
assert hcp_full_res_fodf_dir.exists()
assert hcp_low_res_data_dir.exists()
assert hcp_low_res_fodf_dir.exists()

### Create Patch-Based Training Dataset

### Validation & Test Datasets

In [None]:
with warnings.catch_warnings(record=True) as warn_list:

    # Validation dataset.
    test_paths_dataset = pitn.data.datasets.HCPfODFINRDataset(
        subj_ids=p.test.subj_ids,
        dwi_root_dir=hcp_full_res_data_dir,
        fodf_root_dir=hcp_full_res_fodf_dir,
        lr_dwi_root_dir=hcp_low_res_data_dir,
        lr_fodf_root_dir=hcp_low_res_fodf_dir,
    )
    cached_test_dataset = monai.data.CacheDataset(
        test_paths_dataset,
        transform=test_paths_dataset.default_pre_sample_tf(0, skip_sample_mask=True),
        copy_cache=False,
        num_workers=2,
    )
    test_dataset = pitn.data.datasets.HCPfODFINRWholeVolDataset(
        cached_test_dataset,
        transform=pitn.data.datasets.HCPfODFINRWholeVolDataset.default_tf(),
    )

    # # Test dataset.
    # # The test dataset won't be cached, as each image should only be loaded once.
    # test_paths_dataset = pitn.data.datasets.HCPfODFINRDataset(
    #     subj_ids=p.test.subj_ids,
    #     dwi_root_dir=hcp_full_res_data_dir,
    #     fodf_root_dir=hcp_full_res_fodf_dir,
    #     lr_dwi_root_dir=hcp_low_res_data_dir,
    #     lr_fodf_root_dir=hcp_low_res_fodf_dir,
    #     transform=pitn.data.datasets.HCPfODFINRDataset.default_pre_sample_tf(
    #         0, skip_sample_mask=True
    #     ),
    # )
    # test_dataset = pitn.data.datasets.HCPfODFINRWholeVolDataset(
    #     test_paths_dataset,
    #     transform=pitn.data.datasets.HCPfODFINRWholeVolDataset.default_tf(),
    # )

print("=" * 10)
print("Warnings caught:")
ws = "\n".join(
    [
        warnings.formatwarning(
            w.message, w.category, w.filename, w.lineno, w.file, w.line
        )
        for w in warn_list
    ]
)
ws = "\n".join(filter(lambda s: bool(s.strip()), ws.splitlines()))
print(ws, flush=True)
print("=" * 10)

## Evaluation

In [None]:
ts = datetime.datetime.now().replace(microsecond=0).isoformat()
# Break ISO format because many programs don't like having colons ':' in a filename.
ts = ts.replace(":", "_")
tmp_res_dir = Path(p.tmp_results_dir) / "_".join([ts, "super_res_odf_test"])
tmp_res_dir.mkdir(parents=True)

In [None]:
test_dataloader = monai.data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
)

### INR

In [None]:
# Encoding model
class INREncoder(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        interior_channels: int,
        out_channels: int,
        n_res_units: int,
        n_dense_units: int,
        activate_fn,
    ):
        super().__init__()

        self.init_kwargs = dict(
            in_channels=in_channels,
            interior_channels=interior_channels,
            out_channels=out_channels,
            n_res_units=n_res_units,
            n_dense_units=n_dense_units,
            activate_fn=activate_fn,
        )

        self.in_channels = in_channels
        self.interior_channels = interior_channels
        self.out_channels = out_channels

        if isinstance(activate_fn, str):
            activate_fn = pitn.utils.torch_lookups.activation[activate_fn]

        self._activation_fn_init = activate_fn
        self.activate_fn = activate_fn()

        # Pad to maintain the same input shape.
        self.pre_conv = torch.nn.Sequential(
            torch.nn.Conv3d(
                self.in_channels,
                self.in_channels,
                kernel_size=1,
                padding="same",
                padding_mode="reflect",
            ),
            self.activate_fn,
            torch.nn.Conv3d(
                self.in_channels,
                self.interior_channels,
                kernel_size=3,
                padding="same",
                padding_mode="reflect",
            ),
        )

        # Construct the densely-connected cascading layers.
        # Create n_dense_units number of dense units.
        top_level_units = list()
        for _ in range(n_dense_units):
            # Create n_res_units number of residual units for every dense unit.
            res_layers = list()
            for _ in range(n_res_units):
                res_layers.append(
                    pitn.nn.layers.ResBlock3dNoBN(
                        self.interior_channels,
                        kernel_size=3,
                        activate_fn=activate_fn,
                        padding="same",
                        padding_mode="reflect",
                    )
                )
            top_level_units.append(
                pitn.nn.layers.DenseCascadeBlock3d(self.interior_channels, *res_layers)
            )

        # Wrap everything into a densely-connected cascade.
        self.cascade = pitn.nn.layers.DenseCascadeBlock3d(
            self.interior_channels, *top_level_units
        )

        self.post_conv = torch.nn.Sequential(
            torch.nn.Conv3d(
                self.interior_channels,
                self.interior_channels,
                kernel_size=5,
                padding="same",
                padding_mode="reflect",
            ),
            self.activate_fn,
            torch.nn.Conv3d(
                self.interior_channels,
                self.out_channels,
                kernel_size=3,
                padding="same",
                padding_mode="reflect",
            ),
            self.activate_fn,
            torch.nn.ReplicationPad3d((1, 0, 1, 0, 1, 0)),
            torch.nn.AvgPool3d(kernel_size=2, stride=1),
            torch.nn.Conv3d(
                self.out_channels,
                self.out_channels,
                kernel_size=1,
                padding="same",
                padding_mode="reflect",
            ),
        )
        # self.post_conv = torch.nn.Conv3d(
        #     self.interior_channels,
        #     self.out_channels,
        #     kernel_size=3,
        #     padding="same",
        #     padding_mode="reflect",
        # )

    def forward(self, x: torch.Tensor):
        y = self.pre_conv(x)
        y = self.activate_fn(y)
        y = self.cascade(y)
        y = self.activate_fn(y)
        y = self.post_conv(y)

        return y


class ReducedDecoder(torch.nn.Module):

    TARGET_COORD_EPSILON = 1e-7

    def __init__(
        self,
        context_v_features: int,
        out_features: int,
        m_encode_num_freqs: int,
        sigma_encode_scale: float,
        in_features=None,
    ):
        super().__init__()
        self.init_kwargs = dict(
            context_v_features=context_v_features,
            out_features=out_features,
            m_encode_num_freqs=m_encode_num_freqs,
            sigma_encode_scale=sigma_encode_scale,
            in_features=in_features,
        )

        # Determine the number of input features needed for the MLP.
        # The order for concatenation is
        # 1) ctx feats over the low-res input space, unfolded over a 3x3x3 window
        # ~~2) target voxel shape~~
        # 3) absolute coords of this forward pass' prediction target
        # 4) absolute coords of the high-res target voxel
        # ~~5) relative coords between high-res target coords and this forward pass'
        #    prediction target, normalized by low-res voxel shape~~
        # 6) encoding of relative coords
        self.context_v_features = context_v_features
        self.ndim = 3
        self.m_encode_num_freqs = m_encode_num_freqs
        self.sigma_encode_scale = torch.as_tensor(sigma_encode_scale)
        self.n_encode_features = self.ndim * 2 * self.m_encode_num_freqs
        self.n_coord_features = 2 * self.ndim + self.n_encode_features
        self.internal_features = self.context_v_features + self.n_coord_features

        self.in_features = in_features
        self.out_features = out_features

        # "Swish" function, recommended in MeshFreeFlowNet
        activate_cls = torch.nn.SiLU
        self.activate_fn = activate_cls(inplace=True)
        # Optional resizing linear layer, if the input size should be different than
        # the hidden layer size.
        if self.in_features is not None:
            self.lin_pre = torch.nn.Linear(self.in_features, self.context_v_features)
            self.norm_pre = None
        else:
            self.lin_pre = None
            self.norm_pre = None
        self.norm_pre = None

        # Internal hidden layers are two res MLPs.
        self.internal_res_repr = torch.nn.ModuleList(
            [
                pitn.nn.inr.SkipMLPBlock(
                    n_context_features=self.context_v_features,
                    n_coord_features=self.n_coord_features,
                    n_dense_layers=3,
                    activate_fn=activate_cls,
                )
                for _ in range(2)
            ]
        )
        self.lin_post = torch.nn.Linear(self.context_v_features, self.out_features)

    def encode_relative_coord(self, coords):
        c = einops.rearrange(coords, "b d x y z -> (b x y z) d")
        sigma = self.sigma_encode_scale.expand_as(c).to(c)[..., None]
        encode_pos = pitn.nn.inr.fourier_position_encoding(
            c, sigma_scale=sigma, m_num_freqs=self.m_encode_num_freqs
        )

        encode_pos = einops.rearrange(
            encode_pos,
            "(b x y z) d -> b d x y z",
            x=coords.shape[2],
            y=coords.shape[3],
            z=coords.shape[4],
        )
        return encode_pos

    def sub_grid_forward(
        self,
        context_val,
        context_coord,
        query_coord,
        context_vox_size,
        query_vox_size,
        return_rel_context_coord=False,
    ):
        # Take relative coordinate difference between the current context
        # coord and the query coord.
        # rel_context_coord = context_coord - query_coord + self.TARGET_COORD_EPSILON
        rel_context_coord = torch.clamp_min(
            context_coord - query_coord,
            (-context_vox_size / 2) + self.TARGET_COORD_EPSILON,
        )
        # Also normalize to [0, 1)
        # Coordinates are located in the center of the voxel. By the way
        # the context vector is being constructed surrounding the query
        # coord, the query coord is always within 1.5 x vox_size of the
        # context (low-res space) coordinate. So, subtract the
        # batch-and-channel-wise minimum, and divide by the known upper
        # bound.
        rel_norm_context_coord = (
            rel_context_coord
            - torch.amin(rel_context_coord, dim=(2, 3, 4), keepdim=True)
        ) / (1.5 * context_vox_size)
        assert (rel_norm_context_coord >= 0).all() and (
            rel_norm_context_coord < 1.0
        ).all()
        encoded_rel_norm_context_coord = self.encode_relative_coord(
            rel_norm_context_coord
        )

        # Perform forward pass of the MLP.
        if self.norm_pre is not None:
            context_val = self.norm_pre(context_val)
        context_feats = einops.rearrange(context_val, "b c x y z -> (b x y z) c")

        # q_vox_size = query_vox_size.expand_as(rel_norm_context_coord)
        coord_feats = (
            # q_vox_size,
            context_coord,
            query_coord,
            # rel_norm_context_coord,
            encoded_rel_norm_context_coord,
        )
        coord_feats = torch.cat(coord_feats, dim=1)
        spatial_layout = {
            "b": coord_feats.shape[0],
            "x": coord_feats.shape[2],
            "y": coord_feats.shape[3],
            "z": coord_feats.shape[4],
        }

        coord_feats = einops.rearrange(coord_feats, "b c x y z -> (b x y z) c")
        x_coord = coord_feats
        sub_grid_pred = context_feats

        if self.lin_pre is not None:
            sub_grid_pred = self.lin_pre(sub_grid_pred)
            sub_grid_pred = self.activate_fn(sub_grid_pred)

        for l in self.internal_res_repr:
            sub_grid_pred, x_coord = l(sub_grid_pred, x_coord)
        sub_grid_pred = self.lin_post(sub_grid_pred)
        sub_grid_pred = einops.rearrange(
            sub_grid_pred, "(b x y z) c -> b c x y z", **spatial_layout
        )
        if return_rel_context_coord:
            ret = (sub_grid_pred, rel_context_coord)
        else:
            ret = sub_grid_pred
        return ret

    def forward(
        self,
        context_v,
        context_spatial_extent,
        query_vox_size,
        query_coord,
    ) -> torch.Tensor:
        if query_vox_size.ndim == 2:
            query_vox_size = query_vox_size[:, :, None, None, None]
        context_vox_size = torch.abs(
            context_spatial_extent[..., 1, 1, 1] - context_spatial_extent[..., 0, 0, 0]
        )
        context_vox_size = context_vox_size[:, :, None, None, None]

        batch_size = query_coord.shape[0]
        # Construct a grid of nearest indices in context space by sampling a grid of
        # *indices* given the coordinates in mm.
        # The channel dim is just repeated for every
        # channel, so that doesn't need to be in the idx grid.
        nearest_coord_idx = torch.stack(
            torch.meshgrid(
                *[
                    torch.arange(
                        0,
                        context_spatial_extent.shape[i],
                        dtype=context_spatial_extent.dtype,
                        device=context_spatial_extent.device,
                    )
                    for i in (0, 2, 3, 4)
                ],
                indexing="ij",
            ),
            dim=1,
        )

        # Find the nearest grid point, where the batch+spatial dims are the
        # "channels."
        nearest_coord_idx = pitn.nn.inr.weighted_ctx_v(
            encoded_feat_vol=nearest_coord_idx,
            input_space_extent=context_spatial_extent,
            target_space_extent=query_coord,
            reindex_spatial_extents=True,
            sample_mode="nearest",
        ).to(torch.long)
        # Expand along channel dimension for raw indexing.
        nearest_coord_idx = einops.rearrange(
            nearest_coord_idx, "b dim i j k -> dim (b i j k)"
        )
        batch_idx = nearest_coord_idx[0]

        # Use the world coordinates to determine the necessary voxel coordinate
        # offsets such that the offsets enclose the query point.
        # World coordinate in the low-res input grid that is closest to the
        # query coordinate.
        phys_coords_0 = context_spatial_extent[
            batch_idx,
            :,
            nearest_coord_idx[1],
            nearest_coord_idx[2],
            nearest_coord_idx[3],
        ]

        phys_coords_0 = einops.rearrange(
            phys_coords_0,
            "(b x y z) c -> b c x y z",
            b=batch_size,
            c=query_coord.shape[1],
            x=query_coord.shape[2],
            y=query_coord.shape[3],
            z=query_coord.shape[4],
        )
        # Determine the quadrants that the query point lies in relative to the
        # context grid. We only care about the spatial/non-batch coordinates.
        surround_query_point_quadrants = (
            query_coord - self.TARGET_COORD_EPSILON - phys_coords_0
        )
        # 3 x batch_and_spatial_size
        # The signs of the "query coordinate - grid coordinate" should match the
        # direction the indexing should go for the nearest voxels to the query.
        surround_offsets_vox = einops.rearrange(
            surround_query_point_quadrants.sign(), "b dim i j k -> dim (b i j k)"
        ).to(torch.int8)
        del surround_query_point_quadrants

        # Now, find sum of distances to normalize the distance-weighted weight vector
        # for in-place 'linear interpolation.'
        inv_dist_total = torch.zeros_like(phys_coords_0)
        inv_dist_total = (inv_dist_total[:, 0])[:, None]
        surround_offsets_vox_volume_order = einops.rearrange(
            surround_offsets_vox,
            "dim (b i j k) -> b dim i j k",
            b=batch_size,
            i=query_coord.shape[2],
            j=query_coord.shape[3],
            k=query_coord.shape[4],
        )
        for (
            offcenter_indicate_i,
            offcenter_indicate_j,
            offcenter_indicate_k,
        ) in itertools.product((0, 1), (0, 1), (0, 1)):
            phys_coords_offset = torch.ones_like(phys_coords_0)
            phys_coords_offset[:, 0] *= (
                offcenter_indicate_i * surround_offsets_vox_volume_order[:, 0]
            ) * context_vox_size[:, 0]
            phys_coords_offset[:, 1] *= (
                offcenter_indicate_j * surround_offsets_vox_volume_order[:, 1]
            ) * context_vox_size[:, 1]
            phys_coords_offset[:, 2] *= (
                offcenter_indicate_k * surround_offsets_vox_volume_order[:, 2]
            ) * context_vox_size[:, 2]
            # phys_coords_offset = context_vox_size * phys_coords_offset
            phys_coords = phys_coords_0 + phys_coords_offset
            inv_dist_total += 1 / torch.linalg.vector_norm(
                query_coord - phys_coords, ord=2, dim=1, keepdim=True
            )
        # Potentially free some memory here.
        del phys_coords
        del phys_coords_0
        del phys_coords_offset
        del surround_offsets_vox_volume_order

        y_weighted_accumulate = None
        # Build the low-res representation one sub-window voxel index at a time.
        # The indicators specify if the current voxel index that surrounds the
        # query coordinate should be "off the center voxel" or not. If not, then
        # the center voxel (read: no voxel offset from the center) is selected
        # (for that dimension).
        for (
            offcenter_indicate_i,
            offcenter_indicate_j,
            offcenter_indicate_k,
        ) in itertools.product((0, 1), (0, 1), (0, 1)):
            # Rebuild indexing tuple for each element of the sub-window
            i_idx = nearest_coord_idx[1] + (
                offcenter_indicate_i * surround_offsets_vox[0]
            )
            j_idx = nearest_coord_idx[2] + (
                offcenter_indicate_j * surround_offsets_vox[1]
            )
            k_idx = nearest_coord_idx[3] + (
                offcenter_indicate_k * surround_offsets_vox[2]
            )
            context_val = context_v[batch_idx, :, i_idx, j_idx, k_idx]
            context_val = einops.rearrange(
                context_val,
                "(b x y z) c -> b c x y z",
                b=batch_size,
                x=query_coord.shape[2],
                y=query_coord.shape[3],
                z=query_coord.shape[4],
            )
            context_coord = context_spatial_extent[batch_idx, :, i_idx, j_idx, k_idx]
            context_coord = einops.rearrange(
                context_coord,
                "(b x y z) c -> b c x y z",
                b=batch_size,
                x=query_coord.shape[2],
                y=query_coord.shape[3],
                z=query_coord.shape[4],
            )

            sub_grid_pred_ijk = self.sub_grid_forward(
                context_val=context_val,
                context_coord=context_coord,
                query_coord=query_coord,
                context_vox_size=context_vox_size,
                query_vox_size=query_vox_size,
                return_rel_context_coord=False,
            )
            # Initialize the accumulated prediction after finding the
            # output size; easier than trying to pre-compute it.
            if y_weighted_accumulate is None:
                y_weighted_accumulate = torch.zeros_like(sub_grid_pred_ijk)

            # Weigh this cell's prediction by the inverse of the distance
            # from the cell physical coordinate to the true target
            # physical coordinate. Normalize the weight by the inverse
            # "sum of the inverse distances" found before.
            w = (
                1
                / torch.linalg.vector_norm(
                    query_coord - context_coord, ord=2, dim=1, keepdim=True
                )
            ) / inv_dist_total

            # Accumulate weighted cell predictions to eventually create
            # the final prediction.
            y_weighted_accumulate += w * sub_grid_pred_ijk
            del sub_grid_pred_ijk

        y = y_weighted_accumulate

        return y

In [None]:
def validate_stage(
    fabric,
    encoder,
    decoder,
    val_dataloader,
    step: int,
    epoch: int,
    aim_run,
    val_viz_subj_id,
):
    encoder_was_training = encoder.training
    decoder_was_training = decoder.training
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        # Set up validation metrics to track for this validation run.
        val_metrics = {"mse": list()}
        for batch_dict in val_dataloader:
            subj_id = batch_dict["subj_id"]
            if len(subj_id) == 1:
                subj_id = subj_id[0]
            if val_viz_subj_id is None:
                val_viz_subj_id = subj_id
            x = batch_dict["lr_dwi"]
            x_coords = batch_dict["lr_extent_acpc"]
            # lr_fodf = batch_dict["lr_fodf"]
            y = batch_dict["fodf"]
            y_mask = batch_dict["mask"].to(torch.bool)
            y_coords = batch_dict["extent_acpc"]
            y_vox_size = torch.atleast_2d(batch_dict["vox_size"])

            ctx_v = encoder(x)
            # pred_fodf = decoder(
            #     context_v=ctx_v,
            #     context_spatial_extent=x_coords,
            #     query_vox_size=y_vox_size,
            #     query_coord=y_coords,
            # )
            # Whole-volume inference is memory-prohibitive, so use a sliding
            # window inference method on the encoded volume.
            pred_fodf = monai.inferers.sliding_window_inference(
                y_coords,
                roi_size=(32, 32, 32),
                sw_batch_size=y_coords.shape[0],
                predictor=lambda q: decoder(
                    query_coord=q,
                    context_v=ctx_v,
                    context_spatial_extent=x_coords,
                    query_vox_size=y_vox_size,
                ),
                overlap=0,
                padding_mode="replicate",
            )

            y_mask_broad = torch.broadcast_to(y_mask, y.shape)
            # Calculate performance metrics
            mse_loss = batchwise_masked_mse(y, pred_fodf, mask=y_mask_broad)
            val_metrics["mse"].append(mse_loss.detach().cpu().flatten())

            # If visualization subj_id is in this batch, create the visual and log it.
            if subj_id == val_viz_subj_id:
                with mpl.rc_context({"font.size": 6.0}):
                    fig = plt.figure(dpi=175, figsize=(8, 5))
                    fig, _ = pitn.viz.plot_fodf_coeff_slices(
                        pred_fodf,
                        y,
                        pred_fodf - y,
                        fig=fig,
                        fodf_vol_labels=("Predicted", "Target", "Pred - GT"),
                        imshow_kwargs={
                            "interpolation": "antialiased",
                            "cmap": "gray",
                        },
                    )
                    aim_run.track(
                        aim.Image(
                            fig,
                            caption=f"Val Subj {subj_id}, "
                            + f"MSE = {val_metrics['mse'][-1].item()}",
                            optimize=True,
                            quality=100,
                            format="png",
                        ),
                        name="sh_whole_volume",
                        context={"subset": "val"},
                        epoch=epoch,
                        step=step,
                    )
                    plt.close(fig)

                # Plot MSE as distributed over the SH orders.
                sh_coeff_labels = {
                    "idx": list(range(0, 45)),
                    "l": np.concatenate(
                        list(
                            map(
                                lambda x: np.array([x] * (2 * x + 1)),
                                range(0, 9, 2),
                            )
                        ),
                        dtype=int,
                    ).flatten(),
                }
                error_fodf = F.mse_loss(pred_fodf, y, reduction="none")
                error_fodf = einops.rearrange(
                    error_fodf, "b sh_idx x y z -> b x y z sh_idx"
                )
                error_fodf = error_fodf[
                    y_mask[:, 0, ..., None].broadcast_to(error_fodf.shape)
                ]
                error_fodf = einops.rearrange(
                    error_fodf, "(elem sh_idx) -> elem sh_idx", sh_idx=45
                )
                error_fodf = error_fodf.flatten().detach().cpu().numpy()
                error_df = pd.DataFrame.from_dict(
                    {
                        "MSE": error_fodf,
                        "SH_idx": np.tile(
                            sh_coeff_labels["idx"], error_fodf.shape[0] // 45
                        ),
                        "L Order": np.tile(
                            sh_coeff_labels["l"], error_fodf.shape[0] // 45
                        ),
                    }
                )
                with mpl.rc_context({"font.size": 6.0}):
                    fig = plt.figure(dpi=140, figsize=(6, 2))
                    sns.boxplot(
                        data=error_df,
                        x="SH_idx",
                        y="MSE",
                        hue="L Order",
                        linewidth=0.8,
                        showfliers=False,
                        width=0.85,
                        dodge=False,
                    )
                    aim_run.track(
                        aim.Image(fig, caption="MSE over SH orders", optimize=True),
                        name="mse_over_sh_orders",
                        epoch=epoch,
                        step=step,
                    )
                    plt.close(fig)
            fabric.print(f"MSE {val_metrics['mse'][-1].item()}")
            fabric.print("Finished validation subj ", subj_id)
            del pred_fodf

    val_metrics["mse"] = torch.cat(val_metrics["mse"])
    # Log metrics
    aim_run.track(
        {"mse": val_metrics["mse"].mean().numpy()},
        context={"subset": "val"},
        step=step,
        epoch=epoch,
    )

    encoder.train(mode=encoder_was_training)
    decoder.train(mode=decoder_was_training)
    return aim_run, val_viz_subj_id

In [None]:
# Test all given subjects.
system_state_dict = torch.load(p.model_weight_f)
encoder_state_dict = system_state_dict["encoder"]
clean_encoder_sd = dict()
for k, v in encoder_state_dict.items():
    if "_forward_module" in k:
        continue
    new_k = k.replace("_original_module.", "")
    clean_encoder_sd[new_k] = v
encoder_state_dict = clean_encoder_sd

decoder_state_dict = system_state_dict["decoder"]
clean_decoder_sd = dict()
for k, v in decoder_state_dict.items():
    if "_forward_module" in k:
        continue
    new_k = k.replace("_original_module.", "")
    clean_decoder_sd[new_k] = v

if "in_channels" not in p.encoder:
    in_channels = int(test_dataset[0]["lr_dwi"].shape[0])
else:
    in_channels = p.encoder.in_channels
decoder_state_dict = clean_decoder_sd

encoder = INREncoder(**{**p.encoder.to_dict(), **{"in_channels": in_channels}})
encoder.load_state_dict(encoder_state_dict)
encoder.to(device)

decoder = ReducedDecoder(**p.decoder.to_dict())
decoder.load_state_dict(decoder_state_dict)
decoder.to(device)
del (
    system_state_dict,
    encoder_state_dict,
    decoder_state_dict,
    clean_decoder_sd,
    clean_encoder_sd,
)

encoder.eval()
decoder.eval()

for batch_dict in test_dataloader:

    subj_id = batch_dict["subj_id"]
    if len(subj_id) == 1:
        subj_id = subj_id[0]
    x = batch_dict["lr_dwi"].to(device)
    x_coords = batch_dict["lr_extent_acpc"].to(device)
    x_vox_size = torch.atleast_2d(batch_dict["lr_vox_size"]).to(device)
    x_mask = batch_dict["lr_mask"].to(torch.bool).to(device)

    lower_lim = torch.stack(
        [
            x_coords[0, 0][0].unique()[0],
            x_coords[0, 1][:, 0].unique()[0],
            x_coords[0, 2][:, :, 0].unique()[0],
        ]
    )
    upper_lim = torch.stack(
        [
            x_coords[0, 0][-1].unique()[0],
            x_coords[0, 1][:, -1].unique()[0],
            x_coords[0, 2][:, :, -1].unique()[0],
        ]
    )
    super_z = torch.arange(lower_lim[0], upper_lim[0], step=p.target_vox_size).to(
        x_coords
    )
    super_y = torch.arange(lower_lim[1], upper_lim[1], step=p.target_vox_size).to(
        x_coords
    )
    super_x = torch.arange(lower_lim[2], upper_lim[2], step=p.target_vox_size).to(
        x_coords
    )

    super_zzz, super_yyy, super_xxx = torch.meshgrid(
        [super_z, super_y, super_x], indexing="ij"
    )
    super_coords = torch.stack([super_zzz, super_yyy, super_xxx], dim=0)[None]
    super_vol_shape = tuple(super_coords.shape[2:])

    super_vox_size = torch.ones_like(x_vox_size) * p.target_vox_size

    vox2acpc = batch_dict["affine_lrvox2acpc"][0].cpu()
    scale = (p.target_vox_size / x_vox_size)[0].cpu()
    scale = torch.cat([scale, scale.new_ones(1)]).cpu()
    scale = torch.diag_embed(scale).to(vox2acpc).cpu()
    new_aff = vox2acpc @ scale
    new_aff = new_aff.numpy()

    with torch.no_grad():
        print("Starting net inference.")
        ctx_v = encoder(x)

        # Whole-volume inference is memory-prohibitive, so use a sliding
        # window inference method on the encoded volume.
        pred_super_fodf = monai.inferers.sliding_window_inference(
            super_coords.cpu(),
            roi_size=(48, 48, 48),
            sw_batch_size=super_coords.shape[0],
            predictor=lambda q: decoder(
                query_coord=q.to(device),
                context_v=ctx_v,
                context_spatial_extent=x_coords,
                query_vox_size=super_vox_size,
            ).cpu(),
            overlap=0,
            padding_mode="replicate",
        )
    print("Finished network inference.")
    mask_coords = einops.rearrange(super_coords, "b coord z y x -> b (z y x) coord")
    super_mask = pitn.affine.sample_3d(
        x_mask.cpu(), mask_coords.cpu(), vox2acpc, mode="nearest", align_corners=True
    )
    super_mask = (
        einops.rearrange(
            super_mask,
            "b (z y x) c -> b z y x c",
            z=super_vol_shape[0],
            y=super_vol_shape[1],
            x=super_vol_shape[2],
        )
        .squeeze()
        .cpu()
        .to(torch.int8)
        .numpy()
    )
    superres_pred = pred_super_fodf.cpu().numpy()
    superres_pred = superres_pred * super_mask
    odf_coeffs = np.moveaxis(superres_pred, 1, -1).squeeze()
    print("Saving super-res fodf coeffs.")
    nib.save(
        nib.Nifti1Image(odf_coeffs, affine=new_aff),
        tmp_res_dir / f"{subj_id}_odf-coeff_inr-super-res_{p.target_vox_size}mm.nii.gz",
    )
    print("Saving mask.")
    nib.save(
        nib.Nifti1Image(super_mask, affine=new_aff),
        tmp_res_dir / f"{subj_id}_mask-super-res_{p.target_vox_size}mm.nii.gz",
    )

### Tri-Linear Interp

In [None]:
for batch_dict in test_dataloader:

    subj_id = batch_dict["subj_id"]
    if len(subj_id) == 1:
        subj_id = subj_id[0]
    x = batch_dict["lr_fodf"].to(device)
    x_coords = batch_dict["lr_extent_acpc"].to(device)
    x_vox_size = torch.atleast_2d(batch_dict["lr_vox_size"]).to(device)
    x_mask = batch_dict["lr_mask"].to(torch.bool).to(device)

    lower_lim = torch.stack(
        [
            x_coords[0, 0][0].unique()[0],
            x_coords[0, 1][:, 0].unique()[0],
            x_coords[0, 2][:, :, 0].unique()[0],
        ]
    )
    upper_lim = torch.stack(
        [
            x_coords[0, 0][-1].unique()[0],
            x_coords[0, 1][:, -1].unique()[0],
            x_coords[0, 2][:, :, -1].unique()[0],
        ]
    )
    super_z = torch.arange(lower_lim[0], upper_lim[0], step=p.target_vox_size).to(
        x_coords
    )
    super_y = torch.arange(lower_lim[1], upper_lim[1], step=p.target_vox_size).to(
        x_coords
    )
    super_x = torch.arange(lower_lim[2], upper_lim[2], step=p.target_vox_size).to(
        x_coords
    )

    super_zzz, super_yyy, super_xxx = torch.meshgrid(
        [super_z, super_y, super_x], indexing="ij"
    )
    super_coords = torch.stack([super_zzz, super_yyy, super_xxx], dim=-1)[None]
    super_vol_shape = tuple(super_coords.shape[:-1])
    super_coords = einops.rearrange(super_coords, "b z y x coord -> b (z y x) coord")

    super_vox_size = torch.ones_like(x_vox_size) * p.target_vox_size

    vox2acpc = batch_dict["affine_lrvox2acpc"][0].cpu()
    scale = (p.target_vox_size / x_vox_size)[0].cpu()
    scale = torch.cat([scale, scale.new_ones(1)]).cpu()
    scale = torch.diag_embed(scale).to(vox2acpc).cpu()
    new_aff = vox2acpc @ scale
    new_aff = new_aff.numpy()
    print("Resample fodf coeffs")
    pred_super_fodf = pitn.affine.sample_3d(
        x.cpu(), super_coords.cpu(), vox2acpc, mode="bicubic", align_corners=True
    )
    super_mask = pitn.affine.sample_3d(
        x_mask.cpu(), super_coords.cpu(), vox2acpc, mode="nearest", align_corners=True
    )
    pred_super_fodf = pred_super_fodf * super_mask.bool()
    superres_pred = pred_super_fodf.detach().cpu()
    superres_pred = einops.rearrange(
        superres_pred,
        "b (z y x) c -> b z y x c",
        z=super_vol_shape[1],
        y=super_vol_shape[2],
        x=super_vol_shape[3],
    )
    superres_pred = superres_pred.numpy().astype(np.float32).squeeze()

    super_mask = (
        einops.rearrange(
            super_mask,
            "b (z y x) c -> b z y x c",
            z=super_vol_shape[1],
            y=super_vol_shape[2],
            x=super_vol_shape[3],
        )
        .squeeze()
        .cpu()
        .to(torch.int8)
        .numpy()
    )
    nib.save(
        nib.Nifti1Image(superres_pred, affine=new_aff),
        tmp_res_dir
        / f"{subj_id}_odf-coeff_tri-linear-super-res_{p.target_vox_size}mm.nii.gz",
    )
    # nib.save(
    #     nib.Nifti1Image(super_mask, affine=new_aff),
    #     tmp_res_dir / f"{subj_id}_mask-super-res_{p.target_vox_size}mm.nii.gz",
    # )