# Pain in the Net
Replication of *Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images*


Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

---

Source work:
`S. B. Blumberg, R. Tanno, I. Kokkinos, and D. C. Alexander, “Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images,” in Medical Image Computing and Computer Assisted Intervention – MICCAI 2018, Cham, 2018, pp. 118–125, doi: 10.1007/978-3-030-00928-1_14.`


In [1]:
# imports
import math
import random
import itertools
import collections
import functools
import pathlib
from pathlib import Path
import warnings
import os
import subprocess
import io

import dotenv
# Computation & ML libraries.
import numpy as np
import pandas as pd
import skimage
import skimage.measure, skimage.feature, skimage.filters
import torch
import torch.nn.functional as F
import torchvision
import pytorch_lightning
import nilearn
import nilearn.plotting
# Data management libraries.
import nibabel as nib
import torchio
import natsort
from natsort import natsorted

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams.update({'figure.autolayout': True})
plt.rcParams.update({'figure.facecolor': [1.0, 1.0, 1.0, 1.0]})

  warn("Fetchers from the nilearn.datasets module will be "


In [2]:
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash 
print("CUDA Version: ", torch.version.cuda)

Author: Tyler Spears

Last updated: 2021-03-19T11:04:08.389054-04:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.21.0

Compiler    : GCC 7.3.0
OS          : Linux
Release     : 5.8.0-3-amd64
Machine     : x86_64
Processor   : 
CPU cores   : 8
Architecture: 64bit

Git hash: e85b592aab826d95c03905d7f7b5f67ad1e3fb0a

torchvision      : 0.2.2
torch            : 1.8.0
skimage          : 0.18.1
pytorch_lightning: 1.2.3
natsort          : 7.1.1
torchio          : 0.18.29
nilearn          : 0.7.1
nibabel          : 3.2.1
pandas           : 1.2.3
numpy            : 1.20.1
seaborn          : 0.11.1
matplotlib       : 3.3.4

CUDA Version:  11.1


In [3]:
# torch setup

# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
# keep device as the cpu
# device = torch.device('cpu')
print(device)

cuda


In [4]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(
    command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode('utf-8')
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [5]:
# Set up directories
data_dir = pathlib.Path(os.environ['DATA_DIR']) / "hcp"
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ['WRITE_DATA_DIR']) / "hcp"
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ['RESULTS_DIR'])
assert results_dir.exists()

### Function & Class Definitions

### Global Parameters

In [6]:
bvals_range = (900, 1100)

## Data Loading

In [7]:
# Find data directories for each subject.
subj_dirs: dict = dict()


selected_ids = [
    '397154',
    '224022',
    '140117',
    '751348',
    '894774',
    '156637',
    '227432',
    '303624',
    '185947',
    '810439',
    '753251',
    '644246',
    '141422',
    '135528',
    '103010',
    '700634',
]

## Sub-set the chosen participants for dev and debugging!
selected_ids = random.sample(selected_ids, 5)
warnings.warn(
    "WARNING: Sub-selecting participants for dev and debugging. "
    + f"Subj IDs selected: {selected_ids}"
)
##

selected_ids = natsorted(list(map(lambda s: int(s), selected_ids)))

for subj_id in selected_ids:
    subj_dirs[subj_id] = data_dir / f"{subj_id}/T1w/Diffusion"
    assert subj_dirs[subj_id].exists()
subj_dirs



{185947: PosixPath('/srv/data/pitn/hcp/185947/T1w/Diffusion'),
 303624: PosixPath('/srv/data/pitn/hcp/303624/T1w/Diffusion'),
 397154: PosixPath('/srv/data/pitn/hcp/397154/T1w/Diffusion'),
 644246: PosixPath('/srv/data/pitn/hcp/644246/T1w/Diffusion'),
 894774: PosixPath('/srv/data/pitn/hcp/894774/T1w/Diffusion')}

The 90 scans are taken from the b=1000 s/mm^2 *only*. To find those, sub-select with the $500 < bvals < 1500$, or roughly thereabout. A b-val of $995$ or $1005$ still counts as a b=1000 (is this the direction encoding? Or maybe the difference is where the gradient comes in?).

In [8]:
# Import all image data into a sequence of `torchio.Subject` objects.
subj_data: dict = dict()

# For more clearly designating the return values of a reader function given to
# the `torchio.Image` object.
ReaderOutput = collections.namedtuple('ReaderOutput', ['dwi', 'affine'])
def selection_reader(f_dwi, idx):
    # A float32 is the smallest representation that doesn't lose data.
    img = nib.load(f_dwi)
    
    # Load entire image into memory, then slice that full image.
    # <https://nipy.org/nibabel/images_and_memory.html#saving-time-and-memory>
    # Fancy indexing isn't supported, aka, anything besides a single integer index.
    sliced_img = img.get_fdata()[..., idx]
    affine = img.affine.copy()
    # Hopefuly remove any refs to the full-size array.
    del img
    return ReaderOutput(
        dwi=torch.from_numpy(sliced_img.astype(np.float32)),
        affine=torch.from_numpy(affine.astype(np.float32))
    )

for subj_id, subj_dir in subj_dirs.items():
    # Sub-select volumes with only bvals in a certain range (e.x. bvals ~= 1000 mm/s^2)
    bvals = torch.as_tensor(np.loadtxt(subj_dir/"bvals").astype(int))
    idx_to_keep = np.where((bvals >= bvals_range[0]) & (bvals <= bvals_range[1]))
    bvals = bvals[idx_to_keep]
    bvecs=torch.as_tensor(np.loadtxt(subj_dir/"bvecs"))[:, idx_to_keep]
    # Create a custom reader function that will sub-set the full (~288 for HCP) data 
    # points down to only the selected bvals range.
    # This will still be lazily loaded, while not requiring as much memory per subject.
    partial_reader = functools.partial(selection_reader, idx=idx_to_keep)
    dwi = torchio.ScalarImage(
        subj_dir/"data.nii.gz",
        bvals=bvals,
        bvecs=bvecs,
        reader=partial_reader
    )
    grad = torchio.ScalarImage(subj_dir/"grad_dev.nii.gz")
    brain_mask = torchio.LabelMap(subj_dir/"nodif_brain_mask.nii.gz", type=torchio.LABEL)
    # The brain mask is binary.
    
    brain_mask.set_data(brain_mask.data.bool())
    subject_dict = torchio.Subject(
        subj_id=subj_id,
        dwi=dwi,
        grad=grad,
        brain_mask=brain_mask
    )
    subj_data[subj_id] = subject_dict


In [9]:
subj_dataset = torchio.SubjectsDataset(list(subj_data.values()))

## Model Training

### Model Definition

### Training

## Model Evaluation

### Testing

### Visualization