In [1]:
import sys
sys.path.insert(0,'/root/bart-0.5.00/python/')


import logging
import multiprocessing
import pathlib
import random
import time
from collections import defaultdict

import numpy as np
import torch

import bart
from common import utils
from common.args import Args
from common.subsample import create_mask_for_mask_type
from common.utils import tensor_to_complex_np
from data import transforms
from data.mri_data import SliceData

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class DataTransform:
    """
    Data Transformer that masks input k-space.
    """

    def __init__(self, mask_func):
        """
        Args:
            mask_func (common.subsample.MaskFunc): A function that can create a mask of
                appropriate shape.
        """
        self.mask_func = mask_func

    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array, optional): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                masked_kspace (torch.Tensor): Sub-sampled k-space with the same shape as kspace.
                fname (str): File name containing the current data item
                slice (int): The index of the current slice in the volume
        """
        kspace = transforms.to_tensor(kspace)
        seed = tuple(map(ord, fname))
        # Apply mask to raw k-space
        masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)
        return masked_kspace, fname, slice


def create_data_loader(args):
    dev_mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations)
    data = SliceData(
        root=args.data_path + str(f'{args.challenge}_val'),
        transform=DataTransform(dev_mask),
        challenge=args.challenge,
        sample_rate=args.sample_rate
    )
    return data


def cs_total_variation(args, kspace):
    """
    Run ESPIRIT coil sensitivity estimation and Total Variation Minimization based
    reconstruction algorithm using the BART toolkit.
    """

    if args.challenge == 'singlecoil':
        kspace = kspace.unsqueeze(0)
    kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0)
    kspace = tensor_to_complex_np(kspace)

    # Estimate sensitivity maps
    sens_maps = bart.bart(1, f'ecalib -d0 -m1', kspace)

    # Use Total Variation Minimization to reconstruct the image
    pred = bart.bart(
        1, f'pics -d0 -S -R T:7:0:{args.reg_wt} -i {args.num_iters}', kspace, sens_maps
    )
    pred = torch.from_numpy(np.abs(pred[0]))

    # Crop the predicted image to selected resolution if bigger
    smallest_width = min(args.resolution, pred.shape[-1])
    smallest_height = min(args.resolution, pred.shape[-2])
    return transforms.center_crop(pred, (smallest_height, smallest_width))


def run_model(i):
    masked_kspace, fname, slice = data[i]
    prediction = cs_total_variation(args, masked_kspace)
    return fname, slice, prediction


def main():
    if args.num_procs == 0:
        start_time = time.perf_counter()
        outputs = []
        for i in range(len(data)):
            outputs.append(run_model(i))
            save_outputs([run_model(i)], args.output_path)
        time_taken = time.perf_counter() - start_time
    else:
        with multiprocessing.Pool(args.num_procs) as pool:
            start_time = time.perf_counter()
            outputs = pool.map(run_model, range(len(data)))
            time_taken = time.perf_counter() - start_time
            save_outputs(outputs, args.output_path)
    logging.info(f'Run Time = {time_taken:}s')
    


import json

import h5py


def save_reconstructions(reconstructions, out_dir):
    """
    Saves the reconstructions from a model into h5 files that is appropriate for submission
    to the leaderboard.

    Args:
        reconstructions (dict[str, np.array]): A dictionary mapping input filenames to
            corresponding reconstructions (of shape num_slices x height x width).
        out_dir (pathlib.Path): Path to the output directory where the reconstructions
            should be saved.
    """
    for fname, recons in reconstructions.items():
        with h5py.File(out_dir + fname, 'w') as f:
            f.create_dataset('reconstruction', data=recons)
            
def save_outputs(outputs, output_path):
    reconstructions = defaultdict(list)
    for fname, slice, pred in outputs:
        reconstructions[fname].append((slice, pred))
    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }
    save_reconstructions(reconstructions, output_path)

In [3]:
class Args():
    def __init__(self,ch,path,rate,acc,cent,outpath,iters,reg,procs,maskt,seed,res):
        self.challenge = ch
        self.data_path = path
        self.sample_rate = rate
        self.accelerations = acc
        self.center_fractions = cent
        self.output_path = outpath
        self.num_iters = iters
        self.reg_wt = reg
        self.num_procs = procs
        self.mask_type = maskt
        self.seed = seed
        self.resolution = res
args = Args("multicoil","/hdd/",1,[4],[0.07],"/root/multires_deep_decoder/mri/FINAL/TV/",100,0.01,4,"random",42,320)

In [4]:
import os
import subprocess
import sys

os.environ['TOOLBOX_PATH'] = "/root/bart-0.5.00/" # visible in this process + all children
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

dataset = create_data_loader(args)

In [4]:
dataset.examples[80],len(dataset),len(dataset[0])

((PosixPath('/hdd/multicoil_val/file1000017.h5'), 7), 7135, 3)

In [6]:
this_data = []
prev_slicenu = -1
for i,d in enumerate(dataset):
    if dataset.examples[i][1] > prev_slicenu: 
        this_data.append(d)
        prev_slicenu = dataset.examples[i][1]
    else:
        data = this_data
        main()
        this_data = [d]
        prev_slicenu = 0
    if i == len(dataset) - 1:
        data = this_data
        main()

INFO:root:Run Time = 333.7682563047856s
INFO:root:Run Time = 340.6625652872026s
INFO:root:Run Time = 301.1880727428943s
INFO:root:Run Time = 274.9473853390664s
INFO:root:Run Time = 395.30600419454277s
INFO:root:Run Time = 338.08261120319366s
INFO:root:Run Time = 357.19963121786714s
INFO:root:Run Time = 415.9970267675817s
INFO:root:Run Time = 351.20377369225025s
INFO:root:Run Time = 335.272654408589s
INFO:root:Run Time = 320.07918173260987s
INFO:root:Run Time = 337.77712962403893s
INFO:root:Run Time = 350.6248935814947s
INFO:root:Run Time = 319.2754284348339s
INFO:root:Run Time = 339.5878656320274s
INFO:root:Run Time = 308.0043003074825s
INFO:root:Run Time = 358.6736814174801s
INFO:root:Run Time = 358.03209225833416s
INFO:root:Run Time = 369.3443151284009s
INFO:root:Run Time = 376.01454484649s
INFO:root:Run Time = 327.3354206569493s
INFO:root:Run Time = 311.5621940083802s
INFO:root:Run Time = 315.0435768086463s
INFO:root:Run Time = 325.7221032604575s
INFO:root:Run Time = 291.21864583529