<a href="https://colab.research.google.com/github/RGologorsky/fastmri/blob/master/01_kspace_tfms.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# useful constants
class C():
  # Shift low frequencies to center
  shift  = T.fftshift
  ishift = T.ifftshift

  # Shift low frequencies to center (exclude batch dimension)
  batch_shift  = partial(T.fftshift,  dim=(-3,-2))
  batch_ishift = partial(T.ifftshift, dim=(-3,-2))

  # FFT (complex)
  fft   = partial(torch.fft,  signal_ndim = 2, normalized=True)
  ifft  = partial(torch.ifft, signal_ndim = 2, normalized=True)
  
  # FFT (real)
  rfft2  = partial(torch.rfft,  signal_ndim = 2, normalized=True)
  irfft2 = partial(torch.irfft, signal_ndim = 2, normalized=True)

  # get rrft, irfft
  def rfft(o=True):          return partial(C.rfft2, onesided=o)
  def irfft(o=True, s=None): return partial(C.irfft2, onesided=o, signal_sizes=s)

  # convert image to float tensor; real tensor into complex tensor
  def im2tn(im): return tensor(im).double()
  def real2complex(t): return torch.stack((t, torch.zeros(t.shape).double()), axis=-1)

  # im2k - implemented by subclasses
  def im2k(fft):  pass
  def k2im(ifft): pass
  
  # Viz complex k in magnitude-only kspace (log)
  complex2mgn = [T.complex_abs]
  log = [add(1e-9), torch.log, torch.abs]

  # Decorators
  static_methods = [rfft, irfft, im2tn, real2complex, im2k, k2im]
  class_methods  = []
  for f in static_methods: f = staticmethod(f)
  for f in class_methods:  f = classmethod(f)

In [0]:
# centered image to (centered) kspace (and back)
class CenteredTfms(C):
  def im2k(fft=C.fft,  shift=C.shift, ishift=C.ishift): return [ishift,  fft, shift]
  def k2im(ifft=C.ifft,shift=C.shift, ishift=C.ishift): return [ishift, ifft, shift]

# uncentered image to (centered) kspace (and back)
class UncenteredTfms(C):
  def im2k(fft=C.fft,   shift=C.shift, ishift=C.ishift): return [fft,   shift]
  def k2im(ifft=C.ifft, shift=C.shift, ishift=C.ishift): return [ishift, ifft]


In [0]:
def apply(data, tfms, pre=None, post=None): return Pipeline(L(pre) + L(tfms) + L(post))(data)

In [0]:
# permute kspace tensor: N1HW(Complex) to N(Complex)HW
class Complex2Channel(Transform):
  order = 99 # happens after complex k

  # N1HW(Complex) -> N(Complex)HW
  def encodes(self, t:Tensor):
    if t.size(-1) == 2: return torch.squeeze(t.transpose(-1,-2).transpose(-2,-3))
    return t

  # NCHW -> NHWC
  def decodes(self, t:Tensor):
    if t.size(-3) == 2: return t.transpose(-3,-2).transpose(-2,-1)
    return t


In [0]:
# converts complex k-space (2channel) to amplitude k-space (1channel)
class ComplexK2LogMgn(Transform):
  order = 51 # needs to run after Real2ComplexK

  # do nothing to tensor categories
  def encodes(self, t:TensorCategory): return t
  def decodes(self, t:TensorCategory): return t

  def encodes(self, t:Tensor): return apply(t, C.log, pre=C.complex2mgn)

In [0]:
def idx(lst,i, default=None): return lst[i] if i < len(lst) else default

In [0]:
def plot(imgs, titles=[None], cmaps=["gray"], nrows=1, ncols=1, figsize = (6,6), **kwargs):

   # listify so we input can be string instead of 1-item list
  imgs, titles, cmaps = L(imgs), L(titles), L(cmaps)

  # default set nrows, ncols = 1, len(imgs)
  if nrows * ncols != len(imgs): nrows, ncols = 1, len(imgs)

  # default repeat cmap until same size as images
  cmaps = cmaps * int(len(imgs)/len(cmaps))

  fig,axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
  axes = axes.flatten()
  for i,im in enumerate(imgs): 
    axes[i].imshow(im, cmap=idx(cmaps,i))
    axes[i].set_xticklabels([]), axes[i].set_yticklabels([])
    axes[i].set_title(idx(titles,i))
    axes[i].set_xlabel(im.shape)
  fig.tight_layout()
  fig.show()