<a href="https://colab.research.google.com/github/RGologorsky/fastmri/blob/master/01_kspace_tfms.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Kspace tfms

- Implemented kspace transforms with np and Pytorch tensors
- Testing in testing_01_kspace_tfms.ipynb.

In [0]:
[hobbit-hole](https://en.wikipedia.org/wiki/Hobbit#Lifestyle "Hobbit lifestyles")

For real images, 
- FFT of a purely real array (eg image) has conjugate symmetry: ```
fft2_result[a, b] = fft2_result[-a, -b].conj()
```
- So rfft2 ouputs the left half (plus one column) of fft2 to save space & memory

- RFFT output could correspond to either an odd or even length signal.
- By default, irfft assumes an even output length.
- To avoid losing information, when doing inverse (irfft), given the correct length of the original input.

Sources: 

- https://stackoverflow.com/questions/43001729/how-should-i-interpret-the-output-of-numpy-fft-rfft2

- https://numpy.org/doc/stable/reference/generated/numpy.fft.irfft.html

In [0]:
class TensorTfmsBase():
  # Shift low frequencies to center
  ishift = partial(T.ifftshift, dim=(-3,-2))
  shift  = partial(T.fftshift,  dim=(-3,-2))
  
  # convert image to float tensor; real tensor into complex tensor
  def im2tn(im): return tensor(im).double()
  def real2complex(t): return torch.stack((t, torch.zeros(t.shape).double()), axis=-1)

  # Decorators
  static_methods = [im2tn, real2complex]
  for f in static_methods: f = staticmethod(f)

In [0]:
class MRTensorTfms(TensorTfmsBase):
  # FFT
  fft  = partial(torch.ifft, signal_ndim = 2, normalized=True)
  ifft = partial(torch.ifft, signal_ndim = 2, normalized=True)

  # centered image to (centered) kspace (and back)
  def mr2k(cls): return [cls.ishift,  fft, cls.shift]
  def k2mr(cls): return [cls.ishift, ifft, cls.shift]

  # Viz complex k in magnitude-only kspace (log)
  mgn  = [T.complex_abs]
  complex2log_mgn = mgn + [add(1e-9), torch.log, torch.abs]

  # Decorators
  class_methods = [mr2k, k2mr]
  for f in class_methods: f = classmethod(f)

In [0]:
class IMTensorTfms(TensorTfmsBase):
  # FFT (real)
  fft  = partial(torch.rfft,  signal_ndim = 2, normalized=True)
  ifft = partial(torch.irfft, signal_ndim = 2, normalized=True)

  # uncentered image to (centered) kspace (and back)
  def im2k(cls): return [fft,   cls.shift]
  def k2im(cls): return [cls.ishift, ifft]

    # Decorators
  class_methods = [im2k, k2im]
  for f in class_methods: f = classmethod(f)

In [0]:
# class TensorTfmsBase():
  
#   # FFT
#   fft  = partial(torch.ifft, signal_ndim = 2, normalized=True)
#   ifft = partial(torch.ifft, signal_ndim = 2, normalized=True)

#   # Shift low frequencies to center
#   ishift = partial(T.ifftshift, dim=(-3,-2))
#   shift  = partial(T.fftshift,  dim=(-3,-2))

#   # convert image to float tensor; real tensor into complex tensor
#   def im2tn(im): return tensor(im).double()
#   def real2complex(t): return torch.stack((t, torch.zeros(t.shape).double()), axis=-1)

#   # centered image to (centered) kspace (and back)
#   def center_im2k(cls): return [cls.ishift, cls.fft, cls.shift]
#   def k2center_im(cls): return [cls.ishift, cls.ifft, cls.shift]

#   # uncentered image to (centered) kspace (and back)
#   def uncentered_im2k(cls): return [cls.fft,   cls.shift]
#   def uncentered_k2im(cls): return [cls.ishift, cls.ifft]

#   # Viz complex k in magnitude-only kspace (log)
#   mgn  = [T.complex_abs]
#   complex2log_mgn = mgn + [add(1e-9), torch.log, torch.abs]

#   # Decorators
#   static_methods = [im2tn, real2complex]
#   class_methods  = [center_im2k, k2center_im, im2k, k2im]

#   for f in static_methods: f = staticmethod(f)
#   for f in class_methods:  f = classmethod(f)

In [0]:
# class MRTensorTfms(TensorTfmsBase):
#   # centered real image to centered kspace (and back)
#   def im2k(cls): return super().centered_im2k()     
#   def k2im(cls): return super().centered_k2im()
  
#   # decorators
#   for f in [im2k, k2im]: f = classmethod(f)

# class RealImageTensorTfms(TensorTfmsBase):
#   fft  = partial(torch.rfft,  signal_ndim = 2, normalized=True)
#   ifft = partial(torch.irfft, signal_ndim = 2, normalized=True)

#   # uncentered real image to centered kspace (and back)
#   @classmethod
#   def im2k(cls, onesided=True): 
#     cls.fft = partial(fft, onesided=onesided)
#     return super().uncentered_im2k()
     
#   @classmethod
#   def k2im(cls, onesided=True, s=None):
#     cls.ifft =  partial(ifft, onesided=onesided, signal_sizes=s)
#     return super().uncentered_k2im()

In [0]:
# class BatchShift():
#   # batch fft & ifft shift -- shift all but batch dimension
#   def ifftshift(x):
#     dim = tuple(range(x.dim()))[1:]
#     shift = [(dim + 1) // 2 for dim in x.shape[1:]]
#     return T.roll(x, shift, dim)

#   def fftshift(x):
#     dim = tuple(range(x.dim()))[1:]
#     shift = [dim // 2 for dim in x.shape[1:]]
#     return T.roll(x, shift, dim)

# class BatchTensorTfms(TensorTfms, BatchShift): pass
# class BatchRealTensorTfms(RealTensorTfms, BatchShift): pass

In [0]:
def apply(data, tfms, pre=None, post=None): return Pipeline(L(pre) + L(tfms) + L(post))(data)

## Viz: function to plot images

In [0]:
def idx(lst,i, default=None): return lst[i] if i < len(lst) else default

In [0]:
def plot(imgs, titles=[None], cmaps=["gray"], nrows=1, ncols=1, figsize = (6,6), **kwargs):

   # listify so we input can be string instead of 1-item list
  imgs, titles, cmaps = L(imgs), L(titles), L(cmaps)

  # default set nrows, ncols = 1, len(imgs)
  if nrows * ncols != len(imgs): nrows, ncols = 1, len(imgs)

  # default repeat cmap until same size as images
  cmaps = cmaps * int(len(imgs)/len(cmaps))

  fig,axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
  axes = axes.flatten()
  for i,im in enumerate(imgs): 
    axes[i].imshow(im, cmap=idx(cmaps,i))
    axes[i].set_xticklabels([]), axes[i].set_yticklabels([])
    axes[i].set_title(idx(titles,i))
    axes[i].set_xlabel(im.shape)
  fig.tight_layout()
  fig.show()

# Common KSpace Transforms

In [0]:
# permute kspace tensor: N1HW(Complex) to N(Complex)HW
class Complex2Channel(Transform):
  order = 99 # happens after complex k

  # N1HW(Complex) -> N(Complex)HW
  def encodes(self, t:Tensor):
    if t.size(-1) == 2: return torch.squeeze(t.transpose(-1,-2).transpose(-2,-3))
    return t

  # NCHW -> NHWC
  def decodes(self, t:Tensor):
    if t.size(-3) == 2: return t.transpose(-3,-2).transpose(-2,-1)
    return t
    
# shows concat real & kspace into one image
class ShowK(Tuple):
  def show(self, ctx=None, **kwargs): 
    k,real = self
    line = k.new_zeros(k.shape[0], 10)
    return show_image(torch.cat([k,line,real], dim=1), title = "K & Real", ctx=ctx, **kwargs)

# take dataset item (real img, category), convert to (k arr, category)
class BatchReal2ComplexK(Transform):
  order = 50 # needs to run after save shape

  # do nothing to tensor categories
  def encodes(self, t:TensorCategory): return t
  def decodes(self, t:TensorCategory): return t

  def encodes(self, t:Tensor): return apply(t, TensorTfms.batch_real2k(onesided=False))


  def decodes(self, t_k:Tensor):
    t_k_abs         = apply(t_k, TensorTfms.t_abs)
    t_k_log_abs     = apply(t_k_abs, TensorTfms.log_abs)

    t_real     = apply(t_k, TensorTfms.batch_k2real(onesided=False))
    
    return ShowK(t_k_log_abs, t_real)

# converts complex k-space (2channel) to amplitude k-space (1channel)
class ComplexK2LogAbs(Transform):
  order = 51 # needs to run after Real2ComplexK

  # do nothing to tensor categories
  def encodes(self, t:TensorCategory): return t
  def decodes(self, t:TensorCategory): return t

  def encodes(self, t:Tensor): return apply(t, TensorTfms.log_abs, pre=TensorTfms.t_abs)