In [None]:
import sys
import os

import numpy as np
import scipy.optimize as opt
import pandas as pd
from matplotlib import pyplot as plt, animation
from matplotlib.lines import Line2D
from matplotlib.patches import Ellipse
import proplot as plot
import seaborn as sns
from tqdm import tqdm

sys.path.append('../../')
from tools import animation as myanim
from tools import beam_analysis as ba
from tools import plotting as myplt
from tools import utils

sys.path.append('..')
from data_analysis import to_vec, to_mat, reconstruct
from data_vis import reconstruction_lines

In [None]:
plot.rc['figure.facecolor'] = 'white'
plot.rc['savefig.dpi'] = 'figure'
plot.rc['animation.html'] = 'jshtml'
plot.rc['grid.alpha'] = 0.04
plot.rc['axes.grid'] = False

In [None]:
save_figures = False

def save(figname):
    if save_figures:
        filename = os.path.join('_output/figures', figname + '.png')
        plt.savefig(filename, facecolor='white', dpi=250)

# 4D Measurement in RTBT
> This notebook reconstructs the beam covariance matrix at the entrance of the Ring to Target Beam Transport (RTBT) section of the Spallation Neutron Source (SNS).

<img src="_input/rtbt.png" width=800>

## Method summary

The goal is to reconstruct the transverse beam covariance matrix at position $s = s_0$:

$$
\Sigma_{0} = \begin{bmatrix}
    \langle{x^2}\rangle & \langle{xx'}\rangle & \langle{xy}\rangle & \langle{xy'}\rangle \\
    \langle{xx'}\rangle & \langle{{x'}^2}\rangle & \langle{yx'}\rangle & \langle{x'y'}\rangle \\
    \langle{xy}\rangle & \langle{yx'}\rangle & \langle{y^2}\rangle & \langle{yy'}\rangle \\
    \langle{xy'}\rangle & \langle{x'y'}\rangle & \langle{yy'}\rangle & \langle{{y'}^2}\rangle
\end{bmatrix}.
$$

To do this, a set of $n$ wire-scanners can be placed at positions $\{s_i\} > s_0$ with $i = 1, ..., n$. A single measurement from wire-scanner $i$ will produce the real-space moments of the beam at $s_i$: $\langle{x^2}\rangle_{i}$, $\langle{y^2}\rangle_{i}$, and $\langle{xy}\rangle_{i}$. Without space charge, the transfer matrix $M_{s_0 \rightarrow s_i} = M_i$ is known. The moments at $s_0$ are then directly related to those at $s_i$ by

$$\Sigma_i = M_i \Sigma_{0} {M_i}^T.$$ This gives <br>

$$
\begin{align}
    \langle{x^2}\rangle_i &= 
        m_{11}^2\langle{x^2}\rangle_{0} 
      + m_{12}^2\langle{x'^2}\rangle_{0} 
      + 2m_{11}m_{22}\langle{xx'}\rangle_{0} ,\\
    \langle{y^2}\rangle_i &= 
        m_{33}^2\langle{y^2}\rangle_{0} 
      + m_{34}^2\langle{y'^2}\rangle_{0} 
      + 2m_{33}m_{34}\langle{yy'}\rangle_{0} ,\\
    \langle{xy}\rangle_i &= 
        m_{11}m_{33}\langle{xy}\rangle_{0} 
      + m_{12}m_{33}\langle{yx'}\rangle_{0} 
      + m_{11}m_{34}\langle{xy'}\rangle_{0} 
      + m_{12}m_{34}\langle{x'y'}\rangle_{0} ,
\end{align}
$$

where $m_{lm}$ are the elements of the transfer matrix. Taking 3 measurements with different optics settings between $s_0$ and $s_i$ (and therefore different transfer matrices) gives the 10 equations necessary to solve for $\Sigma_0$; however, real measurements will be noisy, so it is better to take more measurements if possible. Given $N$ measurements, we can form a $3N \times 1$ observation array $b$ from the measured moments and a $3N \times 10$ coefficient array $A$ from the transfer matrix such that

$$\begin{align} \mathbf{A \sigma}_0 = \mathbf{b},\end{align}$$ 

where $\mathbf{\sigma}_0$ is a $10 \times 1$ vector of the moments at $s_0$. There are 5 wire-scanners in the RTBT which operate simultaneously, so if all these are used the coefficient array will be $15N \times 10$. We then choose $\mathbf{\sigma}_0$ such that $|\mathbf{A\sigma}_0 - \mathbf{b}|^2$ is minimized:

$$ \mathbf{\sigma}_0 = (\mathbf{A}^T\mathbf{A})^{-1}\mathbf{A}^T\mathbf{b} $$

## RTBT lattice functions 

In [None]:
twiss = pd.read_csv('_output/data/twiss.dat')
ws_positions = np.loadtxt('_output/data/ws_positions.dat')

In [None]:
fig, ax = plot.subplots(figsize=(7, 2))
twiss[['s','bx','by']].plot('s', ax=ax, legend=False)
ax.format(xlabel='Position [m]', ylabel=r'$\beta$ [m]', toplabels='RTBT lattice functions')
for ws_position in ws_positions:
    ax.axvline(ws_position, color='grey', ls='--', lw=0.5, zorder=0)
ax.format(xlim=(0, twiss['s'].max()))
ax.legend(labels=[r'$\beta_x$', r'$\beta_y$', 'WS'], 
          ncols=1, loc=(1.01, 0), handlelength=1.5, fontsize='small')
plt.savefig('_output/figures/beta.png', facecolor='white', dpi=250)

## Phase scan

In [None]:
ws_names = ['ws02', 'ws20', 'ws21', 'ws23', 'ws24']
active_ws_names = ws_names[1:]

In [None]:
def load(filename, ws_name):
    path = '_output/data/{}/{}'.format(ws_name, filename)
    return np.load(path)

phases_dict, moments_dict, transfer_mats_dict = dict(), dict(), dict()
for ws_name in ws_names:
    transfer_mats_dict[ws_name] = load('transfer_mats.npy', ws_name)
    moments_dict[ws_name] = 1e6 * load('moments.npy', ws_name)
    phases_dict[ws_name] = load('phases.npy', ws_name)

In [None]:
Sigma0 = np.loadtxt('_output/data/Sigma0.dat')
Sigma0 *= 1e6

X0 = np.loadtxt('_output/data/X0.dat')
X0 *= 1e3

Observe the beam at the wire-scanners. Any greyed-out wire-scanners are not used in the reconstruction.

In [None]:
plt_kws = dict(marker='.')
fig, axes = plot.subplots(nrows=2, ncols=5, figsize=(8, 3.5), spany=False)
for ax, ws_name in zip(axes[0, :], ws_names):
    ax.plot(phases_dict[ws_name] % 1, **plt_kws)
    ax.set_title(ws_name, fontsize='large')
for ax, ws_name in zip(axes[1, :], ws_names):
    ax.plot(moments_dict[ws_name][:, 0], **plt_kws)
    ax.plot(moments_dict[ws_name][:, 1], **plt_kws)
    ax.plot(moments_dict[ws_name][:, 2], **plt_kws)
axes[0, 0].legend(labels=[r'$\nu_x$', r'$\nu_y$'], ncols=3);
axes[1, 0].format(ylabel='[mm$^2$]')
axes[1, 0].legend(labels=[r'$\langle{x^2}\rangle$', r'$\langle{y^2}\rangle$', r'$\langle{xy}\rangle$'], fontsize='small', ncols=2);
axes[0, 0].format(ylabel='Frac. phase / ($2\pi$)', xlabel='Scan index', xlabel_kw={'size':'large'}, ylabel_kw={'size':'large'})
plt.savefig('_output/figures/ws_phase_adv.png', facecolor='white', dpi=350)

To do: 
* Add column below showing phase vs scan index.
* Add diagonal line at angle of diagonal wire.

In [None]:
# dims = ('x', 'y')

In [None]:
# str_to_int = {'x':0, 'xp':1, 'y':2, 'yp':3}
# i, j = [str_to_int[dim] for dim in dims]

# ell_coords_list = [[ba.get_ellipse_coords(env_params) for env_params in env_params_dict[ws]] 
#                    for ws in ws_ids]
# ell_coords_list = np.array(ell_coords_list)

# pad = 0.25
# limits_list = np.array([(1 + pad) * myplt.max_u_up_global(coords) 
#                         for coords in ell_coords_list])
# umax, upmax = np.max(limits_list, axis=0)
# limits = 2 * [(-umax, umax), (-upmax, upmax)]
# labels = ['x [mm]', "x' [mrad]", 'y [mm]', "y' [mrad]"]

# fig, ax_list = plt.subplots(ncols=6, figsize=(13, 1.75), sharex=True, sharey=True)
# axes, text_ax = ax_list[:-1], ax_list[-1]
# text_ax.grid(False)
# myplt.despine([text_ax], 'all')
# myplt.despine(axes)
# axes[0].set_xlim(limits[i])
# axes[0].set_ylim(limits[j])
# axes[0].set_yticks(axes[0].get_xticks())
# axes[0].set_xlabel(labels[i])
# axes[0].set_ylabel(labels[j])
# for ax, ws_name in zip(axes, ws_ids):
#     ax.set_title(ws_name, color='grey' if ws_name not in active_ws_ids else 'k')
#     ax.grid(False)
#     ax.axvline(0, lw=0.2, c='k', alpha=0.2, zorder=99)
#     ax.axhline(0, lw=0.2, c='k', alpha=0.2, zorder=99)
# plt.close()

# def update(t):
#     for ax, coords in zip(axes, ell_coords_list):
#         for patch in ax.patches:
#             patch.remove()
#         ax.fill(coords[t, :, i], coords[t, :, j], fc='lightsteelblue', ec='k', zorder=10)
#     for ws, ax in zip(ws_ids, axes):
#         nux, nuy = phases_dict[ws][t]
#         for text in ax.texts:
#             text.set_visible(False)
#         ax.annotate(r'$\nu_x = {:.2f}$'.format(nux), xy=(0.75, 0.85), xycoords='axes fraction', zorder=10)
#         ax.annotate(r'$\nu_y = {:.2f}$'.format(nuy), xy=(0.75, 0.75), xycoords='axes fraction', zorder=10)
#     for text in text_ax.texts:
#         text.set_visible(False)
#     text_ax.annotate('Scan index = {}'.format(t), xy=(0.25, 0.5), xycoords='axes fraction', horizontalalignment='center')

# nframes = ell_coords_list[0].shape[0]
# fps = 1
# anim = animation.FuncAnimation(fig, update, frames=nframes, interval=1000/fps)
# if save_figures:
#      anim.save('_output/figures/ws_envelope.mp4', dpi=300)
# anim

In [None]:
# for ws in ws_ids:
#     print(ws)
#     anim = myanim.corner_env(env_params_dict[ws], figsize=(4, 4), text_fmt='Scan index = {}')
#     play(anim)

In [None]:
myplt.corner(X0, pad=0.1, samples=20000, moments=True, text='Initial bunch');
save('initial_dist')

## Reconstruction

In [None]:
active_ws_names = ws_names[1:]
max_n_meas = 20

In [None]:
moments_list, transfer_mats_list = [], []
for ws_name in active_ws_names:
    transfer_mats_list.extend(transfer_mats_dict[ws_name][:max_n_meas])
    moments_list.extend(moments_dict[ws_name][:max_n_meas])

In [None]:
Sigma = reconstruct(transfer_mats_list, moments_list, verbose=2)

In [None]:
utils.show(Sigma, 'Sigma')
print()
utils.show(Sigma0, 'Sigma0')

axes = myplt.rms_ellipses(Sigma0, color='lightsteelblue', fill=True, lw=0);
axes = myplt.rms_ellipses(Sigma, axes=axes, color='red8', lw=1)
axes[1, 1].legend(labels=['True', 'Reconstructed'], loc=(0, 1.1))
save('projection_default')

## Visualization using lines

In [None]:
axes = myplt.rms_ellipses(Sigma, color='black', alpha=0.15, fill=True, lw=0)
_transfer_mats_dict = {k:v for k,v in transfer_mats_dict.items() 
                       if k in active_ws_names}
_moments_dict = {k:v for k,v in moments_dict.items() 
                       if k in active_ws_names}

reconstruction_lines(axes[2, 2], _transfer_mats_dict, _moments_dict, plane='y-yp')
reconstruction_lines(axes[0, 0], _transfer_mats_dict, _moments_dict, plane='x-xp',
                     legend=True, legend_kws=dict(loc=(1.15, 0)))