In [None]:
import pathlib
import sys
from collections import defaultdict
import h5py
import numpy as np
import torch
import logging
import shutil
import time
import random
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F
import torchvision

In [None]:
args = {
    'seed': 42,
    'resolution': 320,
    'challenge': 'singlecoil',
    'data_path': pathlib.Path('/content/drive/MyDrive/Dataset'),
    'sample_rate': 1.,
    'accelerations': [4, 8],
    'center_fractions': [0.08, 0.04],

    'mask_kspace': False,
    'data_split': 'test',
    'checkpoint': pathlib.Path('/content/drive/MyDrive/checkpoints/best_model.pt'),
    'out_dir': pathlib.Path('/content/drive/MyDrive/reconstructions'),
    'batch_size': 16,
    'device': 'cuda'
}

In [None]:
def save_reconstructions(reconstructions, out_dir):
    out_dir.mkdir(exist_ok=True)
    for fname, recons in reconstructions.items():
        with h5py.File(out_dir / fname, 'w') as f:
            f.create_dataset('reconstruction', data=recons)

In [None]:
class MaskFunc:

  def __init__(self, center_fractions, accelerations):
    if len(center_fractions) != len(accelerations):
        raise ValueError('Number of center fractions should match number of accelerations')

    self.center_fractions = center_fractions
    self.accelerations = accelerations
    self.rng = np.random.RandomState()

  def __call__(self, shape, seed=None):
    if len(shape) < 3:
        raise ValueError('Shape should have 3 or more dimensions')

    self.rng.seed(seed)
    num_cols = shape[-2]

    choice = self.rng.randint(0, len(self.accelerations))
    center_fraction = self.center_fractions[choice]
    acceleration = self.accelerations[choice]

    # Create the mask
    num_low_freqs = int(round(num_cols * center_fraction))
    prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs)
    mask = self.rng.uniform(size=num_cols) < prob
    pad = (num_cols - num_low_freqs + 1) // 2
    mask[pad:pad + num_low_freqs] = True

    # Reshape the mask
    mask_shape = [1 for _ in shape]
    mask_shape[-2] = num_cols
    mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32))

    return mask

In [None]:
def to_tensor(data):
  if np.iscomplexobj(data):
      data = np.stack((data.real, data.imag), axis=-1)
  return torch.from_numpy(data)


def apply_mask(data, mask_func, seed=None):
  shape = np.array(data.shape)
  shape[:-3] = 1
  mask = mask_func(shape, seed)
  return torch.where(mask == 0, torch.Tensor([0]), data), mask


def fft2(data):
  assert data.size(-1) == 2
  data = ifftshift(data, dim=(-3, -2))
  data = torch.fft.fft(data, dim=2, norm='backward')
  data = fftshift(data, dim=(-3, -2))
  return data


def ifft2(data):
  assert data.size(-1) == 2
  data = ifftshift(data, dim=(-3, -2))
  data = torch.fft.ifft(data, dim=2, norm='backward')
  data = fftshift(data, dim=(-3, -2))
  return data


def complex_abs(data):
  assert data.size(-1) == 2
  return (data ** 2).sum(dim=-1).sqrt()


def root_sum_of_squares(data, dim=0):
  return torch.sqrt((data ** 2).sum(dim))


def center_crop(data, shape):
  assert 0 < shape[0] <= data.shape[-2]
  assert 0 < shape[1] <= data.shape[-1]
  w_from = (data.shape[-2] - shape[0]) // 2
  h_from = (data.shape[-1] - shape[1]) // 2
  w_to = w_from + shape[0]
  h_to = h_from + shape[1]
  return data[..., w_from:w_to, h_from:h_to]


def complex_center_crop(data, shape):
  assert 0 < shape[0] <= data.shape[-3]
  assert 0 < shape[1] <= data.shape[-2]
  w_from = (data.shape[-3] - shape[0]) // 2
  h_from = (data.shape[-2] - shape[1]) // 2
  w_to = w_from + shape[0]
  h_to = h_from + shape[1]
  return data[..., w_from:w_to, h_from:h_to, :]


def normalize(data, mean, stddev, eps=0.):
  return (data - mean) / (stddev + eps)


def normalize_instance(data, eps=0.):
  mean = data.mean()
  std = data.std()
  return normalize(data, mean, std, eps), mean, std


# Helper functions

def roll(x, shift, dim):
  if isinstance(shift, (tuple, list)):
      assert len(shift) == len(dim)
      for s, d in zip(shift, dim):
          x = roll(x, s, d)
      return x
  shift = shift % x.size(dim)
  if shift == 0:
      return x
  left = x.narrow(dim, 0, x.size(dim) - shift)
  right = x.narrow(dim, x.size(dim) - shift, shift)
  return torch.cat((right, left), dim=dim)


def fftshift(x, dim=None):
  if dim is None:
      dim = tuple(range(x.dim()))
      shift = [dim // 2 for dim in x.shape]
  elif isinstance(dim, int):
      shift = x.shape[dim] // 2
  else:
      shift = [x.shape[i] // 2 for i in dim]
  return roll(x, shift, dim)


def ifftshift(x, dim=None):
  if dim is None:
      dim = tuple(range(x.dim()))
      shift = [(dim + 1) // 2 for dim in x.shape]
  elif isinstance(dim, int):
      shift = (x.shape[dim] + 1) // 2
  else:
      shift = [(x.shape[i] + 1) // 2 for i in dim]
  return roll(x, shift, dim)

In [None]:
class SliceData(Dataset):

  def __init__(self, root, transform, challenge, sample_rate=1):
    if challenge not in ('singlecoil', 'multicoil'):
        raise ValueError('challenge should be either "singlecoil" or "multicoil"')

    self.transform = transform
    self.recons_key = 'reconstruction_esc' if challenge == 'singlecoil' \
        else 'reconstruction_rss'

    self.examples = []
    files = list(pathlib.Path(root).iterdir())
    if sample_rate < 1:
        random.shuffle(files)
        num_files = round(len(files) * sample_rate)
        files = files[:num_files]
    for fname in sorted(files):
        kspace = h5py.File(fname, 'r')['kspace']
        num_slices = kspace.shape[0]
        self.examples += [(fname, slice) for slice in range(num_slices)]

  def __len__(self):
    return len(self.examples)

  def __getitem__(self, i):
    fname, slice = self.examples[i]
    with h5py.File(fname, 'r') as data:
        kspace = data['kspace'][slice]
        target = data[self.recons_key][slice] if self.recons_key in data else None
        return self.transform(kspace, target, data.attrs, fname.name, slice)

In [None]:
class ConvBlock(nn.Module):

  def __init__(self, in_chans, out_chans, drop_prob):
      super().__init__()

      self.in_chans = in_chans
      self.out_chans = out_chans
      self.drop_prob = drop_prob

      self.layers = nn.Sequential(
          nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1),
          nn.InstanceNorm2d(out_chans),
          nn.ReLU(),
          nn.Dropout2d(drop_prob),
          nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1),
          nn.InstanceNorm2d(out_chans),
          nn.ReLU(),
          nn.Dropout2d(drop_prob)
      )

  def forward(self, input):
      return self.layers(input)

  def __repr__(self):
      return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans}, ' \
          f'drop_prob={self.drop_prob})'


class UnetModel(nn.Module):

  def __init__(self, in_chans, out_chans, chans, num_pool_layers, drop_prob):
      super().__init__()

      self.in_chans = in_chans
      self.out_chans = out_chans
      self.chans = chans
      self.num_pool_layers = num_pool_layers
      self.drop_prob = drop_prob

      self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
      ch = chans
      for i in range(num_pool_layers - 1):
          self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob)]
          ch *= 2
      self.conv = ConvBlock(ch, ch, drop_prob)

      self.up_sample_layers = nn.ModuleList()
      for i in range(num_pool_layers - 1):
          self.up_sample_layers += [ConvBlock(ch * 2, ch // 2, drop_prob)]
          ch //= 2
      self.up_sample_layers += [ConvBlock(ch * 2, ch, drop_prob)]
      self.conv2 = nn.Sequential(
          nn.Conv2d(ch, ch // 2, kernel_size=1),
          nn.Conv2d(ch // 2, out_chans, kernel_size=1),
          nn.Conv2d(out_chans, out_chans, kernel_size=1),
      )

  def forward(self, input):
      stack = []
      output = input
      # Apply down-sampling layers
      for layer in self.down_sample_layers:
          output = layer(output)
          stack.append(output)
          output = F.max_pool2d(output, kernel_size=2)

      output = self.conv(output)

      # Apply up-sampling layers
      for layer in self.up_sample_layers:
          output = F.interpolate(output, scale_factor=2, mode='bilinear', align_corners=False)
          output = torch.cat([output, stack.pop()], dim=1)
          output = layer(output)
      return self.conv2(output)

In [None]:
class DataTransform:

    def __init__(self, resolution, which_challenge, mask_func=None):
        if which_challenge not in ('singlecoil', 'multicoil'):
            raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"')
        self.resolution = resolution
        self.which_challenge = which_challenge
        self.mask_func = mask_func

    def __call__(self, kspace, target, attrs, fname, slice):
        kspace = to_tensor(kspace)
        if self.mask_func is not None:
            seed = tuple(map(ord, fname))
            masked_kspace, _ = apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace
        # Inverse Fourier Transform to get zero filled solution
        image = ifft2(masked_kspace)
        # Crop input image
        image = complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image = image.abs()
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = root_sum_of_squares(image)
        # Normalize input
        image, mean, std = normalize_instance(image)
        image = image.clamp(-6, 6)
        return image, mean, std, fname, slice

In [None]:
def create_data_loaders(args):
  mask_func = None
  if args['mask_kspace']:
      mask_func = MaskFunc(args['center_fractions'], args['accelerations'])
  challenge=args['challenge']
  data_split=args['data_split']
  data = SliceData(
      root=args['data_path'] / f'{challenge}_{data_split}',
      transform=DataTransform(args['resolution'], args['challenge'], mask_func),
      sample_rate=1.,
      challenge=args['challenge']
  )
  data_loader = DataLoader(
      dataset=data,
      batch_size=args['batch_size'],
      num_workers=4,
      pin_memory=True,
  )
  return data_loader

In [None]:
def load_model(checkpoint_file):
  checkpoint = torch.load(checkpoint_file)
  args = checkpoint['args']
  model = UnetModel(1, 1, args['num_chans'], args['num_pools'], args['drop_prob']).to(args['device'])
  if args['data_parallel']:
      model = torch.nn.DataParallel(model)
  model.load_state_dict(checkpoint['model'])
  return model

In [None]:
def run_unet(args, model, data_loader):
  model.eval()
  reconstructions = defaultdict(list)
  with torch.no_grad():
      for (input, mean, std, fnames, slices) in data_loader:
          input = input.unsqueeze(1).to(args['device'])
          input_new = (input[:, :, :, :, 0] + input[:, :, :, :, 1]) / 2
          recons = model(input_new).to('cpu').squeeze(1)
          for i in range(recons.shape[0]):
              recons[i] = recons[i] * std[i] + mean[i]
              reconstructions[fnames[i]].append((slices[i].numpy(), recons[i].numpy()))

  reconstructions = {
      fname: np.stack([pred for _, pred in sorted(slice_preds)])
      for fname, slice_preds in reconstructions.items()
  }
  return reconstructions

In [None]:
data_loader = create_data_loaders(args)

  cpuset_checked))


In [None]:
model = load_model(args['checkpoint'])

In [None]:
reconstructions = run_unet(args, model, data_loader)

  cpuset_checked))


In [None]:
save_reconstructions(reconstructions, args['out_dir'])