<a href="https://colab.research.google.com/github/RGologorsky/fastmri/blob/master/01_kspace_tfms.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Kspace tfms

- Implemented kspace transforms with np and Pytorch tensors
- Testing in testing_01_kspace_tfms.ipynb.

For real images, 
- FFT of a purely real array (eg image) has conjugate symmetry: ```
fft2_result[a, b] = fft2_result[-a, -b].conj()
```
- So rfft2 ouputs the left half (plus one column) of fft2 to save space & memory

- RFFT output could correspond to either an odd or even length signal.
- By default, irfft assumes an even output length.
- To avoid losing information, when doing inverse (irfft), given the correct length of the original input.

Sources: 

- https://stackoverflow.com/questions/43001729/how-should-i-interpret-the-output-of-numpy-fft-rfft2

- https://numpy.org/doc/stable/reference/generated/numpy.fft.irfft.html

In [0]:
import scipy.fft as SP
import numpy.fft as NP

In [0]:
class NpTfms():

  # rfft w/ normalization
  rfft2  = partial(SP.rfft2,  norm="ortho")
  irfft2 = partial(SP.irfft2, norm="ortho")

  # fft w/ normalization - does not require complex input
  fft2  = partial(SP.fft2,  norm="ortho")
  ifft2 = partial(SP.ifft2, norm="ortho")


  # convert image to np array
  im2arr   = [np.array]

  # (uncentered) real arr obj to/from (centered) kspace
  real2k  = [rfft2, NP.fftshift]
  
  @classmethod
  def k2real(cls, s=None): return [NP.ifftshift, partial(cls.irfft2, s=s)]

  # same as above, but fft instead of rfft
  fft_real2k  = [fft2, NP.fftshift]
  fft_k2real  = [NP.ifftshift, ifft2]
  
  # (centered) complex arr obj to/from (centered) kspace
  complex2k  = [NP.ifftshift, fft2, NP.fftshift]
  k2complex  = [NP.ifftshift, ifft2, NP.fftshift]
   
  # Viz: kspace to amplitude-only img (log scale)
  np_abs = [np.abs]
  log_abs = [add(1e-9), np.log, np.abs]


In [0]:
class TensorTfms():

  # rfft in 2dim
  rfft2  = partial(torch.rfft,  signal_ndim = 2, normalized=True)
  irfft2 = partial(torch.irfft, signal_ndim = 2, normalized=True)

  # fft in 2dim - expects complex input
  fft2  = partial(torch.fft,  signal_ndim = 2, normalized=True)
  ifft2 = partial(torch.ifft, signal_ndim = 2, normalized=True)

  # batch fft & ifft shift -- shift all but batch dimension
  def batch_ifftshift(x):
    dim = tuple(range(x.dim()))[1:]
    shift = [(dim + 1) // 2 for dim in x.shape[1:]]
    return T.roll(x, shift, dim)

  def batch_fftshift(x):
    dim = tuple(range(x.dim()))[1:]
    shift = [dim // 2 for dim in x.shape[1:]]
    return T.roll(x, shift, dim)

  # convert image to Pytorch tensor
  def im2arr(im): return tensor(im).double()

  # convert real tensor into complex tensor
  def real2complex(t): return torch.stack((t, torch.zeros(t.shape).double()), axis=-1)

  # (uncentered) real arr obj to/from (centered) kspace
  
  #real2k  = [rfft2, T.fftshift]
  @classmethod
  def real2k(cls, onesided=True): 
    return [partial(cls.rfft2, onesided=onesided), T.fftshift]


  @classmethod
  def k2real(cls, s=None, onesided=True): 
    return [T.ifftshift, partial(cls.irfft2, signal_sizes=s, onesided=onesided)]

 # BATCH version -- does not shift the batch axis 
  @classmethod
  def batch_real2k(cls, onesided=True): 
    return [partial(cls.rfft2, onesided=onesided), cls.batch_fftshift]

  # BATCH version -- does not shift the batch axis 
  @classmethod
  def batch_k2real(cls, s=None, onesided=True): 
    return [cls.batch_ifftshift, partial(cls.irfft2, signal_sizes=s, onesided=onesided)]

  # same as above, but fft instead of rfft
  fft_real2k  = [fft2, T.fftshift]
  fft_k2real  = [T.ifftshift, ifft2, T.complex_abs]
  
  # (centered) complex arr obj to/from (centered) kspace
  complex2k  = [T.fft2]
  k2complex  = [T.to_tensor, T.ifft2]

  # BATCH version -- does not shift the batch axis 
  batch_fft_real2k = [fft2, batch_fftshift] 
  batch_fft_k2real = [batch_ifftshift, ifft2, T.complex_abs]
   
  # Viz: kspace to amplitude-only img (log scale)
  t_abs = [T.complex_abs]
  log_abs = [add(1e-9), torch.log, torch.abs]

NameError: ignored

In [0]:
class TensorTfmsBase():

  # convert image to Pytorch tensor, real tensor into complex tensor
  def im2arr(im):      return tensor(im).double()
  def real2complex(t): return torch.stack((t, torch.zeros(t.shape).double()), axis=-1)

  # real to centered kspace (and back)
  @classmethod def real2k(cls): return [cls.fft, T.fftshift]
  @classmethod def k2real(cls): return [T.ifftshift, cls.ifft]
  
  # batch methods do not shift batch axis
  @classmethod def batch_real2k(cls): return [cls.fft, batch_ffshift] 
  @classmethod def batch_k2real(cls): return [batch_ifftshift, self.ifft]

  # (centered) complex arr obj to (centered) kspace (and back)
  complex2k  = [T.fft2]
  k2complex  = [T.to_tensor, T.ifft2]
   
  # Viz complex k in magnitude-only kspace (log)
  magn = [T.complex_abs]
  log_abs = [add(1e-9), torch.log, torch.abs]


class TensorTfmsReal(TensorTfmsBase):
  @classmethod 
  def fft(onesided=True): 
    return partial(torch.rfft,  signal_ndim = 2, normalized=True, onesided=onesided)
 
  @classmethod 
  def ifft(s=None, onesided=True): 
    return partial(torch.rfft,  signal_ndim = 2, normalized=True, signal_sizes=s, onesided=onesided)


In [0]:
def apply(data, tfms, pre=None, post=None): return Pipeline(L(pre) + L(tfms) + L(post))(data)

## Viz: function to plot images

In [0]:
def idx(lst,i, default=None): return lst[i] if i < len(lst) else default

In [0]:
def plot(imgs, titles=[None], cmaps=["gray"], nrows=1, ncols=1, figsize = (6,6), **kwargs):

   # listify so we input can be string instead of 1-item list
  imgs, titles, cmaps = L(imgs), L(titles), L(cmaps)

  # default set nrows, ncols = 1, len(imgs)
  if nrows * ncols != len(imgs): nrows, ncols = 1, len(imgs)

  # default repeat cmap until same size as images
  cmaps = cmaps * int(len(imgs)/len(cmaps))

  fig,axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
  axes = axes.flatten()
  for i,im in enumerate(imgs): 
    axes[i].imshow(im, cmap=idx(cmaps,i))
    axes[i].set_xticklabels([]), axes[i].set_yticklabels([])
    axes[i].set_title(idx(titles,i))
    axes[i].set_xlabel(im.shape)
  fig.tight_layout()
  fig.show()

# Common KSpace Transforms

In [0]:
# permute kspace tensor: N1HW(Complex) to N(Complex)HW
class Complex2Channel(Transform):
  order = 99 # happens after complex k

  # N1HW(Complex) -> N(Complex)HW
  def encodes(self, t:Tensor):
    if t.size(-1) == 2: return torch.squeeze(t.transpose(-1,-2).transpose(-2,-3))
    return t

  # NCHW -> NHWC
  def decodes(self, t:Tensor):
    if t.size(-3) == 2: return t.transpose(-3,-2).transpose(-2,-1)
    return t
    
# shows concat real & kspace into one image
class ShowK(Tuple):
  def show(self, ctx=None, **kwargs): 
    k,real = self
    line = k.new_zeros(k.shape[0], 10)
    return show_image(torch.cat([k,line,real], dim=1), title = "K & Real", ctx=ctx, **kwargs)

# take dataset item (real img, category), convert to (k arr, category)
class BatchReal2ComplexK(Transform):
  order = 50 # needs to run after save shape

  # do nothing to tensor categories
  def encodes(self, t:TensorCategory): return t
  def decodes(self, t:TensorCategory): return t

  def encodes(self, t:Tensor): return apply(t, TensorTfms.batch_real2k(onesided=False))


  def decodes(self, t_k:Tensor):
    t_k_abs         = apply(t_k, TensorTfms.t_abs)
    t_k_log_abs     = apply(t_k_abs, TensorTfms.log_abs)

    t_real     = apply(t_k, TensorTfms.batch_k2real(onesided=False))
    
    return ShowK(t_k_log_abs, t_real)

# converts complex k-space (2channel) to amplitude k-space (1channel)
class ComplexK2LogAbs(Transform):
  order = 51 # needs to run after Real2ComplexK

  # do nothing to tensor categories
  def encodes(self, t:TensorCategory): return t
  def decodes(self, t:TensorCategory): return t

  def encodes(self, t:Tensor): return apply(t, TensorTfms.log_abs, pre=TensorTfms.t_abs)