# Tomographic reconstruction using linear transfer matrices 

In [None]:
import sys
import numpy as np
from matplotlib import pyplot as plt
import proplot as pplt
import seaborn as sns
from tqdm import tqdm, trange
from skimage.transform import radon
from skimage.transform import iradon
from scipy import interpolate

sys.path.append('/Users/austin/Research/') 
from scdist.tools import ap_utils
from scdist.tools import beam_analysis as ba
from scdist.tools import plotting as myplt
from scdist.tools import utils
from scdist.measurement.tomography import reconstruct as rec

In [None]:
pplt.rc['axes.grid'] = False
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.sequential'] = 'viridis'
pplt.rc['figure.facecolor'] = 'white'

## Setup 

Define the reconstruction grid in normalized phase space.

In [None]:
n_bins = 75
# Twiss parameters
eps = 10.0 # [mm mrad]
alpha = -2.0
beta = 10.0
xmax_rec = 2.0 * np.sqrt(4.0 * eps)
xpmax_rec = 2.0 * np.sqrt(4.0 * eps)
rec_grid_edges = [
    np.linspace(-xmax_rec, xmax_rec, n_bins + 1),
    np.linspace(-xpmax_rec, xpmax_rec, n_bins + 1),
]
rec_grid_centers = [
    utils.get_bin_centers(rec_grid_edges[0]),
    utils.get_bin_centers(rec_grid_edges[1]),
]

Create the true distribution.

In [None]:
np.random.seed(0)
n = 500000
X_true_n = np.random.normal(scale=np.sqrt(eps), size=(n, 2))

# Unnormalize using Twiss parameters.
V1 = ap_utils.V_matrix_2x2(alpha, beta)
X_true = utils.apply(V1, X_true_n)

In [None]:
fig, ax = pplt.subplots()
ax.hist2d(X_true[:, 0], X_true[:, 1], bins=75)
plt.show()

In [None]:
fig, ax = pplt.subplots()
Z_true_n, _, _ = np.histogram2d(X_true_n[:, 0], X_true_n[:, 1], bins=rec_grid_edges)
ax.pcolormesh(rec_grid_centers[0], rec_grid_centers[1], Z_true_n.T, ec='none')
ax.format(xlabel=r"x_n", ylabel=r"$x'_n$", title="Reconstruction point")
plt.show()

Simulate the measurements. We will scan the phase advance in a 180 degree range, then let $\alpha$ and $\beta$ take random values at the measurement location.

In [None]:
n_proj = 15
phase_advances = np.linspace(0.0, np.pi, n_proj, endpoint=False)
betas = np.random.uniform(55.0, 75.0, size=n_proj)
alphas = np.random.uniform(-1.5, +1.5, size=n_proj)
xmax_screen = 90.0 # [mm]
screen_xedges = np.linspace(-xmax_screen, xmax_screen, n_bins + 1)
screen_xcenters = 0.5 * (screen_xedges[:-1] + screen_xedges[1:])

In [None]:
tmats, projections = [], []
for k in trange(n_proj):
    # Compute the linear transfer matrix to the screen.
    V2 = ap_utils.V_matrix_2x2(alphas[k], betas[k])
    P = utils.rotation_matrix(phase_advances[k])
    M = np.linalg.multi_dot([V2, P, np.linalg.inv(V1)])
    # Transport the distribution to the screen.
    X_meas = utils.apply(M, X_true)
    # Compute the projection on the screen.
    projection, _ = np.histogram(X_meas[:, 0], screen_xedges)
    projection / np.sum(projection)
    projections.append(projection)
    # Save M tilde to reconstruct in normalized phase space.
    N = np.linalg.multi_dot([M, V1])
    tmats.append(N)

In [None]:
fig, ax = pplt.subplots()
ax.pcolormesh(screen_xcenters, np.arange(n_proj), projections, ec='None', colorbar=True)
ax.format(xlabel=r'$x_b$', ytickminor=False, ylabel='Projection index')
plt.show()

Obtain the projections at the reconstruction location. The projections at $a$ (the reconstruction location) and $b$ (the measurement location) are related by $p_a(s) = p_b(r s)$, where $s$ is the projection axis at $a$, which is rotated by angle $\theta = \tan^{-1}(M_{11} / M_{12})$. The projection scaling is given by $r = \sqrt{{M_{11}}^2 + {M_{12}}^2}$. Since the bin positions on the $s$ axis will differ depending on the transfer matrix, we first define the reconstruction grid, then linearly interpolate to get the projection at $b$: $p_b(r s)$.

In [None]:
scaled_projections = np.zeros((n_proj, n_bins))
proj_angles = np.zeros(n_proj)
scale_factors = np.zeros(n_proj)
for k, (tmat, projection) in enumerate(zip(tmats, projections)):
    r = rec.get_projection_scaling(tmat)
    interp = interpolate.interp1d(screen_xcenters, projection, kind='linear', 
                                  bounds_error=False, fill_value=0.0)
    scaled_projection = interp(r * rec_grid_centers[0])
    scaled_projections[k, :] = scaled_projection / np.sum(scaled_projection)
    proj_angles[k] = rec.get_projection_angle(tmat)
    scale_factors[k] = r

In [None]:
fig, axes = pplt.subplots(ncols=2, sharey=False)
axes[0].plot(np.degrees(proj_angles), color='black')
axes[1].plot(scale_factors, color='black')
axes.format(xlabel='Projection index', xtickminor=False)
axes[0].format(yformatter='deg', ylabel='Projection angle')
axes[1].format(ylabel='Scale factor')
plt.show()

In [None]:
fig, ax = pplt.subplots()
ax.pcolormesh(rec_grid_centers[0], np.arange(n_proj), scaled_projections, ec='None', colorbar=True)
ax.format(xlabel=r'$x_b$', ytickminor=False, ylabel='Projection index')
plt.show()

## Reconstruction 

In [None]:
Z_rec_n = rec.sart(scaled_projections.T, proj_angles, iterations=2, keep_positive=True)

In [None]:
psi = np.linspace(0.0, 2.0 * np.pi, 1000)
xn = np.sqrt(4.0 * eps) * np.cos(psi)
yn = np.sqrt(4.0 * eps) * np.sin(psi)
x, y = np.matmul(V1, [xn, yn])

In [None]:
fig, axes = pplt.subplots(ncols=2)
plot_kws = dict(ec='None', cmap='viridis')
axes[0].pcolormesh(rec_grid_centers[0], rec_grid_centers[1], Z_true_n.T, **plot_kws)
axes[1].pcolormesh(rec_grid_centers[0], rec_grid_centers[1], Z_rec_n.T, **plot_kws)
axes[0].set_title('True')
axes[1].set_title('Reconstructed')
axes.format(xlabel=r'$x_n$', ylabel=r"$x'_n$",
            suptitle='Normalized phase space')
for ax in axes:
    ax.plot(xn, yn, color='white', lw=0.5)
plt.show()

In [None]:
Z_rec, rec_grid_centers_new = rec.transform(Z_rec_n, V1, rec_grid_centers)

In [None]:
fig, axes = pplt.subplots(ncols=2)
plot_kws = dict(ec='None', cmap='viridis')
axes[0].pcolormesh(rec_grid_edges_new[0], rec_grid_edges_new[1], Z_true.T, **plot_kws)
axes[1].pcolormesh(rec_grid_edges_new[0], rec_grid_edges_new[1], Z_rec.T, **plot_kws)
axes[0].set_title('True')
axes[1].set_title('Reconstructed')
axes.format(xlabel=r'$x$', ylabel=r"$x'$",
            suptitle='Actual phase space')
for ax in axes:
    ax.plot(x, y, color='white', lw=0.5)
plt.show()