# Goal

Nyul-Udupa histogram rescaling

1. compute landmarks for all
2. standard scale = avg landmarks (total sum div by total #inputs)

# Imports

In [44]:
# NYU
code_src    = "/gpfs/home/gologr01"
data_src    = "/gpfs/data/oermannlab/private_data/DeepPit"

In [45]:
# UMich 
# code src: "/home/labcomputer/Desktop/Rachel"
# data src: "../../../../..//media/labcomputer/e33f6fe0-5ede-4be4-b1f2-5168b7903c7a/home/rachel/"

In [46]:
import os

# Paths to (1) code (2) data (3) saved models (4) saved metadata
deepPit_src = f"{code_src}/DeepPit"
obelisk_src = f"{code_src}/OBELISK"

# saved models, dset metadata
model_src  = f"{data_src}/saved_models"
dsetmd_src = f"{data_src}/saved_dset_metadata"

# dsets
dsets_src    = f"{data_src}/PitMRdata"

# key,val = dset_name, path to top level dir
dset_dict = {
    "ABIDE"                  : f"{dsets_src}/ABIDE",
    "ABVIB"                  : f"{dsets_src}/ABVIB/ABVIB",
    "ADNI1_Complete_1Yr_1.5T": f"{dsets_src}/ADNI/ADNI1_Complete_1Yr_1.5T/ADNI",
    "AIBL"                   : f"{dsets_src}/AIBL/AIBL",
    "ICMB"                   : f"{dsets_src}/ICMB/ICBM",
    "PPMI"                   : f"{dsets_src}/PPMI/PPMI",
}

# print
print("Folders in dset src: ", end=""); print(*os.listdir(dsets_src), sep=", ")

Folders in dset src: ICMB, ABVIB (1).zip, central.xnat.org, ADNI, PPMI, Oasis_long, samir_labels, ACRIN-FMISO-Brain, LGG-1p19qDeletion, REMBRANDT, AIBL, CPTAC-GBM, TCGA-GBM, TCGA-LGG, ABVIB, ABIDE, AIBL.zip


In [47]:
from fastai.vision.core import *

In [48]:
# imports
from transforms import AddChannel, Iso, PadSz

# Utilities
import os
import sys
import time
import pickle
from pathlib import Path

# regex
from re import search

# Input IO
import SimpleITK as sitk
import meshio

# Numpy and Pandas
import numpy as np
import pandas as pd
from pandas import DataFrame as DF

# Fastai + distributed training
from fastai import *
from fastai.torch_basics import *
from fastai.basics import *
from fastai.distributed import *

# PyTorch
from torchvision.models.video import r3d_18
from fastai.callback.all import SaveModelCallback
from torch import nn

# Obelisk
sys.path.append(deepPit_src)
sys.path.append(obelisk_src)

# OBELISK
from utils import *
from models import obelisk_visceral, obeliskhybrid_visceral

# 3D extension to FastAI
# from faimed3d.all import *

# Helper functions
from helpers.preprocess import get_data_dict, paths2objs, folder2objs, seg2mask, mask2bbox, print_bbox, get_bbox_size, print_bbox_size
from helpers.general import sitk2np, np2sitk, print_sitk_info, round_tuple, lrange, lmap, get_roi_range, numbers2groups
from helpers.viz import viz_axis

In [49]:
fnames = []
for dset_name in ("AIBL", "ABVIB", "ICMB", "PPMI"):
    dset_src  = dset_dict[dset_name]
    with open(f"{dsetmd_src}/{dset_name}_fnames.txt", "rb") as f:
        fnames.append(pickle.load(f))

# flatten
fnames = [x for dset in fnames for x in dset]

In [50]:
print(len(fnames))

3411


In [51]:
corrected = []
uncorrected = []
multiple    = []

def is_corrected(f):
    nii_paths = glob.glob(f"{f}/*corrected_n4.nii")
    
    if len(nii_paths) == 1:
        corrected.append(nii_paths[0])
        return True
    
    if len(nii_paths) == 0: 
        uncorrected.append(f)
        return False
    
    if len(nii_paths) > 1: 
        multiple.append(f)
        return True  
                
for f in fnames:
    is_corrected(f)
    
print(f"Corrected: {len(corrected)}, TODO: {len(uncorrected)}, Dupl: {len(multiple)}")

Corrected: 3284, TODO: 7, Dupl: 120


In [52]:
fnames[0], corrected[0]

('/gpfs/data/oermannlab/private_data/DeepPit/PitMRdata/AIBL/AIBL/338/MPRAGE_SAG_ISO_p2_ND/2012-10-06_11_32_11.0/S231118',
 '/gpfs/data/oermannlab/private_data/DeepPit/PitMRdata/AIBL/AIBL/326/MPRAGE_ADNI_confirmed/2013-09-15_08_42_36.0/S236430/corrected_n4.nii')

# Get chunk

In [53]:
import os

try:
    taskid = int(os.getenv('SLURM_ARRAY_TASK_ID'))
except:
    taskid = 0
    
n_total = len(corrected)

chunk_len = 50    
chunks    = [range(i,min(i+chunk_len, n_total)) for i in range(0, n_total, chunk_len)]

print(f"N_chunks = {len(chunks)}")
# print(f"Array Task ID: {taskid}")
# print(f"Array ID: {os.getenv('SLURM_ARRAY_TASK_ID')}")
# print(f"Job ID: {os.getenv('SLURM_JOB_ID')}")
#print(*chunks, sep="\n")

task_chunk = chunks[taskid]

N_chunks = 66


# Transform

## from FAIMED3D 02_preprocessing

In [18]:
# from FAIMED3D 02_preprocessing

Piecewise linear histogram matching
[1] N. Laszlo G and J. K. Udupa, “On Standardizing the MR Image Intensity Scale,” Magn. Reson. Med., vol. 42, pp. 1072–1081, 1999.

[2] M. Shah, Y. Xiao, N. Subbanna, S. Francis, D. L. Arnold, D. L. Collins, and T. Arbel, “Evaluating intensity normalization on MRIs of human brain with multiple sclerosis,” Med. Image Anal., vol. 15, no. 2, pp. 267–282, 2011.

Implementation adapted from: https://github.com/jcreinhold/intensity-normalization, ported to pytorch (no use of numpy works in cuda).

In contrast to hist_scaled, the piecewise linear histogram matching need pre-specified values for new scale and landmarks. It should be used to normalize a whole dataset.

In [29]:
from torch import Tensor

In [30]:
def get_percentile(t, q):
    """
    Return the ``q``-th percentile of the flattened input tensor's data.

    CAUTION:
     * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.
     * Values are not interpolated, which corresponds to
       ``numpy.percentile(..., interpolation="nearest")``.

    :param t: Input tensor.
    :param q: Percentile to compute, which must be between 0 and 100 inclusive.
    :return: Resulting value (float).

    This function is twice as fast as torch.quantile and has no size limitations
    """
    # Note that ``kthvalue()`` works one-based, i.e. the first sorted value
    # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly,
    # so that ``round()`` returns an integer, even if q is a np.float32.

    k = 1 + round(.01 * float(q) * (t.numel() - 1))
    result = t.view(-1).kthvalue(k)[0].item()

    return result

In [31]:
def get_landmarks(t: torch.Tensor, percentiles: torch.Tensor)->torch.Tensor:
    """
    Returns the input's landmarks.

    :param t (torch.Tensor): Input tensor.
    :param percentiles (torch.Tensor): Peraentiles to calculate landmarks for.
    :return: Resulting landmarks (torch.tensor).
    """
    return tensor([get_percentile(t, perc.item()) for perc in percentiles])

In [38]:
def find_one_landmark(input_image, i_min=1, i_max=99, i_s_min=1, i_s_max=100, l_percentile=10, u_percentile=90, step=10):
    """
    determine the standard scale for the set of images
    Args:
        inputs (list or L): set of TensorDicom3D objects which are to be normalized
        i_min (float): minimum percentile to consider in the images
        i_max (float): maximum percentile to consider in the images
        i_s_min (float): minimum percentile on the standard scale
        i_s_max (float): maximum percentile on the standard scale
        l_percentile (int): middle percentile lower bound (e.g., for deciles 10)
        u_percentile (int): middle percentile upper bound (e.g., for deciles 90)
        step (int): step for middle percentiles (e.g., for deciles 10)
    Returns:
        standard_scale (np.ndarray): average landmark intensity for images
        percs (np.ndarray): array of all percentiles used
    """
    percs = torch.cat([torch.tensor([i_min]),
                       torch.arange(l_percentile, u_percentile+1, step),
                       torch.tensor([i_max])], dim=0)
   
    mask_data = input_image > input_image.mean()
    masked = input_image[mask_data]
    landmarks = get_landmarks(masked, percs)
    min_p = get_percentile(masked, i_min)
    max_p = get_percentile(masked, i_max)
    new_landmarks = landmarks.interp_1d(torch.FloatTensor([i_s_min, i_s_max]),
                                        torch.FloatTensor([min_p, max_p]))
    return new_landmarks

In [33]:
def find_sum_landmarks(inputs, i_min=1, i_max=99, i_s_min=1, i_s_max=100, l_percentile=10, u_percentile=90, step=10):
    """
    determine the standard scale for the set of images
    Args:
        inputs (list or L): set of TensorDicom3D objects which are to be normalized
        i_min (float): minimum percentile to consider in the images
        i_max (float): maximum percentile to consider in the images
        i_s_min (float): minimum percentile on the standard scale
        i_s_max (float): maximum percentile on the standard scale
        l_percentile (int): middle percentile lower bound (e.g., for deciles 10)
        u_percentile (int): middle percentile upper bound (e.g., for deciles 90)
        step (int): step for middle percentiles (e.g., for deciles 10)
    Returns:
        standard_scale (np.ndarray): average landmark intensity for images
        percs (np.ndarray): array of all percentiles used
    """
    percs = torch.cat([torch.tensor([i_min]),
                       torch.arange(l_percentile, u_percentile+1, step),
                       torch.tensor([i_max])], dim=0)
    standard_scale = torch.zeros(len(percs))

    for input_image in inputs:
        mask_data = input_image > input_image.mean()
        masked = input_image[mask_data]
        landmarks = get_landmarks(masked, percs)
        min_p = get_percentile(masked, i_min)
        max_p = get_percentile(masked, i_max)
        new_landmarks = landmarks.interp_1d(torch.FloatTensor([i_s_min, i_s_max]),
                                            torch.FloatTensor([min_p, max_p]))
        standard_scale += new_landmarks
    #standard_scale = standard_scale / len(inputs)
    return standard_scale, percs

In [34]:
def path2tensor(mr_path):
    mr = sitk.ReadImage(mr_path, sitk.sitkFloat32)
    return torch.transpose(torch.tensor(sitk.GetArrayFromImage(mr)), 0, 2)

# Process

In [35]:
corrected_chunk = [corrected[i] for i in task_chunk]
print(len(corrected_chunk), *corrected_chunk, sep="\n")

5
/gpfs/data/oermannlab/private_data/DeepPit/PitMRdata/AIBL/AIBL/326/MPRAGE_ADNI_confirmed/2013-09-15_08_42_36.0/S236430/corrected_n4.nii
/gpfs/data/oermannlab/private_data/DeepPit/PitMRdata/AIBL/AIBL/1413/MPRAGE_ADNI_confirmed/2013-06-29_11_43_08.0/S236380/corrected_n4.nii
/gpfs/data/oermannlab/private_data/DeepPit/PitMRdata/AIBL/AIBL/1595/MPRAGE_ADNI_confirmed/2013-12-12_12_07_53.0/S235320/corrected_n4.nii
/gpfs/data/oermannlab/private_data/DeepPit/PitMRdata/AIBL/AIBL/1460/MPRAGE_ADNI_confirmed/2013-06-07_16_11_03.0/S232401/corrected_n4.nii
/gpfs/data/oermannlab/private_data/DeepPit/PitMRdata/AIBL/AIBL/773/MPRAGE_SAG_ISO_p2/2012-05-03_10_38_50.0/S236142/corrected_n4.nii


In [42]:
# from FAIMED3D 02_preprocessing
# and https://simpleitk.readthedocs.io/en/master/link_N4BiasFieldCorrection_docs.html

count = 0
for mr_path in corrected_chunk: 
    
    #start = time.time()
    try:
        # print
        print(count, mr_path, flush=True)
        count += 1

        # Read in image
        inputTensor = path2tensor(mr_path) 

        # Get landmarks
        landmark = find_one_landmark(inputTensor)

        # write image
        corrected_fn = f"{Path(mr_path).parent}/landmark.pt"
        torch.save(landmark, corrected_fn)
    except Exception as e:
        #raise(e)
        print("Skipped: ", mr_path)
    
    #elapsed = time.time() - start
    #print(f"Elapsed: {elapsed:0.2f} s")

print("Done.")

0 /gpfs/data/oermannlab/private_data/DeepPit/PitMRdata/AIBL/AIBL/326/MPRAGE_ADNI_confirmed/2013-09-15_08_42_36.0/S236430/corrected_n4.nii
Elapsed: 0.69 s
1 /gpfs/data/oermannlab/private_data/DeepPit/PitMRdata/AIBL/AIBL/1413/MPRAGE_ADNI_confirmed/2013-06-29_11_43_08.0/S236380/corrected_n4.nii
Elapsed: 0.76 s
Done.
