In [1]:
"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import pathlib
import sys
from collections import defaultdict

import numpy as np
import torch
from torch.utils.data import DataLoader

from common.args import Args
from common.subsample import MaskFunc
from common.utils import save_reconstructions
from data import transforms
from data.mri_data import SliceData
from anet_model import AnetModel

from utils import reducedimension, unitize, kspaceto2dimage, transformshape, transformback

from data import transforms
class DataTransform:
    """
    Data Transformer for training DAE.
    """

    def __init__(self, mask_func, resolution, reduce, polar, use_seed=True):
        self.mask_func = mask_func
        self.use_seed = use_seed
        self.resolution = resolution
        self.reduce = reduce
        self.polar = polar

    def __call__(self, kspace, target, challenge, fname, slice_index):
        original_kspace = transforms.to_tensor(kspace)
        
        if self.reduce:
            original_kspace = reducedimension(original_kspace, self.resolution)
        
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(original_kspace, self.mask_func, seed)

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)

        target = transforms.to_tensor(target)
        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)

        if self.polar:
            original_kspace = cartesianToPolar(original_kspace)
            masked_kspace = cartesianToPolar(masked_kspace)
        return original_kspace, masked_kspace, mask, target, fname, slice_index, mean, std


def create_data_loaders(args):
    mask_func = None
    if args.mask_kspace:
        mask_func = MaskFunc(args.center_fractions, args.accelerations)
    data = SliceData(
        root=args.data_path / (str(args.challenge) + "_" + str(args.data_split)),
        transform=DataTransform(mask_func, args.resolution, True, False),
        sample_rate=1.,
        challenge=args.challenge
    )
    data_loader = DataLoader(
        dataset=data,
        batch_size=args.batch_size,
        num_workers=4,
        pin_memory=True,
    )
    return data_loader


def load_model(checkpoint_file):
    checkpoint = torch.load(checkpoint_file)
    args = checkpoint['args']
    # model = UnetModel(1, 1, args.num_chans, args.num_pools, args.drop_prob).to(args.device)
    model = AnetModel(in_chans=2, out_chans=2, chans=args.num_chans, num_pool_layers=args.num_pools, drop_prob=args.drop_prob).to(args.device)
    # if args.data_parallel:
        # model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint['model'])
    return model


def run_unet(args, model, data_loader):
    model.eval()
    reconstructions = defaultdict(list)

    count = 0

    with torch.no_grad():
        for (original_kspace, masked_kspace, mask, target, fnames, slices, mean, std) in data_loader:
            
            input, divisor = unitize(masked_kspace)

            output = model(transformshape(input).to(args.device)).to('cpu').squeeze(1)
            print(output.shape)
            recons = kspaceto2dimage(transformback(output), False, cropping=False, resolution=320)
            
            for i in range(recons.shape[0]):
                recons[i] = recons[i] * std[i] + mean[i]
                reconstructions[fnames[i]].append((slices[i].numpy(), recons[i].numpy()))
                count += 1
                if count % 100 == 0:
                    print(count)
    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }
    return reconstructions


def main(args):
    data_loader = create_data_loaders(args)
    model = load_model(args.checkpoint)
    reconstructions = run_unet(args, model, data_loader)
    save_reconstructions(reconstructions, args.out_dir)


def create_arg_parser():
    parser = Args()
    parser.add_argument('--mask-kspace', action='store_true',
                        help='Whether to apply a mask (set to True for val data and False '
                             'for test data')
    parser.add_argument('--data-split', choices=['val', 'test'], required=True,
                        help='Which data partition to run on: "val" or "test"')
    parser.add_argument('--checkpoint', required=True,
                        help='Path to the U-Net model')
    parser.add_argument('--out-dir', type=pathlib.Path, required=True,
                        help='Path to save the reconstructions to')
    parser.add_argument('--batch-size', default=16, type=int, help='Mini-batch size')
    parser.add_argument('--device', type=str, default='cuda', help='Which device to run on')
    return parser


if __name__ == '__main__':
    args = create_arg_parser().parse_args(['--data-path', '/home/sumeet_ranka47_gmail_com/data', '--data-split', 'val', '--checkpoint', 'kspace_unitize_fixed_cartesian/best_model.pt', '--challenge', 'singlecoil', '--out-dir', 'reconstructions_val', '--mask-kspace'])
    main(args)

torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
100
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
200
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
300
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
400
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
torch.Size([16, 2, 320, 320])
500
torch.Size([16, 2, 3