In [1]:
"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import pathlib
import random

import h5py
from torch.utils.data import Dataset

import sys
sys.path.insert(0,'../../common/')
sys.path.insert(0,'/home/ubuntu/Downloads/mri_recon/robustness-CS/bart-0.5.00/python/')
import bart

import subsample
import pathlib

import numpy as np
from common import utils
from matplotlib import pyplot as plt
import torch
try:
    import nibabel as nib
except:
    ! pip install nibabel 
    import nibabel as nib

from data import transforms, mri_data

class SliceData(Dataset):
    """
    A PyTorch Dataset that provides access to MR image slices.
    """

    def __init__(self, root,challenge, sample_rate=1):
        """
        Args:
            root (pathlib.Path): Path to the dataset.
            transform (callable): A callable object that pre-processes the raw data into
                appropriate form. The transform function should take 'kspace', 'target',
                'attributes', 'filename', and 'slice' as inputs. 'target' may be null
                for test data.
            challenge (str): "singlecoil" or "multicoil" depending on which challenge to use.
            sample_rate (float, optional): A float between 0 and 1. This controls what fraction
                of the volumes should be loaded.
        """
        self.mask_func = subsample.RandomMaskFunc(center_fractions=[0.08, 0.04], accelerations=[4, 8])
        
        if challenge not in ('singlecoil', 'multicoil'):
            raise ValueError('challenge should be either "singlecoil" or "multicoil"')

        #self.transform = transform
        self.recons_key = 'reconstruction_esc' if challenge == 'singlecoil' \
            else 'reconstruction_rss'

        self.examples = []
        files = list(pathlib.Path(root).iterdir())
        if sample_rate < 1:
            random.shuffle(files)
            num_files = round(len(files) * sample_rate)
            files = files[:num_files]
        for fname in sorted(files):
            
            img = nib.load(fname)
            img_data = img.get_data()
            img_data_arr = np.asarray(img_data)
            img_data_arr = img_data_arr.astype(np.float32)
            img_data_torch = torch.from_numpy(img_data_arr)
            kspace = np.zeros(img_data_torch.shape,dtype='complex')
            kspace = torch.from_numpy(kspace)
            for i in range(img_data_torch.shape[2]):
                fft_img = torch.fft.fft2(img_data_torch[:,:,i],norm="ortho")
                kspace[:,:,i] = torch.fft.fftshift(fft_img ) 
            #kspace = h5py.File(fname, 'r')['kspace']
            num_slices = kspace.shape[2]
            self.masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func)
            self.examples += [(self.masked_kspace[:,:,slice],fname, slice) for slice in range(num_slices)]
    ''''
    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        masked_kspace,fname, slice = self.examples[i]
        with h5py.File(fname, 'r') as data:
            kspace = data['kspace'][slice]
            target = data[self.recons_key][slice] if self.recons_key in data else None
            return self.transform(kspace, target, data.attrs, fname.name, slice)
        
    '''
    def __getlist__(self):
        return self.examples
    def __getmaskedkspace__(self):
        return self.masked_kspace

def save_reconstructions(reconstructions, out_dir):
    """
    Saves the reconstructions from a model into h5 files that is appropriate for submission
    to the leaderboard.

    Args:
        reconstructions (dict[str, np.array]): A dictionary mapping input filenames to
            corresponding reconstructions (of shape num_slices x height x width).
        out_dir (pathlib.Path): Path to the output directory where the reconstructions
            should be saved.
    """
    for fname, recons in reconstructions.items():
        with h5py.File(out_dir + fname, 'w') as f:
            f.create_dataset('reconstruction', data=recons)
            
def save_outputs(outputs, output_path):
    reconstructions = defaultdict(list)
    for fname, slice, pred in outputs:
        reconstructions[fname].append((slice, pred))
    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }
    save_reconstructions(reconstructions, output_path)


In [2]:
s = SliceData('/home/ubuntu/Downloads/dataset/', 'singlecoil')
data = (s.__getlist__())


* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0


In [3]:
masked_kspace = s.__getmaskedkspace__()

In [12]:
masked_kspace.shape
masked_kspace_cd = masked_kspace.unsqueeze(0)
print(masked_kspace_cd.shape)
masked_kspace_cd = masked_kspace_cd.permute(3,0, 1, 2).unsqueeze(0)
print(masked_kspace_cd.shape)
masked_kspace_cd = masked_kspace_cd.squeeze(0)
print(masked_kspace_cd.shape)

torch.Size([1, 256, 256, 176])
torch.Size([1, 176, 1, 256, 256])
torch.Size([176, 1, 256, 256])


In [19]:
import pathlib
import random

import numpy as np
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TestTubeLogger
from torch.nn import functional as F
from torch.optim import RMSprop

from common.args import Args
from common.subsample import create_mask_for_mask_type
from data import transforms
#from mri_model import MRIModel
from unet_model import UnetModel


from pytorch_lightning.callbacks.early_stopping import EarlyStopping

early_stop_callback = EarlyStopping(
   monitor='val_accuracy',
   min_delta=0.00,
   patience=3,
   verbose=False,
   mode='max'
)

In [20]:
from mri_model import MRIModel

In [21]:
import torch
torch.cuda.is_available()

False

In [22]:

class UnetMRIModel(MRIModel):
    def __init__(self, hparams):
        super().__init__(hparams)
        self.unet = UnetModel(
            in_chans=1,
            out_chans=1,
            chans=hparams.num_chans,
            num_pool_layers=hparams.num_pools,
            drop_prob=hparams.drop_prob
        )

    def forward(self, input):
        return self.unet(input.unsqueeze(1)).squeeze(1)

    def training_step(self, batch, batch_idx):
        input, target, mean, std, _, _ = batch
        output = self.forward(input)
        loss = F.l1_loss(output, target)
        logs = {'loss': loss.item()}
        return dict(loss=loss, log=logs)

    def validation_step(self, batch, batch_idx):
        input, target, mean, std, fname, slice = batch
        output = self.forward(input)
        mean = mean.unsqueeze(1).unsqueeze(2)
        std = std.unsqueeze(1).unsqueeze(2)
        return {
            'fname': fname,
            'slice': slice,
            'output': (output * std + mean).cpu().numpy(),
            'target': (target * std + mean).cpu().numpy(),
            'val_loss': F.l1_loss(output, target),
        }

    def test_step(self, batch, batch_idx):
        input, _, mean, std, fname, slice = batch
        output = self.forward(input)
        mean = mean.unsqueeze(1).unsqueeze(2)
        std = std.unsqueeze(1).unsqueeze(2)
        return {
            'fname': fname,
            'slice': slice,
            'output': (output * std + mean).cpu().numpy(),
        }

    def configure_optimizers(self):
        optim = RMSprop(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optim, self.hparams.lr_step_size, self.hparams.lr_gamma)
        return [optim], [scheduler]

    def train_data_transform(self):
        mask = create_mask_for_mask_type(self.hparams.mask_type, self.hparams.center_fractions,
                                         self.hparams.accelerations)
        return DataTransform(self.hparams.resolution, self.hparams.challenge, mask, use_seed=False)

    def val_data_transform(self):
        mask = create_mask_for_mask_type(self.hparams.mask_type, self.hparams.center_fractions,
                                         self.hparams.accelerations)
        return DataTransform(self.hparams.resolution, self.hparams.challenge, mask)

    def test_data_transform(self):
        return DataTransform(self.hparams.resolution, self.hparams.challenge)

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--num-pools', type=int, default=4, help='Number of U-Net pooling layers')
        parser.add_argument('--drop-prob', type=float, default=0.0, help='Dropout probability')
        parser.add_argument('--num-chans', type=int, default=32, help='Number of U-Net channels')
        parser.add_argument('--batch-size', default=16, type=int, help='Mini batch size')
        parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
        parser.add_argument('--lr-step-size', type=int, default=40,
                            help='Period of learning rate decay')
        parser.add_argument('--lr-gamma', type=float, default=0.1,
                            help='Multiplicative factor of learning rate decay')
        parser.add_argument('--weight-decay', type=float, default=0.,
                            help='Strength of weight decay regularization')
        parser.add_argument('--mask_type',default='random')
        return parser


In [23]:
class Args():
    def __init__(self,mode,challenge,exp_dir,exp,mask_type,num_epochs,gpus,action='store_true'):
        self.mode=mode
        self.num_epochs=num_epochs
        self.gpus=gpus
        self.exp = exp
        self.exp_dir=exp_dir
        self.seed = 42
        self.challenge = challenge
        self.mask_type = mask_type
args = Args('train','singlecoil','/home/ubuntu/Downloads/dataset/','unet','equispaced',50,0)

In [24]:
def create_trainer(args, logger):
    return Trainer(
        #num_nodes=1,
        logger=logger,
        default_root_dir=args.exp_dir,
        checkpoint_callback=True,
        max_epochs=args.num_epochs,
        gpus=args.gpus,
        distributed_backend='ddp',
        check_val_every_n_epoch=1,
        val_check_interval=1.,
        callbacks=[early_stop_callback]
        )

In [25]:
import random 
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
load_version = None
logger = TestTubeLogger(save_dir=args.exp_dir, name=args.exp, version=load_version)
trainer = create_trainer(args, logger)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [27]:
class hparams():
    def __init__(self,in_chans, out_chans, num_chans, num_pools, drop_prob):
        self.num_chans= num_chans
        self.num_pools = num_pools
        self.drop_prob= drop_prob
        self.in_chans = in_chans
        self.out_chans = out_chans
hparams1 = hparams(1,1,32,4,0)

In [None]:
hparams1.__dict__

In [49]:
model = UnetModel(
            in_chans=1,
            out_chans=1,
            chans=hparams1.num_chans,
            num_pool_layers=hparams1.num_pools,
            drop_prob=hparams1.drop_prob
        )

In [50]:
target = model.forward(m_tensor)

RuntimeError: expected scalar type Byte but found Float

In [44]:
masked_kspace_cd = torch.abs(masked_kspace_cd).type(torch.double)

In [48]:
m = np.ones((320,320,176),dtype=np.uint8)
m_tensor = torch.from_numpy(m)
m_tensor = m_tensor.unsqueeze(0)
print(m_tensor.shape)
m_tensor= m_tensor.permute(3,0, 1, 2).unsqueeze(0)
print(m_tensor.shape)
m_tensor= m_tensor.squeeze(0)
print(m_tensor.shape)

torch.Size([1, 320, 320, 176])
torch.Size([1, 176, 1, 320, 320])
torch.Size([176, 1, 320, 320])


In [51]:
m_tensor

tensor([[[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]],


        [[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]],


        [[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]],


        ...,


        [[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]],


        [[[1, 1, 1,  ..., 1, 1, 1],
         