# Pain in the Net
Replication of *Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images*


Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

---

Source work:
`S. B. Blumberg, R. Tanno, I. Kokkinos, and D. C. Alexander, “Deeper Image Quality Transfer: Training Low-Memory Neural Networks for 3D Images,” in Medical Image Computing and Computer Assisted Intervention – MICCAI 2018, Cham, 2018, pp. 118–125, doi: 10.1007/978-3-030-00928-1_14.`


## Imports & Environment Setup

In [79]:
# imports
import math
import random
import itertools
import collections
import functools
import pathlib
from pathlib import Path
import warnings
import sys
import os
import subprocess
import io

import dotenv
# Computation & ML libraries.
import numpy as np
import pandas as pd
import skimage
import skimage.measure, skimage.feature, skimage.filters
import torch
import torch.nn.functional as F
import torchvision
import pytorch_lightning as pl
import nilearn
import nilearn.plotting
import dipy
import dipy.reconst
import dipy.reconst.dti
# Data management libraries.
import nibabel as nib
import torchio
import natsort
from natsort import natsorted

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams.update({'figure.autolayout': True})
plt.rcParams.update({'figure.facecolor': [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True)
torch.set_printoptions(sci_mode=False)

In [2]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(
    command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode('utf-8')
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [3]:
# Project-specific scripts
# It's easier to import it this way rather than make an entirely new package, due to
# conflicts with local packages and anaconda installations.
# You made me do this, poor python package management!!
if 'PROJECT_ROOT' in os.environ:
    src_location = str(Path(os.environ['PROJECT_ROOT']).resolve())
else:
    src_location = str(Path("../../").resolve())
sys.path.append(src_location)
import src as pitn

In [4]:
# torch setup

# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
# keep device as the cpu
# device = torch.device('cpu')
print(device)

cuda


In [76]:
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
print("CUDA Version: ", torch.version.cuda)

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
Author: Tyler Spears

Last updated: 2021-03-22T16:52:36.282472-04:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.21.0

Compiler    : GCC 7.3.0
OS          : Linux
Release     : 5.8.0-3-amd64
Machine     : x86_64
Processor   : 
CPU cores   : 8
Architecture: 64bit

Git hash: 3163f7b31449d9cd01e861f7037463dc8f9a77c4

matplotlib       : 3.3.4
dipy             : 1.3.0
pandas           : 1.2.3
torchvision      : 0.2.2
sys              : 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0]
pytorch_lightning: 1.2.3
nibabel          : 3.2.1
seaborn          : 0.11.1
torchio          : 0.18.29
torch            : 1.8.0
nilearn          : 0.7.1
numpy            : 1.20.1
natsort          : 7.1.1
skimage          : 0.18.1

CUDA Version:  11.1


## Variables & Definitions Setup

In [6]:
# Set up directories
data_dir = pathlib.Path(os.environ['DATA_DIR']) / "hcp"
assert data_dir.exists()
write_data_dir = pathlib.Path(os.environ['WRITE_DATA_DIR']) / "hcp"
assert write_data_dir.exists()
results_dir = pathlib.Path(os.environ['RESULTS_DIR'])
assert results_dir.exists()

### Global Function & Class Definitions

### Global Parameters

In [7]:
bvals_range = (900, 1100)

## Data Loading

In [8]:
# Find data directories for each subject.
subj_dirs: dict = dict()


selected_ids = [
    '397154',
    '224022',
    '140117',
    '751348',
    '894774',
    '156637',
    '227432',
    '303624',
    '185947',
    '810439',
    '753251',
    '644246',
    '141422',
    '135528',
    '103010',
    '700634',
]

## Sub-set the chosen participants for dev and debugging!
selected_ids = random.sample(selected_ids, 6)
warnings.warn(
    "WARNING: Sub-selecting participants for dev and debugging. "
    + f"Subj IDs selected: {selected_ids}"
)
##

selected_ids = natsorted(list(map(lambda s: int(s), selected_ids)))

for subj_id in selected_ids:
    subj_dirs[subj_id] = data_dir / f"{subj_id}/T1w/Diffusion"
    assert subj_dirs[subj_id].exists()
subj_dirs



{156637: PosixPath('/srv/data/pitn/hcp/156637/T1w/Diffusion'),
 397154: PosixPath('/srv/data/pitn/hcp/397154/T1w/Diffusion'),
 644246: PosixPath('/srv/data/pitn/hcp/644246/T1w/Diffusion'),
 751348: PosixPath('/srv/data/pitn/hcp/751348/T1w/Diffusion'),
 894774: PosixPath('/srv/data/pitn/hcp/894774/T1w/Diffusion')}

The 90 scans are taken from the b=1000 s/mm^2 *only*. To find those, sub-select with the $500 < bvals < 1500$, or roughly thereabout. A b-val of $995$ or $1005$ still counts as a b=1000 (is this the direction encoding? Or maybe the difference is where the gradient comes in?).

In [9]:
# Import all image data into a sequence of `torchio.Subject` objects.
subj_data: dict = dict()

# For more clearly designating the return values of a reader function given to
# the `torchio.Image` object.
ReaderOutput = collections.namedtuple('ReaderOutput', ['dwi', 'affine'])
def selection_reader(f_dwi, idx):
    # A float32 is the smallest representation that doesn't lose data.
    img = nib.load(f_dwi)
    
    # Load entire image into memory, then slice that full image.
    # <https://nipy.org/nibabel/images_and_memory.html#saving-time-and-memory>
    # Fancy indexing isn't supported, aka, anything besides a single integer index.
    sliced_img = img.get_fdata()[..., idx]
    affine = img.affine.copy()
    # Hopefuly remove any refs to the full-size array.
    del img
    return ReaderOutput(
        dwi=torch.from_numpy(sliced_img.astype(np.float32)),
        affine=torch.from_numpy(affine.astype(np.float32))
    )

for subj_id, subj_dir in subj_dirs.items():
    # Sub-select volumes with only bvals in a certain range (e.x. bvals ~= 1000 mm/s^2)
    bvals = torch.as_tensor(np.loadtxt(subj_dir/"bvals").astype(int))
    idx_to_keep = np.where((bvals >= bvals_range[0]) & (bvals <= bvals_range[1]))
    bvals = bvals[idx_to_keep]
    bvecs=torch.as_tensor(np.loadtxt(subj_dir/"bvecs"))[:, idx_to_keep]
    # Create a custom reader function that will sub-set the full (~288 for HCP) data 
    # points down to only the selected bvals range.
    # This will still be lazily loaded, while not requiring as much memory per subject.
    partial_reader = functools.partial(selection_reader, idx=idx_to_keep)
    dwi = torchio.ScalarImage(
        subj_dir/"data.nii.gz",
        bvals=bvals,
        bvecs=bvecs,
        reader=partial_reader
    )
    grad = torchio.ScalarImage(subj_dir/"grad_dev.nii.gz")
    brain_mask = torchio.LabelMap(subj_dir/"nodif_brain_mask.nii.gz", type=torchio.LABEL)
    # The brain mask is binary.
    
    brain_mask.set_data(brain_mask.data.bool())
    subject_dict = torchio.Subject(
        subj_id=subj_id,
        dwi=dwi,
        grad=grad,
        brain_mask=brain_mask
    )
    subj_data[subj_id] = subject_dict


In [44]:
subj_dataset = torchio.SubjectsDataset(list(subj_data.values()), load_getitem=False)

## Model Training

In [73]:
# Patch parameters
downsample_factor = 2
batch_size = 12
# 6 channels for the 6 DTI components
channels = 6

# Input patch parameters
h_in = 11
w_in = 11
d_in = 11
input_patch_shape = (channels, h_in, w_in, d_in)
# Output patch parameters
h_out = downsample_factor * h_in
w_out = downsample_factor * w_in
d_out = downsample_factor * d_in
unshuffled_channels_out = channels * downsample_factor**3
# Output before shuffling
unshuffled_output_patch_shape = (unshuffled_channels_out, h_in, w_in, d_in)
# Output shape after shuffling.
output_patch_shape = (channels, h_out, w_out, d_out)

### Model Definition

In [35]:
# Define pytorch-lightning module.
class DIQTSystem(pl.LightningModule):

    def __init__(self, channels_in, channels_out):
        super().__init__()

        self.channels_in = channels_in
        self.channels_out = channels_out
        # Parameters
        # Network parameters
        self.num_revnet_layers = 4
        self.net = pitn.models.ESPCN_RN(
            no_RevNet_layers=self.num_revnet_layers,
            no_chans_in=self.channels_in,
            no_chans_out=self.channels_out,
            memory_efficient=True
        )

        ## Training parameters
        self.max_epochs = 100
        self.lr = 10e-4
        self.loss = torch.nn.MSELoss(reduction='mean')

    def forward(self, x):
        y = self.net(x)
        return y
    
    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass

    def test_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        return optimizer

### Set Up Patch-Based Data Loaders

In [74]:
# Data train/validation/test split
test_percent = 0.2
train_percent = 1 - test_percent
val_percent = 0.1

num_subjs = len(subj_dataset)
num_test_subjs = int(np.ceil(num_subjs * test_percent))
num_train_subjs = num_subjs - num_test_subjs
subj_list = subj_dataset.dry_iter()
# Randomly shuffle the list of subjects, then choose the first `num_test_subjs` subjects
# for testing.
random.shuffle(subj_list)
test_dataset = torchio.SubjectsDataset(subj_list[:num_test_subjs], load_getitem=False)
# Choose the remaining for training/validation.
subj_list = subj_list[num_test_subjs:]
train_dataset = torchio.SubjectsDataset(subj_list, load_getitem=False)

# Training patch sampler, random across all patches of all volumes.
train_sampler = torchio.LabelSampler((h_in, w_in, d_in), 'brain_mask', {0: 0, 1:1})
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    sampler=train_sampler,
    num_workers=4,
    pin_memory=True
)
# Test samplers must be dynamically created during testing.

print("Test subject(s) IDs: ", [s.subj_id for s in test_dataset.dry_iter()])
print("Training subject(s) IDs: ", [s.subj_id for s in train_dataset.dry_iter()])

Test subject(s) IDs:  [644246]
Training subject(s) IDs:  [894774, 751348, 397154, 156637]


### Training

In [None]:
model = DIQTSystem(channels_in=channels, channels_out=unshuffled_channels_out)
trainer = pl.Trainer(gpus=1, progress_bar_refresh_rate=10)
trainer.fit(model, train_loader)

## Model Evaluation

### Testing

### Visualization