In [None]:
# Generate some field map
import torch
import matplotlib.pyplot as plt

dim = (1, 64, 64)

b0_max = 800
b0_map = torch.linspace(-b0_max, b0_max, dim[-1]).repeat(dim[-2], 1)[None, ...]
b0_map += torch.rot90(b0_map, k=-1, dims=(-2, -1))

plt.imshow(b0_map.squeeze().numpy())
_ = plt.colorbar()

In [None]:
# Generate phantom

from mrpro.phantoms import EllipsePhantom
from mrpro.data import SpatialDimension

s_dim = SpatialDimension(z=dim[0], y=dim[1], x=dim[2])

phantom = EllipsePhantom()
img = phantom.image_space(s_dim)

img_max = img.abs().max()
img_min = img.abs().min()

plt.imshow(img.abs().squeeze().numpy(), cmap="gray")
_ = plt.colorbar()

In [None]:
from mrpro.operators.TimeSegmentedFastFourierOp import TimeSegmentedFastFourierOp
from mrpro.operators.FastFourierOp import FastFourierOp

ro_bandwidth = 20e3
t_ro = torch.arange(dim[-1])/ro_bandwidth

# Apply time segmented operator to obtain distorted k-space
ts_op = TimeSegmentedFastFourierOp(b0_map=b0_map, readout_times=t_ro, num_segments=16)
# Reconstruct k-space with normal FFT to obtain distorted image
fft_op = FastFourierOp()

(ts_distorted_ksp,) = ts_op.forward(img)
(ts_distorted_img,) = fft_op.adjoint(ts_distorted_ksp)

fig, ax = plt.subplots(1, 2)
ax[0].imshow(img.squeeze().abs().numpy(), cmap="gray", vmin=img_min, vmax=img_max)
_ = ax[0].set_title("Undistorted")
ax[1].imshow(ts_distorted_img.squeeze().abs().numpy(), cmap="gray", vmin=img_min, vmax=img_max)
_ = ax[1].set_title("TS Distorted")


In [None]:
from mrpro.operators.ConjugatePhaseFastFourierOp import ConjugatePhaseFastFourierOp

cp_op = ConjugatePhaseFastFourierOp(b0_map=b0_map, readout_times=t_ro)

(cp_distorted_ksp,) = cp_op.forward(img)
(cp_distorted_img,) = fft_op.adjoint(cp_distorted_ksp)

fig, ax = plt.subplots(1, 2)
ax[0].imshow(img.squeeze().abs().numpy(), cmap="gray", vmin=img_min, vmax=img_max)
_ = ax[0].set_title("Undistorted")
ax[1].imshow(cp_distorted_img.squeeze().abs().numpy(), cmap="gray", vmin=img_min, vmax=img_max)
_ = ax[1].set_title("CP Distorted")

In [None]:
# Compare conjugate phase and time segmented approach
distortion_dif = ((cp_distorted_img - ts_distorted_img).squeeze().abs() - img_min) / (img_max - img_min)

plt.imshow(distortion_dif.numpy())
_ = plt.colorbar()

In [None]:
# Distortion correction with CP and TS
(cp_corrected_img,) = cp_op.adjoint(cp_distorted_ksp)
(ts_corrected_img,) = ts_op.adjoint(cp_distorted_ksp)

# Calculate error of corrected images
cp_error = ((img - cp_corrected_img).squeeze().abs() - img_min) / (img_max - img_min)
ts_error = ((img - ts_corrected_img).squeeze().abs() - img_min) / (img_max - img_min)

fig, ax = plt.subplots(1, 5, figsize=(15, 3))
ax[0].imshow(img.squeeze().abs().numpy(), cmap="gray", vmin=img_min, vmax=img_max)
_ = ax[0].set_title("Original")
ax[1].imshow(cp_corrected_img.squeeze().abs().numpy(), cmap="gray", vmin=img_min, vmax=img_max)
_ = ax[1].set_title("CP Corrected")
ax[2].imshow(cp_error.numpy(), vmin=0, vmax=1)
_ = ax[2].set_title("CP Error")
ax[3].imshow(ts_corrected_img.squeeze().abs().numpy(), cmap="gray", vmin=img_min, vmax=img_max)
_ = ax[3].set_title("TS Corrected")
ax[4].imshow(ts_error.numpy(), vmin=0, vmax=1)
_ = ax[4].set_title("TS Error")