# Pain in the Net
Replication of *Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images*


Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

---

Source work:
`S. B. Blumberg, R. Tanno, I. Kokkinos, and D. C. Alexander, “Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images,” in Medical Image Computing and Computer Assisted Intervention – MICCAI 2018, Cham, 2018, pp. 118–125, doi: 10.1007/978-3-030-00928-1_14.`


## Imports & Environment Setup

In [16]:
# imports
import collections
import functools
import io

import math
import itertools
import os
import pathlib
import pdb
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
from typing import Generator

import ants
import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import natsort

# Data management libraries.
import nibabel as nib
import nilearn
import nilearn.plotting

# Computation & ML libraries.
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import seaborn as sns
import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import torch
import torch.nn.functional as F
import torchio
import torchvision
from natsort import natsorted

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True)
torch.set_printoptions(sci_mode=False)

In [2]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [3]:
# Project-specific scripts
# It's easier to import it this way rather than make an entirely new package, due to
# conflicts with local packages and anaconda installations.
# You made me do this, poor python package management!!
if "PROJECT_ROOT" in os.environ:
    src_location = str(Path(os.environ["PROJECT_ROOT"]).resolve())
else:
    src_location = str(Path("../../").resolve())
sys.path.append(src_location)
import src as pitn

In [4]:
# torch setup

# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

cuda


In [5]:
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
print("CUDA Version: ", torch.version.cuda)

Author: Tyler Spears

Last updated: 2021-04-01T14:52:07.477953+00:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.22.0

Compiler    : GCC 7.3.0
OS          : Linux
Release     : 5.4.0-52-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

Git hash: 1dd50c3b9ef4c348d426c908d9ef77485e3c3ed3

sys              : 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0]
torchvision      : 0.2.2
nibabel          : 3.2.1
ants             : 0.2.7
dipy             : 1.4.0
matplotlib       : 3.4.1
pandas           : 1.2.3
skimage          : 0.18.1
torchio          : 0.18.31
seaborn          : 0.11.1
numpy            : 1.20.2
nilearn          : 0.7.1
torch            : 1.8.1
natsort          : 7.1.1
pytorch_lightning: 1.2.6

CUDA Version:  11.1


## Variables & Definitions Setup

In [6]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"]) / "hcp"
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"]) / "hcp"
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()

### Global Function & Class Definitions

In [8]:
# For more clearly designating the return values of a reader function given to
# the `torchio.Image` object.
ReaderOutput = collections.namedtuple("ReaderOutput", ["dwi", "affine"])


def nifti_reader(
    f_dwi,
) -> ReaderOutput:
    """Reader that reads in NIFTI files quickly.

    Meant for use with the `torchio.Image` object and its sub-classes.
    """

    # Load with nibabel first to get the correct affine matrix. See
    # <https://github.com/ANTsX/ANTsPy/issues/52> for why I don't trust antspy for this.
    # This does not require loading the entire NIFTI file into memory.
    affine = nib.load(f_dwi).affine.copy()
    affine = affine.astype(np.float32)
    print("Loading NIFTI image", flush=True)
    # Load entire image with antspy, then slice and (possibly) downsample that full image.
    # A float32 is the smallest representation that doesn't lose data.
    dwi = ants.image_read(str(f_dwi), pixeltype="float")
    print("\tLoaded NIFTI image", flush=True)

    # Use `torch.tensor()` to explicitly copy the numpy array. May have issues with
    # underlying memory getting garbage collected when using `torch.from_numpy`.
    # <https://pytorch.org/docs/1.8.0/generated/torch.tensor.html#torch.tensor>
    return ReaderOutput(dwi=torch.tensor(dwi.view()), affine=torch.tensor(affine))

In [9]:
# torchio.Transform functions/objects.


class BValSelectionTransform(torchio.SpatialTransform):
    """Sub-selects scans that are within a certain range of bvals.

    Expects:
    - volumes in canonical (RAS+) format with *channels first.*
    - bvecs to be of shape (N, 3), with N being the number of scans/bvals.

    """

    def __init__(self, bval_range: tuple, bval_key, bvec_key, **kwargs):
        super().__init__(**kwargs)

        self.bval_range = bval_range
        self.bval_key = bval_key
        self.bvec_key = bvec_key

    def apply_transform(self, subject: torchio.Subject) -> torchio.Subject:
        print("Selecting with bvals", flush=True)

        for img in self.get_images(subject):
            bvals = img[self.bval_key]
            scans_to_keep = (self.bval_range[0] <= bvals) & (
                bvals <= self.bval_range[-1]
            )
            img[self.bvec_key] = img[self.bvec_key][scans_to_keep, :]
            img.set_data(img.data[scans_to_keep, ...])
            img[self.bval_key] = img[self.bval_key][scans_to_keep]
        print("\tSelected", flush=True)
        return subject


class MeanDownsampleTransform(torchio.SpatialTransform):
    """Mean downsampling transformation.

    Expects volumes in canonical (RAS+) format with *channels first.*
    """

    def __init__(self, downsample_factor: int, **kwargs):
        super().__init__(**kwargs)

        self.downsample_factor = downsample_factor

    def apply_transform(self, subject: torchio.Subject) -> torchio.Subject:
        print("Downsampling", flush=True)
        # Get reference to Image objects that have been included for transformation.

        for img in self.get_images(subject):
            img["downsample_factor"] = self.downsample_factor
            if self.downsample_factor == 1:
                continue
            # Determine dimension-specific downsample factors
            img_ndarray = img.data.numpy()
            dim_factors = np.asarray(
                [
                    self.downsample_factor,
                ]
                * img_ndarray.ndim
            )
            # Only spatial dimensions should be downsampled.
            if img.data.ndim > 3:
                # Don't downsample the channels
                dim_factors[0] = 1
                # Or anything else outside of spatial dims.
                dim_factors[3:] = 1

            downsample_vol = skimage.transform.downscale_local_mean(
                img_ndarray, factors=tuple(dim_factors), cval=0
            )
            downsample_vol = torch.from_numpy(
                downsample_vol.astype(img_ndarray.dtype)
            ).to(img.data.dtype)

            img.set_data(downsample_vol)
            scaled_affine = img.affine.copy()
            # Scale the XYZ coordinates on the main diagonal.
            scaled_affine[(0, 1, 2), (0, 1, 2)] = (
                scaled_affine[(0, 1, 2), (0, 1, 2)] * self.downsample_factor
            )
            img.affine = scaled_affine
        print("\tDownsampled", flush=True)
        return subject


class FitDTITransform(torchio.SpatialTransform, torchio.IntensityTransform):
    def __init__(
        self,
        bval_key,
        bvec_key,
        mask_img_key=None,
        fit_method="WLS",
        tensor_model_kwargs=dict(),
        **kwargs
    ):
        super().__init__(**kwargs)

        self.bval_key = bval_key
        self.bvec_key = bvec_key
        self.mask_img_key = mask_img_key
        self.fit_method = fit_method
        self.tensor_model_kwargs = tensor_model_kwargs

    def apply_transform(self, subject: torchio.Subject) -> torchio.Subject:

        print("Fitting to DTI", flush=True)
        mask_img = subject[self.mask_img_key] if self.mask_img_key is not None else None
        for img in self.get_images(subject):

            gradient_table = dipy.core.gradients.gradient_table_from_bvals_bvecs(
                bvals=img[self.bval_key],
                bvecs=img[self.bvec_key],
            )

            tensor_model = dipy.reconst.dti.TensorModel(
                gradient_table, fit_method=self.fit_method, **self.tensor_model_kwargs
            )
            # dipy does not like the channels being first, apparently.
            if mask_img is not None:
                dti = tensor_model.fit(
                    np.moveaxis(img.numpy(), 0, -1),
                    mask=mask_img.numpy().squeeze().astype(bool),
                )
            else:
                dti = tensor_model.fit(np.moveaxis(img.numpy(), 0, -1))

            # Pull only the lower-triangular part of the DTI (the non-symmetric coefficients.)
            img.set_data(
                torch.from_numpy(dti.lower_triangular().astype(np.float32)).to(img.data)
            )
        print("\tFitted DTI model", flush=True)

        return subject


class RenameImageTransform(torchio.Transform):
    def __init__(self, name_mapping: dict, **kwargs):
        super().__init__(**kwargs)

        self.name_mapping = name_mapping

    def apply_transform(self, subject: torchio.Subject) -> torchio.Subject:
        for old_name, new_name in self.name_mapping.items():
            tmp = subject[old_name]
            subject.remove_image(old_name)
            subject.add_image(tmp, new_name)

        return subject

In [17]:
# Definitions for sampling and loading patches from volumes of different resolutions in a `pytorch.utils.data.DataLoader`.

# Return type wrapper
MultiresSample = collections.namedtuple("MultiresSample", ("full_res", "low_res"))

# Custom sampler for sampling multiple volumes of different resolutions.
class MultiresSampler(torchio.LabelSampler):
    def __init__(
        self,
        source_img_key,
        low_res_key,
        downsample_factor_key,
        label_name,
        subj_keys_to_copy=tuple(),
        **kwargs,
    ):

        super().__init__(label_name=label_name, **kwargs)
        self.source_img_key = source_img_key
        self.low_res_key = low_res_key
        self.downsample_factor_key = downsample_factor_key
        self.subj_keys_to_copy = subj_keys_to_copy

    def __call__(
        self, subject: torchio.Subject, num_patches=None
    ) -> Generator[torchio.Subject, None, None]:

        # Setup copied from the `torchio.WeightedSampler.__call__` function definition.
        subject.check_consistent_space()
        if np.any(self.patch_size > subject.spatial_shape):
            message = (
                f"Patch size {tuple(self.patch_size)} cannot be"
                f" larger than image size {tuple(subject.spatial_shape)}"
            )
            raise RuntimeError(message)
        probability_map = self.get_probability_map(subject)
        probability_map = self.process_probability_map(probability_map, subject)
        cdf = self.get_cumulative_distribution_function(probability_map)

        patches_left = num_patches if num_patches is not None else True
        while patches_left:
            subj_fields_transfer = dict(
                ((k, subject[k]) for k in self.subj_keys_to_copy)
            )
            # Create a new subject that only contains patches.
            patch_subj = torchio.Subject(**subj_fields_transfer)
            # Sample an index from the full-res image.
            source_index_ini = self.get_random_index_ini(probability_map, cdf)
            # Include the index in the subject.
            patch_subj["index_ini"] = np.array(source_index_ini).astype(int)

            # Add the patch from the full-res image into the subject.
            patch_subj.add_image(
                self.extract_subj_patch(
                    subject,
                    img_key=self.source_img_key,
                    index_ini=source_index_ini,
                    patch_size=self.patch_size,
                ),
                self.img_key,
            )

            # Crop low-res image and add to the subject.
            lr_index_ini = tuple(
                np.array(source_index_ini).astype(int)
                // subject[self.low_res_key][self.downsample_factor_key]
            )
            patch_subj.add_image(
                self.extract_subj_patch(
                    subject,
                    img_key=self.low_res_key,
                    index_ini=lr_index_ini,
                    patch_size=self.patch_size,
                ),
                self.low_res_key,
            )

            # Return the new patch subject.
            yield patch_subj
            if num_patches is not None:
                patches_left -= 1

    @classmethod
    def extract_subj_patch(
        cls, subject: torchio.Subject, img_key, index_ini, patch_size
    ) -> torchio.Image:

        img = subject[img_key]
        cropping = cls.get_crop_transform(subject, index_ini, patch_size).cropping
        crop_transform = torchio.transforms.Crop(cropping, copy=False)
        return crop_transform(img)


# Collate function for the DataLoader to combine multiple samples.
def collate_subj(samples, full_res_key: str, low_res_key: str):
    full_res_stack = torch.stack([subj[full_res_key] for subj in samples])
    low_res_stack = torch.stack([subj[low_res_key] for subj in samples])

    return MultiresSample(full_res=full_res_stack, low_res=low_res_stack)

### Global Parameters

In [11]:
downsample_factor = 2
# Include b=0 shells and b=1000 shells for DTI fitting.
bval_range = (0, 1500)
dti_fit_method = "WLS"

## Data Loading

In [12]:
# Find data directories for each subject.
subj_dirs: dict = dict()

selected_ids = [
    "397154",
    "224022",
    "140117",
    "751348",
    "894774",
    "156637",
    "227432",
    "303624",
    "185947",
    "810439",
    "753251",
    "644246",
    "141422",
    "135528",
    "103010",
    "700634",
]

## Sub-set the chosen participants for dev and debugging!
selected_ids = random.sample(selected_ids, 3)
warnings.warn(
    "WARNING: Sub-selecting participants for dev and debugging. "
    + f"Subj IDs selected: {selected_ids}"
)
##

selected_ids = natsorted(list(map(lambda s: int(s), selected_ids)))

for subj_id in selected_ids:
    subj_dirs[subj_id] = data_dir / f"{subj_id}/T1w/Diffusion"
    assert subj_dirs[subj_id].exists()
subj_dirs





{135528: PosixPath('/mnt/storage/data/pitn/hcp/135528/T1w/Diffusion'),
 141422: PosixPath('/mnt/storage/data/pitn/hcp/141422/T1w/Diffusion'),
 224022: PosixPath('/mnt/storage/data/pitn/hcp/224022/T1w/Diffusion')}

The 90 scans are taken from the $b=1000 \ s/mm^2$. However, the $b=0$ shells are still required for fitting the diffusion tensors (DTI's), so those will need to be kept, too.

To find those, sub-select with the $0 < bvals < 1500$, or roughly thereabout. A b-val of $995$ or $1005$ still counts as a b=1000.

In [13]:
# Import all image data into a sequence of `torchio.Subject` objects.
subj_data: dict = dict()

for subj_id, subj_dir in subj_dirs.items():
    # Sub-select volumes with only bvals in a certain range. E.x. bvals <= 1100 mm/s^2,
    # a.k.a. only the b=0 and b=1000 shells.
    bvals = torch.as_tensor(np.loadtxt(subj_dir / "bvals").astype(int))
    bvecs = torch.as_tensor(np.loadtxt(subj_dir / "bvecs"))
    # Reshape to be N x 3
    if bvecs.shape[0] == 3:
        bvecs = bvecs.T

    # grad = torchio.ScalarImage(subj_dir/"grad_dev.nii.gz")
    brain_mask = torchio.LabelMap(
        subj_dir / "nodif_brain_mask.nii.gz",
        type=torchio.LABEL,
        channels_last=False,
    )

    # The brain mask is binary.
    brain_mask.set_data(brain_mask.data.bool())

    dwi = torchio.ScalarImage(
        subj_dir / "data.nii.gz",
        type=torchio.INTENSITY,
        bvals=bvals,
        bvecs=bvecs,
        reader=nifti_reader,
        channels_last=True,
    )

    subject_dict = torchio.Subject(subj_id=subj_id, dwi=dwi, brain_mask=brain_mask)

    preproc_transforms = torchio.Compose(
        [
            torchio.transforms.ToCanonical(include=("dwi", "brain_mask"), copy=False),
            BValSelectionTransform(
                bval_range=bval_range,
                bval_key="bvals",
                bvec_key="bvecs",
                include="dwi",
                copy=False,
            ),
            MeanDownsampleTransform(
                downsample_factor,
                include=("dwi", "brain_mask"),
                keep={"dwi": "fr_dwi", "brain_mask": "fr_brain_mask"},
                copy=False,
            ),
            RenameImageTransform(
                {"dwi": "lr_dwi", "brain_mask": "lr_brain_mask"}, copy=False
            ),
            FitDTITransform(
                "bvals",
                "bvecs",
                "fr_brain_mask",
                fit_method=dti_fit_method,
                include=("fr_dwi"),
                copy=False,
            ),
            FitDTITransform(
                "bvals",
                "bvecs",
                "lr_brain_mask",
                fit_method=dti_fit_method,
                include=("lr_dwi"),
                copy=False,
            ),
            RenameImageTransform({"fr_dwi": "fr_dti", "lr_dwi": "lr_dti"}, copy=False),
        ]
    )

    subj_data[subj_id] = preproc_transforms(subject_dict)

Loading NIFTI image
	Loaded NIFTI image
Selecting with bvals
	Selected
Downsampling
	Downsampled
Fitting to DTI
	Fitted DTI model
Fitting to DTI
	Fitted DTI model


In [14]:
subj_dataset = torchio.SubjectsDataset(list(subj_data.values()), load_getitem=False)

## Model Training

In [None]:
# Patch parameters
batch_size = 16
# 6 channels for the 6 DTI components
channels = 6

# Input patch parameters
h_in = 11
w_in = 11
d_in = 11
input_patch_shape = (channels, h_in, w_in, d_in)
# Output patch parameters
h_out = downsample_factor * h_in
w_out = downsample_factor * w_in
d_out = downsample_factor * d_in
unshuffled_channels_out = channels * downsample_factor ** 3
# Output before shuffling
unshuffled_output_patch_shape = (unshuffled_channels_out, h_in, w_in, d_in)
# Output shape after shuffling.
output_patch_shape = (channels, h_out, w_out, d_out)

### Set Up Patch-Based Data Loaders

In [None]:
# Data train/validation/test split
test_percent = 0.2
train_percent = 1 - test_percent
# val_percent = 0.1

num_subjs = len(subj_dataset)
num_test_subjs = int(np.ceil(num_subjs * test_percent))
num_train_subjs = num_subjs - num_test_subjs
subj_list = subj_dataset.dry_iter()
# Randomly shuffle the list of subjects, then choose the first `num_test_subjs` subjects
# for testing.
random.shuffle(subj_list)
test_dataset = torchio.SubjectsDataset(subj_list[:num_test_subjs], load_getitem=False)
# Choose the remaining for training/validation.
subj_list = subj_list[num_test_subjs:]
train_dataset = torchio.SubjectsDataset(subj_list, load_getitem=False)
torchio.LabelSampler()
# Training patch sampler, random across all patches of all volumes.
train_sampler = MultiresSampler(
    source_img_key="fr_dti",
    low_res_key="lr_dti",
    downsample_factor_key="downsample_factor",
    label_name="brain_mask",
    patch_size=(h_in, w_in, d_in),
    label_probabilities={0: 0, 1: 1},
)
# Set up a torchio.Queue to act as a sampler proxy for the torch DataLoader
train_queue = torchio.Queue(
    train_dataset,
    max_length=5 * batch_size,
    samples_per_volume=32,
    sampler=train_sampler,
    shuffle_patches=True,
    shuffle_subjects=True,
    num_workers=2,
)
# Create partial function to collect list of samples and form a tuple of tensors.
collate_fn = functools.partial(
    collate_subj, full_res_key="fr_dti", low_res_key="lr_dti"
)
train_loader = torch.utils.data.DataLoader(
    train_queue,
    batch_size=batch_size,
    collate_fn=collate_fn,
    pin_memory=True,
)
# Test samplers must be dynamically created during testing.

print("Test subject(s) IDs: ", [s.subj_id for s in test_dataset.dry_iter()])
print("Training subject(s) IDs: ", [s.subj_id for s in train_dataset.dry_iter()])

### Model Definition

In [None]:
# Define pytorch-lightning module.
class DIQTSystem(pl.LightningModule):
    def __init__(self, channels_in, channels_out):
        super().__init__()

        self.channels_in = channels_in
        self.channels_out = channels_out
        # Parameters
        # Network parameters
        self.num_revnet_layers = 4
        self.net = pitn.models.ESPCN_RN(
            no_RevNet_layers=self.num_revnet_layers,
            no_chans_in=self.channels_in,
            no_chans_out=self.channels_out,
            memory_efficient=True,
        )

        ## Training parameters
        self.lr = 10e-4
        mse_loss = torch.nn.MSELoss(reduction="mean")
        self.loss_fn = lambda y_hat, y: torch.sqrt(mse_loss(y_hat, y))

    def forward(self, x):
        y = self.net(x)
        return y

    def training_step(self, batch, batch_idx):
        breakpoint()
        opt = self.optimizers()
        x, y = batch
        y_pred_shuffled = self.net(x)
        y_pred = y_pred_shuffled
        # y_pred = deshuffle(y_pred_shuffled)
        loss = self.loss_fn(y_pred, y)

        # Perform manual backprop
        loss.backward()
        self.net.backward(y_pred, y_pred.grad)
        opt.step()
        opt.zero_grad()

    #     def validation_step(self, batch, batch_idx):
    #         pass

    def test_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        return optimizer

### Training

In [None]:
# Training parameters
max_epochs = 100

In [None]:
model = DIQTSystem(channels_in=channels, channels_out=unshuffled_channels_out)
# Create trainer object. Note: `automatic_optimization` needs to be set to `False` when
# manually performing backprop. See
# <https://colab.research.google.com/drive/1nGtvBFirIvtNQdppe2xBes6aJnZMjvl8?usp=sharing>
trainer = pl.Trainer(
    gpus=1, max_epochs=1, automatic_optimization=False, progress_bar_refresh_rate=10
)
# trainer = pl.Trainer(
#     gpus=1, max_epochs=max_epochs, automatic_optimization=False, progress_bar_refresh_rate=10
# )
trainer.fit(model, train_loader)

## Model Evaluation

### Testing

### Visualization