# 4D emittance measurement data analysis 

In [None]:
import sys
import os
import copy
import importlib
from tqdm import trange, tqdm
import gif

import numpy as np
import pandas as pd
from scipy import optimize as opt
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from matplotlib import animation
import proplot as pplt
import seaborn as sns

sys.path.append('/Users/46h/Research/')
from accphys.tools import beam_analysis as ba
from accphys.tools import coupling as BL
from accphys.tools import plotting as myplt
from accphys.tools import utils
from accphys.tools.accphys_utils import V_matrix_4x4_uncoupled, phase_adv_matrix

from accphys.emittance_measurement_4D.analysis import reconstruct
from accphys.emittance_measurement_4D.analysis import to_mat, to_vec
from accphys.emittance_measurement_4D.plotting import reconstruction_lines

In [None]:
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['grid.alpha'] = 0.04
plt.rcParams['axes.grid'] = False
plt.rcParams['savefig.dpi'] = 'figure'
plt.rcParams['animation.html'] = 'jshtml'

In [None]:
eps_labels = [r'$\varepsilon_x$', r'$\varepsilon_y$', r'$\varepsilon_1$', r'$\varepsilon_2$']

## Load data 

In [None]:
# folder = './_saved/2021-08-01/production/'
folder = './_saved/2021-08-01/production_scan_phases/'
# folder = './_saved/2021-05-31/'
# folder = './_saved/2021-08-10/scan/'

# We can restrict the number of measurements used if we want.
max_n_meas = 100 

# We can exclude certain wire-scanners if we want.
exclude = None

In [None]:
def load_tmats_dict(filename, exclude=None, max_n_meas=None):
    """Load dictionary of transfer matrix elements at each wire-scanner.
    
    Each line in the file reads [node_id, M11, M12, M13, M14, M21, M22,
    M23, M24, M31, M32, M33, M34, M41, M42, M43, M44].
    """
    tmats_dict = dict()
    file = open(filename, 'r')
    for line in file:
        tokens = line.split()
        ws_id = tokens[0]
        tmat_elems = [float(token) for token in tokens[1:]]
        tmat = np.reshape(tmat_elems, (4, 4))
        if ws_id not in tmats_dict:
            tmats_dict[ws_id] = []
        tmats_dict[ws_id].append(tmat)
    file.close()
    tmats_dict = utils.blacklist(tmats_dict, exclude)
    if max_n_meas:
        for ws_id in tmats_dict:
            tmats_dict[ws_id] = tmats_dict[ws_id][:max_n_meas]
    return tmats_dict

In [None]:
def load_moments_dict(filename, exclude=None, max_n_meas=None):
    """Load dictionary of moments at each wire-scanner.
    
    Each line in the file reads [node_id, <xx>, <yy>, <xy>].
    """
    moments_dict = dict()
    file = open(filename, 'r')
    for line in file:
        tokens = line.split()
        ws_id = tokens[0]
        moments = [float(token) for token in tokens[1:]]
        if ws_id not in moments_dict:
            moments_dict[ws_id] = []
        moments_dict[ws_id].append(moments)
    file.close()
    moments_dict = utils.blacklist(moments_dict, exclude)
    for ws_id in moments_dict:
        if max_n_meas:
            moments_dict[ws_id] = moments_dict[ws_id][:max_n_meas]
        moments_dict[ws_id] = np.array(moments_dict[ws_id])
    return moments_dict

In [None]:
def load_phases_dict(filename, exclude=None, max_n_meas=None):
    """Load dictionary of phases at each wire-scanner.
    
    Each line in the file reads [node_id, mux, muy].
    """
    phases_dict = dict()
    file = open(filename, 'r')
    for line in file:
        ws_id, mux, muy = line.split()
        mux = float(mux)
        muy = float(muy)
        if ws_id not in phases_dict:
            phases_dict[ws_id] = []
        phases_dict[ws_id].append([mux, muy])
    file.close()
    phases_dict = utils.blacklist(phases_dict, exclude)
    for ws_id in phases_dict:
        if max_n_meas:
            phases_dict[ws_id] = phases_dict[ws_id][:max_n_meas]
        phases_dict[ws_id] = np.array(phases_dict[ws_id])
    return phases_dict

In [None]:
tmats_dict = load_tmats_dict(os.path.join(folder, 'transfer_mats.dat'), exclude, max_n_meas)
moments_dict = load_moments_dict(os.path.join(folder, 'moments.dat'), exclude, max_n_meas)
phases_dict = load_phases_dict(os.path.join(folder, 'phase_adv.dat'), exclude, max_n_meas)
ws_ids = sorted(list(tmats_dict))
n_meas = len(tmats_dict[ws_ids[0]])

In [None]:
fig, axes = pplt.subplots(ncols=2, nrows=3, figsize=(7, 4.5), spany=False, sharey=False, wspace=1, aligny=True)
plt_kws = dict(marker='.')
for ws_id in ws_ids:
    moments = moments_dict[ws_id]
    axes[0, 0].plot(moments[:, 0], **plt_kws)
    axes[1, 0].plot(moments[:, 1], **plt_kws)
    axes[2, 0].plot(moments[:, 2], **plt_kws)
    axes[0, 1].plot(np.sqrt(moments[:, 0]), **plt_kws)
    axes[1, 1].plot(np.sqrt(moments[:, 1]), **plt_kws)
    axes[2, 1].plot(moments[:, 2] / np.sqrt(moments[:, 0] * moments[:, 1]), **plt_kws)
axes[0, 0].format(ylabel=r'$\langle{x^2}\rangle$ [mm$^2$]')
axes[1, 0].format(ylabel=r'$\langle{y^2}\rangle$ [mm$^2$]')
axes[2, 0].format(ylabel=r'$\langle{xy}\rangle$ [mm$^2$]')
axes[0, 1].format(ylabel=r'$\sqrt{\langle{x^2}\rangle}$ [mm]')
axes[1, 1].format(ylabel=r'$\sqrt{\langle{y^2}\rangle}$ [mm]')
axes[2, 1].format(ylabel=r'x-y corr. coeff.')
for ax in axes:
    ax.grid(axis='y', alpha=0.1)
axes.format(xlabel='Measurement index', xtickminor=False, xticks=range(len(moments)))
axes[0, 1].legend(labels=ws_ids, ncols=1, loc=(1.02, 0), fontsize='small')
axes[0, 0].set_title('Measured moments', fontsize='medium')
axes[0, 1].set_title('Measured rms size and x-y corr.', fontsize='medium')
plt.savefig('_output/moments.png', facecolor='white', dpi=300)

In [None]:
fig, axes = pplt.subplots(ncols=2)
for ws_id in ws_ids:
    mux, muy = phases_dict[ws_id].T
    axes[0].plot(mux, **plt_kws)
    axes[1].plot(muy, **plt_kws)
axes.format(xlabel='Measurement index', ylabel='[rad]', ylim=(0, 2 * np.pi),
            xtickminor=False, xticks=range(n_meas))
axes[0].set_title('Horizontal phase adv.')
axes[1].set_title('Vertical phase adv.')
axes[1].legend(labels=ws_ids, ncols=1, loc=(1.02, 0), fontsize='small')
for ax in axes:
    ax.grid(axis='y')
plt.savefig('_output/phases.png', facecolor='white', dpi=300)

In [None]:
fig, axes = pplt.subplots(ncols=3)
colors = ['red8', 'blue8', 'green8']
for i in range(len(ws_ids) - 1):
    diffs = phases_dict[ws_ids[i + 1]] - phases_dict[ws_ids[i]]
    for j in range(len(diffs)):
        for k in range(2):
            if diffs[j, k] < 0:
                diffs[j, k] += 2 * np.pi
    axes[0].plot(np.degrees(diffs[:, 0]), color=colors[i], marker='.')
    axes[1].plot(np.degrees(diffs[:, 1]), color=colors[i], marker='.')
    axes[2].plot(np.degrees(np.abs(diffs[:, 0] - diffs[:, 1])), color=colors[i], marker='.')
for ax in axes:
    ax.grid(axis='y')
labels = [r'{} $\rightarrow$ {}'.format(ws_ids[i], ws_ids[i + 1]) for i in range(len(ws_ids) - 1)]
axes[2].legend(labels=labels, ncol=1, loc=(1.02, 0))
axes[0].set_title(r'Horizontal: $\Delta\mu_x$')
axes[1].set_title(r'Vertical: $\Delta\mu_y$')
axes[2].set_title(r'Difference: $\left| \Delta\mu_x - \Delta\mu_y \right|$')
axes.format(ylabel='Gap', yformatter='deg', suptitle='Wire-scanner phase spacing',
            xlabel='Measurement index', xtickminor=False, xticks=range(n_meas))

In [None]:
fig, axes = pplt.subplots(ncols=4, figsize=(8, 2))
for ax, ws_id in zip(axes, ws_ids):
    ax.set_title(ws_id)
    for sig_xx, sig_yy, sig_xy in moments_dict[ws_id]:
        angle = -0.5 * np.arctan2(2*sig_xy, sig_xx-sig_yy)
        sn, cs = np.sin(angle), np.cos(angle)
        c1 = np.sqrt(abs(sig_xx*cs**2 + sig_yy*sn**2 - 2*sig_xy*sn*cs))
        c2 = np.sqrt(abs(sig_xx*sn**2 + sig_yy*cs**2 + 2*sig_xy*sn*cs))
        myplt.ellipse(ax, c1, c2, angle)
axes.format(xlim=(-30, 30), ylim=(-30, 30), xlabel='x [mm]', ylabel='y [mm]',
            suptitle='Measured rms ellipses in x-y plane')
plt.savefig('_output/corr.png', facecolor='white', dpi=300)

## Reconstruction

In [None]:
tmats_list, moments_list = [], []
for ws_id in ws_ids:
    tmats_list.extend(tmats_dict[ws_id])
    moments_list.extend(moments_dict[ws_id])

In [None]:
Sigma = reconstruct(tmats_list, moments_list, verbose=2)

In [None]:
Corr = utils.cov2corr(Sigma)
alpha_x, alpha_y, beta_x, beta_y = ba.twiss2D(Sigma)
eps_x, eps_y, eps_1, eps_2 = ba.emittances(Sigma)
coupling_coeff = 1.0 - np.sqrt((eps_1 * eps_2) / (eps_x * eps_y))

print('Sigma =')
print(Sigma)
print('Corr =')
print(Corr)
print('eps_4D = {:.3f}'.format(np.sqrt(np.linalg.det(Sigma))))
print('eps_1, eps_2 = {:.3f}, {:.3f}'.format(eps_1, eps_2))
print('eps_x, eps_y = {:.3f}, {:.3f}'.format(eps_x, eps_y))
print('alpha_x, alpha_y = {:.3f}, {:.3f}'.format(alpha_x, alpha_y))
print('beta_x, beta_y = {:.3f}, {:.3f}'.format(beta_x, beta_y))
print('Coupling coefficient = {}'.format(coupling_coeff))

In [None]:
norm = '2D'
V = np.identity(4)
if norm == '2D':
    alpha_x, alpha_y, beta_x, beta_y = ba.twiss2D(Sigma)
    V = V_matrix_4x4_uncoupled(alpha_x, alpha_y, beta_x, beta_y)
elif norm == '4D':
    U = np.array([[0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]])
    eigvals, eigvecs = np.linalg.eig(np.matmul(Sigma, U))
    V = BL.construct_V(eigvecs)
Vinv = np.linalg.inv(V)
Sigma_n = np.linalg.multi_dot([Vinv, Sigma, Vinv.T])

In [None]:
axes = myplt.rms_ellipses(Sigma, color='black', pad=0.75, alpha=0.15, fill=True, lw=0, zorder=0)
plt.suptitle('Reconstructed phase space')
reconstruction_lines(axes[2, 2], tmats_dict, moments_dict, plane='y-yp')
reconstruction_lines(axes[0, 0], tmats_dict, moments_dict, plane='x-xp',
                     legend=True, legend_kws=dict(loc=(1.15, 0)))
plt.savefig('_output/corner.png', dpi=300, facecolor='white')

In [None]:
axes = myplt.rms_ellipses(Sigma_n, color='black', pad=0.75, alpha=0.15, fill=True, lw=0, zorder=0)
plt.suptitle('Reconstructed phase space (normalized {})'.format(norm))
reconstruction_lines(axes[2, 2], tmats_dict, moments_dict, plane='y-yp', norm_mat=V)
reconstruction_lines(axes[0, 0], tmats_dict, moments_dict, plane='x-xp', norm_mat=V,
                     legend=True, legend_kws=dict(loc=(1.15, 0)))
plt.savefig('_output/corner_norm.png', dpi=300, facecolor='white')

## Errors

In [None]:
Sigma_all_meas = np.copy(Sigma)
emittances_all_meas = ba.emittances(Sigma_all_meas)
eps_x_all_meas, eps_y_all_meas, eps_1_all_meas, eps_2_all_meas = emittances_all_meas

In [None]:
def plot_with_error_bars(data, labels, correct=None):
    data = np.array(data)
    means, stds = np.mean(data, axis=0), np.std(data, axis=0)
    n_rows, n_cols = data.shape
    
    plt_kws = dict(marker='.', lw=0)
    if data.ndim == 1:
        colors = ['black']
    else:
        colors = copy.copy(myplt.DEFAULT_COLORCYCLE)
        colors[2], colors[3] = colors[3], colors[2]

    fig, ax = pplt.subplots(figsize=None)
    for i in range(n_cols):
        ax.errorbar(range(n_rows), data[:, i], yerr=stds[i], capsize=3, elinewidth=1, color=colors[i], **plt_kws)
        if correct:
            ax.axhline(correct[i], color=colors[i], alpha=0.2)
    ax.format(xlabel='Measurement index', ylabel='[mm mrad]', xtickminor=False, xticks=range(len(emittances)))

    ax.legend([Line2D([0], [0], c=c, **plt_kws) for c in colors[:4]], labels, ncol=1, loc='r')
    line = []
    for i in range(4):
        line.append(r'{} = {:.3f} $\pm$ {:.3f} [mm mrad]'.format(labels[i], means[i], stds[i]))
        line.append('\n')
    ax.set_title(''.join(line[:-1]))
    return ax

In [None]:
def form_coeff_target_arrays(meas_index):
    Axx, Ayy, Axy = [], [], []
    bxx, byy, bxy = [], [], []
    for ws_id in ws_ids:
        M = tmats_dict[ws_id][meas_index]
        sig_xx, sig_yy, sig_xy = moments_dict[ws_id][meas_index]
        Axx.append([M[0, 0]**2, M[0, 1]**2, 2*M[0, 0]*M[0, 1]])
        Ayy.append([M[2, 2]**2, M[2, 3]**2, 2*M[2, 2]*M[2, 3]])
        Axy.append([M[0, 0]*M[2, 2],  M[0, 1]*M[2, 2],  M[0, 0]*M[2, 3],  M[0, 1]*M[2, 3]])
        bxx.append(sig_xx)
        byy.append(sig_yy)
        bxy.append(sig_xy) 
    Axx, Ayy, Axy = np.array(Axx), np.array(Ayy), np.array(Axy)
    bxx, byy, bxy = np.array(bxx), np.array(byy), np.array(bxy)
    return Axx, Ayy, Axy, bxx, byy, bxy

### Bounds on cross-plane moments

In [None]:
def solve_corr_coeff_bounds(meas_index, **solver_kws):
    Axx, Ayy, Axy, bxx, byy, bxy = form_coeff_target_arrays(meas_index)
    sig_11, sig_22, sig_12 = opt.lsq_linear(Axx, bxx).x
    sig_33, sig_44, sig_34 = opt.lsq_linear(Ayy, byy).x  

    def cost_func(vec):
        return np.sum((np.matmul(Axy, vec) - bxy)**2)

    r_13_denom = np.sqrt(sig_11 * sig_33)
    r_23_denom = np.sqrt(sig_22 * sig_33)
    r_14_denom = np.sqrt(sig_11 * sig_44)
    r_24_denom = np.sqrt(sig_22 * sig_44)
    lb = [-r_13_denom, -r_23_denom, -r_14_denom, -r_24_denom]
    ub = [+r_13_denom, +r_23_denom, +r_14_denom, +r_24_denom]
    guess = [0, lb[1], ub[2], 0]
    result = opt.least_squares(cost_func, guess, bounds=(lb, ub), **solver_kws)
    
    sig_13, sig_23, sig_14, sig_24 = result.x
    S = to_mat([sig_11, sig_22, sig_12, sig_33, sig_44, sig_34, sig_13, sig_23, sig_14, sig_24])
    return ba.emittances(S)

In [None]:
emittances = [solve_corr_coeff_bounds(meas_index) for meas_index in trange(n_meas)]    
plot_with_error_bars(emittances, eps_labels, correct=emittances_all_meas);

The method seems to "work" if I guess close to the correct answer.

### Edwards-Teng

We can parameterize the covariance matrix as 

$$
\mathbf{\Sigma} = \mathbf{V} \, \mathbf{C} \, \mathbf{\Sigma}_n \, \mathbf{C}^T \, \mathbf{V}^T,
$$

with

$$
\mathbf{\Sigma}_n = 
\begin{bmatrix}
    \varepsilon_1 & 0 & 0 & 0 \\
    0 & \varepsilon_1 & 0 & 0 \\
    0 & 0 & \varepsilon_2 & 0 \\
    0 & 0 & 0 & \varepsilon_2
\end{bmatrix}, \\
$$

$$
\mathbf{C} = 
\begin{bmatrix}
    1 & 0 & a & b \\
    0 & 1 & c & d \\
    -d & b & 1 & 0 \\
    c & a & 0 & 1
\end{bmatrix}, \\
$$

$$
\mathbf{V} = 
\begin{bmatrix}
    \sqrt{\beta_x} & 0 & 0 & 0 \\
    -\frac{\alpha_x}{\sqrt{\beta_x}} & \frac{1}{\sqrt{\beta_x}} & 0 & 0 \\
    0 & 0 & \sqrt{\beta_y} & 0 \\
    0 & 0 & -\frac{\alpha_y}{\sqrt{\beta_y}} & \frac{1}{\sqrt{\beta_y}}
\end{bmatrix}. \\
$$

$\mathbf{C}$ is symplectic if $ad - bc = 0$, so we set $d = bc/a$. If $a = b = c = d = 0$, then we get the normal uncoupled matrix with $\varepsilon_{x,y} = \varepsilon_{1,2}$. 

The strategy is to fit the $x$-$x'$ and $y$-$y'$ planes first, then choose $\varepsilon_1$, $\varepsilon_2$, $a$, $b$, and $c$.

In [None]:
def get_cov(eps_1, eps_2, alpha_x, alpha_y, beta_x, beta_y, a, b, c):
    E = np.diag([eps_1, eps_1, eps_2, eps_2])
    V = V_matrix_4x4_uncoupled(alpha_x, alpha_y, beta_x, beta_y)
    if a == 0:
        if b == 0 or c == 0:
            d = 0
        else:
            raise ValueError("a is zero but b * c is not zero.")
    else:
        d = b * c / a
    C = np.array([[1, 0, a, b], [0, 1, c, d], [-d, b, 1, 0], [c, -a, 0, 1]])
    return np.linalg.multi_dot([V, C, E, C.T, V.T])

In [None]:
def solve_edwards_teng(meas_index, **solver_kws):
    Axx, Ayy, Axy, bxx, byy, bxy = form_coeff_target_arrays(meas_index)
    sig_11, sig_22, sig_12 = opt.lsq_linear(Axx, bxx).x
    sig_33, sig_44, sig_34 = opt.lsq_linear(Ayy, byy).x 
    S = np.zeros((4, 4))
    S[:2, :2] = [[sig_11, sig_12], [sig_12, sig_22]]
    S[2:, 2:] = [[sig_33, sig_34], [sig_34, sig_44]]
    eps_x, eps_y = ba.apparent_emittances(S)
    alpha_x, alpha_y, beta_x, beta_y = ba.twiss2D(S)

    def cost_func(vec):
        eps_1, eps_2, a, b, c = vec
        S = get_cov(eps_1, eps_2, alpha_x, alpha_y, beta_x, beta_y, a, b, c)
        vec_xy = np.array([S[0, 2], S[1, 2], S[0, 3], S[1, 3]])
        cost = np.sum((np.matmul(Axy, vec_xy) - bxy)**2)
        f = 1.0
        cost += f * (S[0, 0] - sig_11)**2
        cost += f * (S[0, 1] - sig_12)**2
        cost += f * (S[1, 1] - sig_22)**2
        cost += f * (S[2, 2] - sig_33)**2
        cost += f * (S[2, 3] - sig_34)**2
        cost += f * (S[3, 3] - sig_44)**2
        return cost

    lb = [0., 0., -np.inf, -np.inf, -np.inf]
    ub = np.inf
    guess = [np.random.random() * eps_x, np.random.random() * eps_y, 
             np.random.uniform(-10, 10), np.random.uniform(-10, 10), np.random.uniform(-10, 10)]
    result = opt.minimize(cost_func, guess, bounds=opt.Bounds(lb, ub), **solver_kws)
    eps_1, eps_2, a, b, c = result.x
    S = get_cov(eps_1, eps_2, alpha_x, alpha_y, beta_x, beta_y, a, b, c)
    S[:2, :2] = [[sig_11, sig_12], [sig_12, sig_22]]
    S[2:, 2:] = [[sig_33, sig_34], [sig_34, sig_44]]
    return ba.emittances(S)

In [None]:
solver_kws = dict(method='trust-constr', options=dict(verbose=1, xtol=1e-15, maxiter=1000))
emittances = [solve_edwards_teng(meas_index, **solver_kws) for meas_index in trange(n_meas)]
plot_with_error_bars(emittances, eps_labels, correct=emittances_all_meas);

Again, it depends on how good my initial guess is.

### Cholesky decomposition 

In [None]:
def inverse_cholesky(x):
    L = np.array([[x[0], 0, 0, 0],
                  [x[1], x[2], 0, 0],
                  [x[3], x[4], x[5], 0],
                  [x[6], x[7], x[8], x[9]]])
    return np.matmul(L, L.T)

In [None]:
def solve_cholesky(meas_index, **solver_kws):
    Axx, Ayy, Axy, bxx, byy, bxy = form_coeff_target_arrays(meas_index)
    sig_11, sig_22, sig_12 = opt.lsq_linear(Axx, bxx).x
    sig_33, sig_44, sig_34 = opt.lsq_linear(Ayy, byy).x  

    def cost_func(x):
            S = inverse_cholesky(x)
            vec_xx = np.array([S[0, 0], S[1, 1], S[0, 1]])
            vec_yy = np.array([S[2, 2], S[3, 3], S[2, 3]])
            vec_xy = np.array([S[0, 2], S[1, 2], S[0, 3], S[1, 3]])
            cost = 0.
            cost += np.sum((np.matmul(Axx, vec_xx) - bxx)**2)
            cost += np.sum((np.matmul(Ayy, vec_yy) - byy)**2)
            cost += np.sum((np.matmul(Axy, vec_xy) - bxy)**2)
            return cost

    lb = -np.inf
    ub = +np.inf
    guess = np.random.uniform(-10, 10, size=10)
    result = opt.minimize(cost_func, guess, bounds=opt.Bounds(lb, ub), **solver_kws)
    S = inverse_cholesky(result.x)
    return ba.emittances(S)

In [None]:
solver_kws = dict(method='trust-constr', options=dict(verbose=1, maxiter=2000))
emittances = [solve_cholesky(meas_index, **solver_kws) for meas_index in trange(n_meas)]
plot_with_error_bars(emittances, eps_labels, correct=emittances_all_meas);

### Simple simulation

Let's simulate the measurement very simply. We can start with the "correct" $\Sigma$ at the reconstruction point (the one we reconstructed using all the measurements). We'll transport this to each wire-scanner using the transfer matrices. We can add some error to the measured moments and try the reconstruction method. 

In [None]:
def scatter_hist(x, y, ax, ax_marg_x, ax_marg_y, joint_kws=None, marginal_kws=None):
    if joint_kws is None:
        joint_kws = dict()
    if marginal_kws is None:
        marginal_kws = dict()        
    if 'range' in joint_kws:
        xrange, yrange = joint_kws['range']
    else:
        xrange = yrange = None
    if 'bins' in joint_kws:
        bins = joint_kws['bins']
    else:
        heights, edges, patches = ax_marg_x.hist(x, bins='auto', range=xrange, **marginal_kws)
        for patch in patches:
            patch.set_visible(False)
        bins = len(heights)
        joint_kws['bins'] = bins
    ax_marg_x.hist(x, range=xrange, bins=bins, **marginal_kws)
    ax_marg_y.hist(y, range=yrange, bins=bins, orientation='horizontal', **marginal_kws)
    ax.hist2d(x, y, **joint_kws)
    return ax

def create_grid(fig, gridspec, row, col):
    ax_joint = fig.add_subplot(gridspec[row, col])
    ax_marg_x = fig.add_subplot(gridspec[row - 1, col])
    ax_marg_y = fig.add_subplot(gridspec[row, col + 1])
    for ax in [ax_marg_x, ax_marg_y]:
        ax.set_xticks([])
        ax.set_yticks([])
        for side in ['top', 'bottom', 'left', 'right']:
            ax.spines[side].set_visible(False)
    return ax_joint, ax_marg_x, ax_marg_y

def emittances_joint_hist(emittances, lims=((10, 40), (0, 20))):
    fig = plt.figure(figsize=(10, 4))
    h = 1.5
    gridspec = fig.add_gridspec(2, 5, width_ratios=(7, h, 2.5, 7, h), height_ratios=(h, 7),
                                left=0.1, right=0.9, bottom=0.1, top=0.9,
                                wspace=0, hspace=0)
    ax1, ax1_marg_x, ax1_marg_y = create_grid(fig, gridspec, 1, 0)
    ax2, ax2_marg_x, ax2_marg_y = create_grid(fig, gridspec, 1, 3)

    joint_kws = dict(cmap='fire_r', range=lims, bins=75)
    marginal_kws = dict(histtype='step', color='black')

    scatter_hist(emittances[:, 0], emittances[:, 1], ax1, ax1_marg_x, ax1_marg_y, joint_kws, marginal_kws)
    scatter_hist(emittances[:, 2], emittances[:, 3], ax2, ax2_marg_x, ax2_marg_y, joint_kws, marginal_kws)
    ax1_marg_x.set_xlim(lims[0])
    ax1_marg_y.set_ylim(lims[1])
    ax2_marg_x.set_xlim(lims[0])
    ax2_marg_y.set_ylim(lims[1])
    ax1.set_xlabel(r'$\varepsilon_x$ [mm mrad]')
    ax1.set_ylabel(r'$\varepsilon_y$ [mm mrad]')
    ax2.set_xlabel(r'$\varepsilon_1$ [mm mrad]')
    ax2.set_ylabel(r'$\varepsilon_2$ [mm mrad]')
    ax1_marg_x.set_title(r'Apparent emittances ($\varepsilon_x$, $\varepsilon_y$)')
    ax2_marg_x.set_title(r'Intrinsic emittances ($\varepsilon_1$, $\varepsilon_2$)')      
    return ax1, ax1_marg_x, ax1_marg_y, ax2, ax1_marg_y, ax1_marg_y

In [None]:
def get_moments(Sigma0, tmats, f=None):
    """Return [<xx>, <yy>, <xy>] at each wire-scanner.
    
    `f` is the rms fractional error added to the virtually measured moments. 
    """
    moments = []
    for M in tmats:            
        Sigma = np.linalg.multi_dot([M, Sigma0, M.T])
        sig_xx = Sigma[0, 0]
        sig_yy = Sigma[2, 2]
        sig_xy = Sigma[0, 2]
        sig_uu = 0.5 * (2 * sig_xy + sig_xx + sig_yy)
        if f:
            sig_xx *= (1.0 + np.random.normal(scale=f))
            sig_yy *= (1.0 + np.random.normal(scale=f))
            sig_uu *= (1.0 + np.random.normal(scale=f))
        sig_xy = 0.5 * (2 * sig_uu - sig_xx - sig_yy)
        moments.append([sig_xx, sig_yy, sig_xy])
    return moments

In [None]:
def solve(tmats, moments, method='llsq'):
    Axx, Ayy, Axy = [], [], []
    bxx, byy, bxy = [], [], []
    for M, (sig_xx, sig_yy, sig_xy) in zip(tmats, moments):            
        Axx.append([M[0, 0]**2, M[0, 1]**2, 2*M[0, 0]*M[0, 1]])
        Ayy.append([M[2, 2]**2, M[2, 3]**2, 2*M[2, 2]*M[2, 3]])
        Axy.append([M[0, 0]*M[2, 2],  M[0, 1]*M[2, 2],  M[0, 0]*M[2, 3],  M[0, 1]*M[2, 3]])
        bxx.append(sig_xx)
        byy.append(sig_yy)
        bxy.append(sig_xy)
        
    sig_11, sig_22, sig_12 = opt.lsq_linear(Axx, bxx).x
    sig_33, sig_44, sig_34 = opt.lsq_linear(Ayy, byy).x        
    
    if method == 'llsq':
        sig_13, sig_23, sig_14, sig_24 = opt.lsq_linear(Axy, bxy).x
        Sigma = to_mat([sig_11, sig_22, sig_12, sig_33, sig_44, sig_34, sig_13, sig_23, sig_14, sig_24])
        return Sigma
    else:
        raise ValueError("`method` must be in {'llsq'}")

In [None]:
def run_trials(Sigma0, tmats, n_trials, method='llsq', f=None, pbar=False):
    emittances, n_fail = [], 0
    for _ in (trange(n_trials) if pbar else range(n_trials)):
        moments = get_moments(Sigma0, tmats, f)
        Sigma = solve(tmats, moments, method)
        if not utils.is_positive_definite(Sigma):
            n_fail += 1
            continue
        eps_x, eps_y, eps_1, eps_2 = ba.emittances(Sigma)
        emittances.append(ba.emittances(Sigma))
    fail_rate = n_fail / n_trials
    return fail_rate, np.array(emittances)

In [None]:
f = 0.05
n_trials = 10000
method = 'llsq'

fail_rates, emittances_list = [], []
for meas_index in range(n_meas):
    tmats = [tmats_dict[ws_id][meas_index] for ws_id in ws_ids]
    fail_rate, emittances = run_trials(Sigma_all_meas, tmats, n_trials, method, f, pbar=True)
    fail_rates.append(fail_rate)
    emittances_list.append(emittances)

In [None]:
fig, ax = pplt.subplots(figsize=(3, 2))
ax.plot(fail_rates, color='black', marker='.')
ax.format(xlabel='Measurement index', ylabel='Fail rate', 
          ylim=(0, 1), xtickminor=False, xticks=range(len(fail_rates)))

In [None]:
line_kws = dict(color='white', lw=0.1)
hist_kws = dict(cmap='viridis', range=((0, 60), (0, 60)), bins=50)
fig, axes = plt.subplots(ncols=2, figsize=(6, 3), sharex=True, sharey=True, constrained_layout=True)
plt.close()

def update(meas_index):
    fail_rate = fail_rates[meas_index]
    emittances = emittances_list[meas_index]
    for j, ax in zip([0, 2], axes):
        ax.clear()
        ax.hist2d(emittances[:, j], emittances[:, j + 1], **hist_kws)
    axes[0].axvline(eps_x_all_meas, **line_kws)
    axes[0].axhline(eps_y_all_meas, **line_kws)
    axes[1].axvline(eps_1_all_meas, **line_kws)
    axes[1].axhline(eps_2_all_meas, **line_kws)
    axes[0].set_xlabel(r'$\varepsilon_x$ [mm mrad]')
    axes[0].set_ylabel(r'$\varepsilon_y$ [mm mrad]')
    axes[1].set_xlabel(r'$\varepsilon_1$ [mm mrad]')
    axes[1].set_ylabel(r'$\varepsilon_2$ [mm mrad]')
    axes[0].set_title('Apparent emittances')
    axes[1].set_title('Intrinsic emittances');
    axes[1].annotate('f = {}'.format(f), xy=(0.05, 0.92), xycoords='axes fraction', 
                     color='white', fontsize='small')
    axes[1].annotate('fail rate = {:.3f}'.format(fail_rate), xy=(0.05, 0.87), xycoords='axes fraction', 
                     color='white', fontsize='small')
    axes[1].annotate('meas index = {}'.format(meas_index), xy=(0.05, 0.82), xycoords='axes fraction', 
                     color='white', fontsize='small')
    
animation.FuncAnimation(fig, update, frames=n_meas)

In [None]:
# Sigma0 = np.copy(Sigma)
# meas_index = 5
# n_trials = 1000
# dsig_list = np.linspace(0.0, 0.05, 6)
# dphi = np.radians(0.0)

# # Use real transfer matrices
# tmats = [tmats_dict[ws_id][meas_index] for ws_id in ws_ids]

# emittances_list, fail_rates = [], []
# for dsig in dsig_list:
#     fail_rate, emittances = run(Sigma0, tmats, n_trials, dsig, dphi, pbar=True)
#     fail_rates.append(fail_rate)
#     emittances_list.append(emittances)

In [None]:
# fig, ax = plot.subplots(figsize=(3, 2))
# ax.plot(dsig_list, fail_rates, marker='.', color='k')
# ax.format(xlabel='RMS fractional error in measured moments',
#           ylabel='Fail rate', grid=True,
#           ylim=(0, 1))
# plt.savefig('_output/failrate_vs_meas_error_{}'.format(meas_index), facecolor='white')

In [None]:
#### @gif.frame
# def plot_errors(dsig, fail_rate, emittances):
#     ax1, ax1_marg_x, ax1_marg_y, ax2, ax1_marg_y, ax1_marg_y = emittances_joint_hist(emittances, lims=((15, 60), (0, 40)))
#     ax2.annotate('fail rate = {:.2f}'.format(fail_rate), 
#              xy=(0.03, 0.93), xycoords='axes fraction', color='white')  
#     ax2.annotate(r'rms frac err = {:.0f}%'.format(100 * dsig), 
#              xy=(0.03, 0.87), xycoords='axes fraction', color='white') 
#     line_kws = dict(color='white', lw=0.25, alpha=0.5)
#     ax1.axvline(eps_x, **line_kws)
#     ax1.axhline(eps_y, **line_kws)
#     ax2.axvline(eps_1, **line_kws)
#     ax2.axhline(eps_2, **line_kws)
#     figname = '_output/results_dsig{:.0f}%.png'.format(100*dsig)
#     plt.savefig(figname, facecolor='white', dpi=250)
#     plt.show()
    
    
# frames = []
# for dsig, fail_rate, emittances in zip(dsig_list, fail_rates, emittances_list):
#     frame = plot_errors(dsig, fail_rate, emittances)
#     frames.append(frame)

## Stability conditions

Let "individual reconstruction" refer to the reconstruction with only four wire-scanners at fixed optics. The "badness" of an individual reconstruction seems to depend on the optics. In fact, in some cases it seems to vary smoothly with the measurement index. 

This has apparently been studied before. Below, we investigate the stability conditions derived in this previous work.

In [None]:
def cond(A):
    return np.linalg.norm(A) * np.linalg.norm(np.linalg.inv(A))

In [None]:
condition_numbers = []
for meas_index in range(n_meas):
    Axx, Ayy, Axy, bxx, byy, bxy = form_coeff_target_arrays(meas_index)
    condition_numbers.append(cond(Axy))

In [None]:
fig, ax1 = pplt.subplots(figsize=(3, 2))
colors = ['black', 'red8']
ax1.plot(fail_rates, color=colors[0], marker='.')
ax1.set_ylabel('Fail rate', color=colors[0])
ax1.tick_params(axis='y', labelcolor=colors[0])
ax1.format(xlabel='Measurement index', xtickminor=False, xticks=range(len(fail_rates)), ylim=(0, 1))
ax2 = ax1.twinx()
ax2.plot(np.log10(condition_numbers), color=colors[1], marker='.')
ax2.set_ylabel(r'log$_{10}$ Condition number', color=colors[1])
ax2.tick_params(axis='y', labelcolor=colors[1])

### Regions of stability 

In [None]:
def coeff_arrays(tmats):
    Axx, Ayy, Axy = [], [], []
    for M in tmats:    
        Axx.append([M[0, 0]**2, M[0, 1]**2, 2*M[0, 0]*M[0, 1]])
        Ayy.append([M[2, 2]**2, M[2, 3]**2, 2*M[2, 2]*M[2, 3]])
        Axy.append([M[0, 0]*M[2, 2],  M[0, 1]*M[2, 2],  M[0, 0]*M[2, 3],  M[0, 1]*M[2, 3]])
    return np.array(Axx), np.array(Ayy), np.array(Axy)

In [None]:
N = 4 # four wire-scanners

n_steps = 300
dmuxx = np.radians(np.linspace(0.01, 181, n_steps))
dmuyy = np.radians(np.linspace(0.01, 181, n_steps))

cond_arr = np.zeros((n_steps, n_steps))
for i, dmux in enumerate(tqdm(dmuxx)):
    for j, dmuy in enumerate(dmuyy):
        muxx = np.cumsum(np.full(N, dmux)) - dmux
        muyy = np.cumsum(np.full(N, dmuy)) - dmuy
        tmats = [phase_adv_matrix(mux, muy) for mux, muy in zip(muxx, muyy)]
        Axx, Ayy, Axy = coeff_arrays(tmats)
        cond_arr[i, j] = cond(Axy)

In [None]:
fig, ax = plt.subplots()
X, Y = np.meshgrid(np.degrees(dmuxx), np.degrees(dmuyy))
Z = 2 / (1 + cond_arr)
mesh = ax.pcolormesh(X, Y, Z, snap=True, cmap='binary_r', shading='auto')
cbar = fig.colorbar(mesh, ax=ax)
cbar.set_label(r'1 / (2 + cond($A_{xy}$))')
ax.set_xlabel('Horizontal phase advance gaps [deg]')
ax.set_ylabel('Vertical phase advance gaps [deg]')
ax.set_title(''.join([r'Stability regions (evenly spaced WS)',
                      '\n',
                      r'$\beta = 1$, $\alpha = 0$ (M = rotation matrix)'
                      '\n',
                      'Dark regions are unstable'
                     ]));

We want to create a picture like this for the RTBT wire-scanners. The three phase advances between the wire-scanners in the RTBT are not equal. They are also no located at points where the design $\beta$ function is maximum/minimum.

In [None]:
def augmented_matrix(A, b):
    return np.hstack([A, b[:, np.newaxis]])

In [None]:
meas_index = 0
Axx, Ayy, Axy, bxx, byy, bxy = form_coeff_target_arrays(meas_index)
Axx_star = augmented_matrix(Axx, bxx)
Ayy_star = augmented_matrix(Ayy, byy)
Axy_star = augmented_matrix(Axy, bxy)

def split_submatrix(A, submat_shape):
    p, q = submat_shape      # Store submatrix shape
    m, n = A.shape
    A4D = A.reshape(-1, p, n//q, q).transpose(0, 2, 1, 3).reshape(-1, p, q)
    return np.array_split(A4D, x.size/(p*q), axis=0)

A = np.array([[1, 2, 3, 4], [4, 5, 6, 4], [7, 8, 9, 4], [1, 2, 3, 4]])
split_submatrix(A, (3, 3))