# 4D ART

In [None]:
from time import time

from tqdm import trange
from tqdm import tqdm
import importlib
import numpy as np
import matplotlib.pyplot as plt
import proplot as pplt
from skimage import transform

import reconstruct as rec

import sys
sys.path.append('/Users/46h/Research/')
from accphys.tools import plotting as myplt
from accphys.tools import utils

In [None]:
pplt.rc['axes.grid'] = False
pplt.rc['figure.facecolor'] = 'white'
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.sequential'] = 'mono_r'
pplt.rc['savefig.dpi'] = 'figure'
pplt.rc['animation.html'] = 'jshtml'
savefig_kws = dict(dpi=300)

## Setup 

### Create distribution 

In [None]:
# Create a rigid rotating distribution.
X = np.random.normal(size=(400000, 4))
X = np.apply_along_axis(lambda row: row / np.linalg.norm(row), 1, X)
X[:, 3] = +X[:, 0]
X[:, 1] = -X[:, 2]

# Change the x-y phase difference.
R = np.zeros((4, 4))
R[:2, :2] = utils.rotation_matrix(np.pi / 4)
R[2:, 2:] = utils.rotation_matrix(0.0)
X = np.apply_along_axis(lambda row: np.matmul(R, row), 1, X)

# Add some noise.
X += np.random.normal(scale=0.4, size=X.shape)

# Plot the 2D projections.
n_bins = 50
axes = myplt.corner(X, figsize=(6, 6), bins=n_bins, cmap='mono_r')
plt.show()

# Store the limits for each dimension.
limits = [ax.get_xlim() for ax in axes[-1, :]]
labels = ["x", "x'", "y", "y'"]

In [None]:
Z_true, edges = np.histogramdd(X, n_bins, limits, density=True)
centers = []
for _edges in edges:
    centers.append(0.5 * (_edges[:-1] + _edges[1:]))
bin_volume = rec.get_bin_volume(limits, n_bins)

### Simulate measurements

In [None]:
K = 7 # number of angles in x dimension
L = 7 # number of angles in y dimension
muxx = muyy = np.linspace(0., np.pi, K, endpoint=False)

xx_list = []
for mux in tqdm(muxx):
    Mx = utils.rotation_matrix(mux)
    xx_list.append(utils.apply(Mx, X[:, :2])[:, 0])
    
yy_list = []
for muy in tqdm(muyy):
    My = utils.rotation_matrix(muy)
    yy_list.append(utils.apply(My, X[:, 2:])[:, 0])
    
projections = []
for xx in tqdm(xx_list):
    for yy in yy_list:
        projection, _, _ = np.histogram2d(xx, yy, n_bins, (limits[0], limits[2]))
        projections.append(projection)

In [None]:
tmats = []
for mux in muxx:
    for muy in muyy:
        M = np.zeros((4, 4))
        M[:2, :2] = utils.rotation_matrix(mux)
        M[2:, 2:] = utils.rotation_matrix(muy)
        tmats.append(M)

In [None]:
screen_edges_x = edges[0]
screen_edges_y = edges[2]

In [None]:
# Treat each reconstruction bin center as a particle. We will call this collection of 
# particles the "bunch".
rec_grid_coords = rec.get_grid_coords(*centers)

# Keep this for later.
col_indices = np.arange(n_bins**4)

P_list = []
rho_list = []

n_proj = len(tmats)
for proj_index in trange(n_proj):

    # Transport the bunch to the screen.
    M = tmats[proj_index]
    screen_grid_coords = np.apply_along_axis(lambda row: np.matmul(M, row), 1, rec_grid_coords)

    # For each particle, record the indices of the bin it landed in. So we want (k, l) such
    # that the particle landed in the bin with x = x[k] and y = y[l] on the screen. One of 
    # the indices will be -1 or n_bins if the particle landed outside the screen.
    xidx = np.digitize(screen_grid_coords[:, 0], screen_edges_x) - 1
    yidx = np.digitize(screen_grid_coords[:, 2], screen_edges_y) - 1
    on_screen = np.logical_and(np.logical_and(xidx >= 0, xidx < n_bins), 
                               np.logical_and(yidx >= 0, yidx < n_bins))

    # Get the indices for the flattened array.
    projection = projections[proj_index]
    screen_idx = np.ravel_multi_index((xidx, yidx), projection.shape, mode='clip')

    # Create the array P such that P[i, j] = 1 if particle j landed in bin j
    # on the screen, or 0 otherwise.
    P = np.zeros((n_bins**2, n_bins**4))    
    for j in tqdm(col_indices[on_screen]):
        i = screen_idx[j]
        P[i, j] = 1.0

    P_list.append(P)
    rho_list.append(projection.flatten())

In [None]:
P = np.vstack(P_list)
rho = np.hstack(rho_list)

In [None]:
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import lsqr

P = csc_matrix(P)

In [None]:
psi, istop, itn, r1norm, r2norm, anorm, acond, arnorm, xnorm, var = lsqr(P, rho, show=True)

In [None]:
Z = psi.reshape((n_bins, n_bins, n_bins, n_bins))

In [None]:
cmap = 'mono_r'
indices = [(0, 1), (2, 3), (0, 2), (0, 3), (2, 1), (1, 3)]
fig, axes = pplt.subplots(nrows=6, ncols=2, sharex=False, sharey=False)
for row, (i, j) in enumerate(indices):
    _Z_true = rec.project(Z_true, [i, j])
    _Z = rec.project(Z, [i, j])
    axes[row, 0].pcolormesh(centers[i], centers[j], _Z.T, cmap=cmap)
    axes[row, 1].pcolormesh(centers[i], centers[j], _Z_true.T, cmap=cmap)
    axes[row, 0].annotate('{}-{}'.format(labels[i], labels[j]),
                          xy=(0.02, 0.92), xycoords='axes fraction', color='white')
for ax, title in zip(axes[0, :], ['Reconstructed', 'True', 'Error']):
    ax.set_title(title)
axes.format(xticks=[], yticks=[]);

plt.savefig('_output/rec.png', dpi=300)