# Setup

## Imports 

In [None]:
import sys
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
import proplot as plot

sys.path.append('/Users/46h/Research/code/accphys') 
from tools import beam_analysis as ba
from tools import plotting as myplt
from tools.plotting import save, set_labels
from tools import animation as myanim
from tools import utils
from tools.utils import show, play, file_exists
from tools.accphys_utils import get_phase_adv
from tools.plot_utils import moment_label, moment_label_string

## Settings

### Plotting 

In [None]:
# Plotting
plt_kws = dict(legend=False, xlabel='s / L')
plot.rc['figure.facecolor'] = 'white'
plot.rc['grid.alpha'] = 0.04
plot.rc['style'] = None 
plot.rc['savefig.dpi'] = 'figure' 
plot.rc['animation.html'] = 'jshtml'
dpi = 500

# Animation
animate = True
skip = 10
fps = 3

In [None]:
width, height = 3.5, 2.5

def setup_figure(opt=1):
    """Convenience function to create subplots."""
    if opt == 1:
        nrows, ncols, figsize = 1, 1, (width, height)
    elif opt == 2:
        nrows, ncols, figsize = 3, 1, (width, 2*height)
    elif opt == 3:
        nrows, ncols, figsize = 3, 2, (1.6*width, 2*height)
    elif opt == 4:
        nrows, ncols, figsize = 1, 2, (7, 2.5)
    fig, axes = plot.subplots(nrows=nrows, ncols=ncols, figsize=figsize, spany=False, aligny=True)
    axes.format(xlabel='s / L')
    return fig, axes

### Files 

In [None]:
# Data file locations
files = {
    'positions': '_output/data/position.npy', 
    'env_params': '_output/data/envelope/env_params.npy',
    'testbunch_coords': '_output/data/envelope/testbunch_coords.npy',
    'bunch_coords': '_output/data/bunch/bunch_coords.npy',
    'bunch_moments': '_output/data/bunch/bunch_moments.npy',
    'transfer_matrix': '_output/data/transfer_matrix.npy'
}
# Check if files exist
files_exist = {key: file_exists(file) for key, file in files.items()}

# Directories in which to save the figures
dirs = {
    'env': './_output/figures/envelope/',
    'bunch': './_output/figures/bunch/',
    'comparison': './_output/figures/comparison/'
}

# Bare lattice optics 

In [None]:
lattice_twiss = np.load('_output/data/twiss.npy')
lattice_twiss = pd.DataFrame(lattice_twiss, columns=['s','nux','nuy','ax','ay','bx','by'])

fig, ax = plot.subplots(figsize=(8, 2))
lattice_twiss[['s','bx','by']].plot('s', ax=ax, legend=False)
ax.format(xlabel='s [m]', ylabel=r'$\beta$ [m]')
ax.legend(labels=[r'$\beta_x$', r'$\beta_y$'], ncols=1);
save('bare_optics', dirs['env'], dpi=dpi)

# Envelope

In [None]:
mode = int(np.loadtxt('_output/data/mode.txt'))
env_params_list = np.load(files['env_params'])
positions = np.load(files['positions'])

env_stats = ba.Stats(mode)
env_stats.read_env(env_params_list)

# Add columns
env_stats.twiss2D['mux'] = get_phase_adv(env_stats.twiss2D['bx'], positions, units='deg')
env_stats.twiss2D['muy'] = get_phase_adv(env_stats.twiss2D['by'], positions, units='deg')

# Add position column
positions_normed = positions / positions[-1]
for df in env_stats.dfs():
    df['s'] = positions
    df['s/L'] = positions_normed

### Effective lattice 

In [None]:
def getk(k0, s_over_L):
    if s_over_L < 0.125 or s_over_L >= 0.875:
        return +k0
    elif 0.375 <= s_over_L < 0.625:
        return -k0
    else:
        return 0

Q = np.loadtxt('_output/data/perveance.txt')
k0 = 0.505 # [m^-2]
cx = 0.001 * env_stats.realspace.loc[:, 'cx'].values
cy = 0.001 * env_stats.realspace.loc[:, 'cy'].values
phi = np.radians(env_stats.realspace.loc[:, 'angle'].values)
k0xx = np.array([getk(k0, s) for s in positions_normed])
k0yy = -k0xx
k0xy = [0] * len(k0xx)
k0yx = [0] * len(k0xx)

cos, sin = np.cos(phi), np.sin(phi)
cos2, sin2, sincos = cos**2, sin**2, sin*cos
kxx = k0xx - ((Q / (cx + cy)) * (cos2/cx + sin2/cy))
kyy = k0yy - ((Q / (cx + cy)) * (cos2/cy + sin2/cx))
kxy =- ((Q / (cx + cy)) * (1/cy - 1/cx) * sincos)
kyx =- ((Q / (cx + cy)) * (1/cy - 1/cx) * sincos)

fig, axes = plot.subplots(nrows=3, ncols=1, figsize=(5, 4), spany=False, aligny=True)
scale = 400
for k0, k, ax in zip([k0xx, k0yy, k0xy], [kxx, kyy, kxy], axes):
    ax.plot(positions_normed, k0, color='grey', ls='--', lw=1)
    ax.plot(positions_normed, k, color='pink5')
axes.format(xlabel='s / L')
myplt.set_labels(axes[:, 0], [r'$k_{x}$ [m$^{-2}$]', r'$k_{y}$ [m$^{-2}$]', r'$k_{xy}$ [m$^{-2}$]'], 'ylabel')
axes[-1].legend(labels=['Bare lattice', 'Effective lattice'], ncols=1, loc=(1.02, 0));

for ax in axes:
    ymax = max(np.abs(ax.get_ylim()))
    ax.format(ylim=(-ymax, ymax))

save('effective_lattice', dirs['env'], dpi=dpi)

## Twiss parameters 

### 2D Twiss

In [None]:
fig, axes = setup_figure(2)
env_stats.twiss2D[['s/L','bx','by']].plot('s/L', ax=axes[0], **plt_kws)
env_stats.twiss2D[['s/L','ax','ay']].plot('s/L', ax=axes[1], **plt_kws)
env_stats.twiss2D[['s/L','ex', 'ey']].plot('s/L', ax=axes[2], **plt_kws)
set_labels(axes, [r'$\beta$ [m]', r'$\alpha$ [rad]', r'$\varepsilon$ [mm $\cdot$ mrad]'], 'ylabel')
save('twiss2D', dirs['env'], dpi=dpi)

### 4D Twiss 

In [None]:
fig, axes = setup_figure(2)
env_stats.twiss4D[['s/L','bx','by']].plot('s/L', ax=axes[0], **plt_kws)
env_stats.twiss4D[['s/L','ax','ay']].plot('s/L', ax=axes[1], **plt_kws)
env_stats.twiss4D[['s/L','u']].plot('s/L', ax=axes[2], color='k', **plt_kws)
set_labels(axes, [r'$\beta$ [m]', r'$\alpha$ [rad]', 'u'], 'ylabel')
save('twiss4D', dirs['env'], dpi=dpi)

### Emittance 

In [None]:
fig, ax = plot.subplots(figsize=(4.5, 2.5))
env_stats.twiss2D[['s/L','ex','ey']].plot('s/L', ax=ax, **plt_kws)
env_stats.twiss4D[['s/L','e1','e2']].plot('s/L', ax=ax, **plt_kws)
ax.format(ylabel=r'$\varepsilon$ [mm $\cdot$ mrad]')
ax.legend(labels=[r'$\varepsilon_x$', r'$\varepsilon_y$', r'$\varepsilon_1$', r'$\varepsilon_2$'], 
          ncols=1, loc=(1.01, 0))
save('emittance', dirs['env'], dpi=dpi)

### Phase advance 
The phase advance in the $x$ dimension is found by integrating the beam size and  emittance: 

$$\mu_x(s) = \int_{0}^{s}{\frac{\varepsilon_x(s')}{{\tilde{x}(s')}^2}} ds',$$

where $\tilde{x} = \sqrt{\langle{x^2}\rangle}$ and $s$ is the position in the lattice. The same expression holds with $x \longleftrightarrow y$.

In [None]:
fig, ax = setup_figure(1)
env_stats.twiss2D[['s/L','mux','muy']].plot('s/L', ax=ax, **plt_kws)
ax.format(ylabel='Phase advance', yformatter='deg')
save('phase_adv', dirs['env'], dpi=dpi)

### Phase difference (nu)
The difference between every particle's $x$ and $y$ phases is related the the shape of the beam in real space as

$$
\cos\nu = \frac{\langle{xy}\rangle}{\sqrt{\langle{x^2}\rangle\langle{y^2}\rangle}}
$$

In [None]:
fig, ax = setup_figure(1)
env_stats.twiss4D[['s/L','nu']].plot('s/L', ax=ax, color='k', **plt_kws)
ax.format(ylabel=r'$\nu$', yformatter='deg')
save('twiss4D-nu', dirs['env'], dpi=dpi)

## Moments 

In [None]:
fig, ax = setup_figure(1)
env_stats.moments[['s/L','x_rms','y_rms']].plot('s/L', ax=ax, **plt_kws)
ax.format(ylabel='Beam size [mm]')
save('beamsize', dirs['env'], dpi=dpi)

In [None]:
fig, ax = setup_figure(1)
env_stats.moments[['s/L','xp_rms','yp_rms']].plot('s/L', ax=ax, **plt_kws)
ax.format(ylabel='Beam divergence [mrad]')
save('beamdiv', dirs['env'], dpi=dpi)

In [None]:
fig, axes = plot.subplots(nrows=4, ncols=4, sharey=False, figsize=(8, 6), 
                          spany=False, aligny=True)
myplt.make_lower_triangular(axes)
axes.format(xlabel='s / L', suptitle='Transverse moments', titleborder=True)

for i in range(4):
    for j in range(i + 1):
        ax = axes[i, j]
        col = moment_label(i, j)
        env_stats.moments[['s/L',col]].plot('s/L', ax=ax, color='k', **plt_kws)
        ax.format(title=moment_label_string(i, j))

set_labels(axes[0:, 0], [r'[mm$^2$]', r'[mm$\cdot$mrad]', r'[mm$^2$]', r'mm$\cdot$mrad'], 'ylabel')
set_labels(axes[1:, 1], [r'[mrad$^2$]', r'[mm$\cdot$mrad]', r'[mrad$^2$]'], 'ylabel')
set_labels(axes[2:, 2], [r'[mm$^2$]', r'[mm$\cdot$mrad]'], 'ylabel')
set_labels(axes[3:, 3], [r'[mrad$^2$]'], 'ylabel')
save('all_moments', dirs['env'], dpi=dpi)

In [None]:
fig, axes = plot.subplots(nrows=4, ncols=4, sharey=False, figsize=(8, 6), 
                          spany=False, aligny=True)
myplt.make_lower_triangular(axes)
axes.format(suptitle='Transverse correlations', titleborder=True)

for i in range(4):
    for j in range(i + 1):
        ax = axes[i, j]
        col = moment_label(i, j)
        env_stats.corr[['s/L', col]].plot('s/L', ax=ax, color='k', **plt_kws)
        ax.format(title=moment_label_string(i, j))

save('all_correlations', dirs['env'], dpi=dpi)

## Real space orientation

In [None]:
fig, axes = setup_figure(2)
env_stats.realspace[['s/L','angle']].plot('s/L', color='k', ax=axes[0], **plt_kws)
env_stats.realspace[['s/L','cx','cy']].plot('s/L', ax=axes[1], **plt_kws)
env_stats.realspace[['s/L','area_rel']].plot('s/L', ax=axes[2], color='k', **plt_kws)
set_labels(axes, ['tilt angle', 'ellipse axes [mm]', 'area [frac. change]'], 'ylabel')
axes[0].format(yformatter='deg')
save('realspace_ellipse', dirs['env'], dpi=dpi)

## Phase space projections 

In [None]:
axes = myplt.corner_env(
    env_params_list[[0, -1]], cmap=plot.Colormap(('red7', 'blue7')),
    legend_kws=dict(labels=['initial','final'], loc=(1, 1))
)
save('init_final', dirs['env'], dpi=dpi)

In [None]:
if animate:
    myanim.corner_env(env_params_list, skip=skip, fps=fps, figsize=5,
                      text_vals=positions_normed, text_fmt='s / L = {:.2f}')

## Transfer matrix 

In [None]:
if files_exist['transfer_matrix']:
    M = np.load(files['transfer_matrix'])
    M_eigvals, M_eigvecs = np.linalg.eig(M)
    M_eigtunes = np.degrees(np.arccos(M_eigvals.real))

    show(M, 'M')
    print()
    show(M_eigvals[[0, 2]], 'eigenvalues')
    print()
    show(M_eigtunes[[0, 2]], 'eigentunes [deg]')

In [None]:
if files_exist['transfer_matrix']:
    
    fig, axes = plot.subplots(ncols=2, figsize=(5.25, 2.5), share=False, span=False)
    axes.format(grid=False)
    myplt.despine(axes)
    ax1, ax2 = axes

    # Plot eigenvalues in complex plane
    psi = np.linspace(0, 2*np.pi, 50)
    x_circ, y_circ = np.cos(psi), np.sin(psi)
    ax1.plot(x_circ, y_circ, 'k--', zorder=0)
    ax1.scatter(M_eigvals.real, M_eigvals.imag, c=['r','r','b','b'])
    scale = 1.25
    ax1.format(
        xticks=[-1, -0.5, 0, 0.5, 1], yticks=[-1, -0.5, 0, 0.5, 1], 
        ylim=(-scale, scale), xlim=(-scale, scale),
        xlabel='Real', ylabel='Imag', title='Eigenvalues')
    ax1.annotate(r'$\mu_1 = {:.2f}\degree$'.format(M_eigtunes[0]), xy=(0, +0.1), horizontalalignment='center')
    ax1.annotate(r'$\mu_2 = {:.2f}\degree$'.format(M_eigtunes[2]), xy=(0, -0.1), horizontalalignment='center')

    # Plot turn-by-turn trajectory of eigenvectors
    myplt.eigvec_trajectory(ax2, M, 'x', 'y', s=10)
    ax2.format(xticklabels=[], yticklabels=[], 
               ylabel='y', xlabel='x', title='Eigenvectors')

    # Add legend
    custom_lines = [matplotlib.lines.Line2D([0], [0], color='r', lw=2),
                    matplotlib.lines.Line2D([0], [0], color='b', lw=2)]
    ax2.legend(custom_lines, [r'$\vec{v}_1$', r'$\vec{v}_2$'],
               loc=(1.05, 0.7), handlelength=1, ncols=1);

    save('eigvecs_realspace', dirs['env'], dpi=dpi)

In [None]:
if files_exist['transfer_matrix']:
    
    # Set up figure
    fig, axes = plot.subplots(nrows=3, ncols=3, figsize=(5, 5), span=False)
    axes.format(grid=True, suptitle='Transfer matrix eigenvectors')
    myplt.make_lower_triangular(axes)
    myplt.despine(axes)

    labels = ["x", "x'", "y", "y'"]
    xlabels, ylabels = labels[:-1], labels[1:]
    set_labels(axes[-1, :], xlabels, 'xlabel')
    set_labels(axes[:, 0], ylabels, 'ylabel')

    # Plot eigenvectors and their trajectories
    for i in range(3):
        for j in range(3):
            if i >= j:
                ax = axes[i, j]
                yvar = ['xp', 'y', 'yp'][i]
                xvar = ['x', 'y', 'xp'][j]
                myplt.eigvec_trajectory(ax, M, xvar, yvar, s=7, lw=1)

    # Zoom out a bit
    for i in range(3):
        ymin, ymax = axes[i, 0].get_ylim()
        xmin, xmax = axes[-1, i].get_xlim()
        scale = 1.2
        axes[i, :].format(ylim=(-scale*ymax, scale*ymax))
        axes[:, i].format(xlim=(-scale*xmax, scale*xmax))

    save('eigvecs', dirs['env'], dpi=dpi)

## Test bunch

In [None]:
if files_exist['testbunch_coords']:
    
    test_coords = np.load(files['testbunch_coords'])
    nframes, ntestparts, ndims = test_coords.shape
    print('nparts, nframes = {}, {}'.format(ntestparts, nframes))
    
    test_cdfs = []
    for X in test_coords:
        cdf = pd.DataFrame(X, columns=['x','xp','y','yp'])
        cdf[['s','s/L']] = env_stats.env_params[['s','s/L']]
        test_cdfs.append(cdf)

In [None]:
if files_exist['testbunch_coords']:
    
    fig, axes = plot.subplots(nrows=2, sharey=False, figsize=(3, 4))
    set_labels(axes, [r'$\sigma_x$ [mm]', r'$\sigma_y$ [mm]'], 'ylabel')
    
    for part_idx in range(0, ntestparts, 4):
        X = pd.DataFrame(test_coords[:, part_idx, :], columns=['x','xp','y','yp'])
        X['s/L'] = positions_normed
        X[['s/L','x']].plot('s/L', color='k', lw=1, legend=False, ax=axes[0])
        X[['s/L','y']].plot('s/L', color='k', lw=1, legend=False, ax=axes[1])

    x_env = 2 * env_stats.moments['x_rms']
    y_env = 2 * env_stats.moments['y_rms']
    for ax, env, c in zip(axes, (x_env, y_env), ('blue8','orange6')):
        ax.plot(positions_normed, +env, c=c)
        ax.plot(positions_normed, -env, c=c)
        
    save('testbunch_beamsize', dirs['env'], dpi=dpi)

# Distribution

In [None]:
if files_exist['bunch_coords']:
    coords = np.load(files['bunch_coords'])    
    print('Bunch coordinates:')
    print('nframes, nparts = {}, {}'.format(*coords.shape))
        
if files_exist['bunch_moments']:
    moments = np.load(files['bunch_moments'])
    bunch_stats = ba.Stats(mode)
    bunch_stats.read_moments(moments)
    bunch_stats.twiss2D['mux'] = get_phase_adv(bunch_stats.twiss2D['bx'], positions, 'deg')
    bunch_stats.twiss2D['muy'] = get_phase_adv(bunch_stats.twiss2D['by'], positions, 'deg')
    for df in bunch_stats.dfs():
        df['s'] = positions
        df['s/L'] = positions_normed

## Twiss parameters 

### 2D Twiss

In [None]:
if files_exist['bunch_moments']:
    fig, axes = setup_figure(2)
    bunch_stats.twiss2D[['s/L','bx','by']].plot('s/L', ax=axes[0], **plt_kws)
    bunch_stats.twiss2D[['s/L','ax','ay']].plot('s/L', ax=axes[1], **plt_kws)
    bunch_stats.twiss2D[['s/L','ex', 'ey']].plot('s/L', ax=axes[2], **plt_kws)
    set_labels(axes, [r'$\beta$ [m]', r'$\alpha$ [rad]', r'$\varepsilon$ [mm $\cdot$ mrad]'], 'ylabel')
    save('twiss2D', dirs['bunch'], dpi=dpi)

### Emittance 

In [None]:
if files_exist['bunch_moments']:
    fig, ax = plot.subplots(figsize=(4.5, 2.5))
    bunch_stats.twiss2D[['s/L','ex','ey']].plot('s/L', ax=ax, **plt_kws)
    bunch_stats.twiss4D[['s/L','e1','e2']].plot('s/L', ax=ax, **plt_kws)
    ax.format(ylabel=r'$\varepsilon$ [mm $\cdot$ mrad]')
    ax.legend(labels=[r'$\varepsilon_x$', r'$\varepsilon_y$', r'$\varepsilon_1$', r'$\varepsilon_2$'], 
              ncols=1, loc=(1.01, 0))
    save('emittance', dirs['bunch'], dpi=dpi)

### Phase advance 

In [None]:
if files_exist['bunch_moments']:
    fig, ax = setup_figure(1)
    bunch_stats.twiss2D[['s/L','mux','muy']].plot('s/L', ax=ax, **plt_kws)
    ax.format(ylabel='Phase advance', yformatter='deg')
    save('phase_adv', dirs['bunch'], dpi=dpi)

## Moments 

In [None]:
if files_exist['bunch_moments']:
    fig, ax = setup_figure(1)
    bunch_stats.moments[['s/L','x_rms','y_rms']].plot('s/L', ax=ax, **plt_kws)
    ax.format(ylabel='Beam size [mm]')
    save('beamsize', dirs['bunch'], dpi=dpi)

In [None]:
if files_exist['bunch_moments']:
    fig, ax = setup_figure(1)
    bunch_stats.moments[['s/L','xp_rms','yp_rms']].plot('s/L', ax=ax, **plt_kws)
    ax.format(ylabel='Beam divergence [mrad]')
    save('beamdiv', dirs['bunch'], dpi=dpi)

In [None]:
if files_exist['bunch_moments']:
    fig, axes = plot.subplots(nrows=4, ncols=4, sharey=False, figsize=(8, 6), 
                              spany=False, aligny=True)
    myplt.make_lower_triangular(axes)
    axes.format(suptitle='Transverse moments', titleborder=True)

    for i in range(4):
        for j in range(i + 1):
            ax = axes[i, j]
            col = moment_label(i, j)
            bunch_stats.moments[['s/L',col]].plot('s/L', ax=ax, color='k', **plt_kws)
            ax.format(title=moment_label_string(i, j))

    set_labels(axes[0:, 0], [r'[mm$^2$]', r'[mm$\cdot$mrad]', r'[mm$^2$]', r'mm$\cdot$mrad'], 'ylabel')
    set_labels(axes[1:, 1], [r'[mrad$^2$]', r'[mm$\cdot$mrad]', r'[mrad$^2$]'], 'ylabel')
    set_labels(axes[2:, 2], [r'[mm$^2$]', r'[mm$\cdot$mrad]'], 'ylabel')
    set_labels(axes[3:, 3], [r'[mrad$^2$]'], 'ylabel')
    save('all_moments', dirs['bunch'], dpi=dpi)

In [None]:
if files_exist['bunch_moments']:
    fig, axes = plot.subplots(nrows=4, ncols=4, sharey=False, figsize=(8, 6), 
                              spany=False, aligny=True)
    myplt.make_lower_triangular(axes)
    axes.format(suptitle='Transverse correlations', titleborder=True)

    for i in range(4):
        for j in range(i + 1):
            ax = axes[i, j]
            col = moment_label(i, j)
            bunch_stats.corr[['s/L', col]].plot('s/L', ax=ax, color='k', **plt_kws)
            ax.format(title=moment_label(i, j))

    save('all_correlations', dirs['bunch'], dpi=dpi)

## Real space orientation

In [None]:
if files_exist['bunch_moments']:
    fig, axes = setup_figure(2)
    bunch_stats.realspace[['s/L','angle']].plot('s/L', ax=axes[0], color='k', **plt_kws)
    bunch_stats.realspace[['s/L','cx','cy']].plot('s/L', ax=axes[1], **plt_kws)
    bunch_stats.realspace[['s/L','area_rel']].plot('s/L', ax=axes[2], color='k', **plt_kws)
    set_labels(axes, ['tilt angle', 'ellipse axes [mm]', 'area [frac. change]'], 'ylabel')
    axes[0].format(yformatter='deg')
    save('beam_dims', dirs['bunch'], dpi=dpi)

In [None]:
if files_exist['bunch_moments']:
    fig, axes = plot.subplots(nrows=4, ncols=1, figsize=(0.8*width, 2.5*height), 
                              spany=False, aligny=True)
    bunch_stats.moments[['s/L','x_rms','y_rms']].plot('s/L', ax=axes[0], **plt_kws)
    bunch_stats.twiss2D[['s/L','ex_frac','ey_frac']].plot('s/L', ax=axes[1], **plt_kws)
    bunch_stats.realspace[['s/L','angle']].plot('s/L', ax=axes[2], color='k', **plt_kws)
    bunch_stats.twiss2D[['s/L','mux','muy']].plot('s/L', ax=axes[3], **plt_kws)
    ylabels = ['beam size [mm]', 'emittance ratio', 'tilt angle [deg]', 'phase adv. [deg]']
    set_labels(axes, ylabels, 'ylabel')    
    save('vert', dirs['bunch'], dpi=dpi)

## Phase space projections 

In [None]:
if files_exist['bunch_coords']:
    for i, name in zip((1, -1), ('Initial', 'Final')):
        axes = myplt.corner(coords[i], text=name, figsize=5, pad=0.25)
        save(name, dirs['bunch'], dpi=dpi)

In [None]:
if animate and files_exist['bunch_coords']:
    anim = myanim.corner(coords, skip=skip, figsize=6, 
                         diag_kind='hist', fps=fps, pad=0.25,
                         text_fmt='s / L = {:.2f}', text_vals=positions_normed)
    play(anim)

# Comparison 

In [None]:
if files_exist['bunch_moments']:
    plt_kws_env = dict(
        lw=None,
        marker=None,
        markersize=None,
        color='steelblue',
        legend=False,
    )
    plt_kws_bunch = dict(
        lw=0,
        marker='x',
        markersize=1,
        color='red',
        legend=False,
    )
    dataframes = [env_stats, bunch_stats]
    kws_list = [plt_kws_env, plt_kws_bunch]

## Moments 

In [None]:
if files_exist['bunch_moments']:
    fig, axes = setup_figure(4)
    for ax, key in zip(axes, ('x_rms', 'y_rms')):
        for df, kws in zip(dataframes, kws_list):
            df.moments[['s/L', key]].plot('s/L', ax=ax, **kws)
    axes.format(ylabel='[mm]')
    set_labels(axes, [r'$\sqrt{\langle{x^2}\rangle}$', r'$\sqrt{\langle{y^2}\rangle}$'], 'title')
    axes[1].legend(labels=['theory', 'calc'], ncols=1, loc=(1.02, 0), fontsize='small')
    save('beamsize', dirs['comparison'], dpi=dpi)

In [None]:
if files_exist['bunch_moments']:
    fig, axes = setup_figure(4)
    for ax, key in zip(axes, ('xp_rms', 'yp_rms')):
        for df, kws in zip(dataframes, kws_list):
            df.moments[['s/L', key]].plot('s/L', ax=ax, **kws)
    axes.format(ylabel='[mm]')
    set_labels(axes, [r"$\sqrt{\langle{x'^2}\rangle}$", r"$\sqrt{\langle{y'^2}\rangle}$"], 'title')
    axes[1].legend(labels=['theory', 'calc'], ncols=1, loc=(1.02, 0), fontsize='small')
    save('beamdiv', dirs['comparison'], dpi=dpi)

In [None]:
if files_exist['bunch_moments']:
    fig, ax = plot.subplots(figsize=(1.25*width, height))
    for df, kws in zip(dataframes, kws_list):
        df.corr[['s/L','xy']].plot('s/L', ax=ax, **kws)
    ax.format(title=r"$x$-$y$ corr. coef.", xlabel='Turn number')
    ax.legend(labels=['theory', 'calc'], ncols=1, loc=(1.02, 0), fontsize='small')
    save('xy_corr', dirs['comparison'], dpi=dpi)

In [None]:
if files_exist['bunch_moments']:
    
    fig, axes = plot.subplots(nrows=4, ncols=4, sharey=False, figsize=(8, 6), 
                              spany=False, aligny=True)
    myplt.make_lower_triangular(axes)
    axes.format(suptitle='Transverse moments', titleborder=True)

    for df, kws in zip(dataframes, kws_list):
        for i in range(4):
            for j in range(i + 1):
                ax = axes[i, j]
                col = moment_label(i, j)
                df.moments[['s/L', col]].plot('s/L', ax=ax, **kws)
                ax.format(title=moment_label_string(i, j))

    set_labels(axes[0:, 0], [r'[mm$^2$]', r'[mm$\cdot$mrad]', r'[mm$^2$]', r'mm$\cdot$mrad'], 'ylabel')
    set_labels(axes[1:, 1], [r'[mrad$^2$]', r'[mm$\cdot$mrad]', r'[mrad$^2$]'], 'ylabel')
    set_labels(axes[2:, 2], [r'[mm$^2$]', r'[mm$\cdot$mrad]'], 'ylabel')
    set_labels(axes[3:, 3], [r'[mrad$^2$]'], 'ylabel')
    save('all_moments', dirs['comparison'], dpi=dpi)

In [None]:
if files_exist['bunch_moments']:
    
    fig, axes = plot.subplots(nrows=4, ncols=4, sharey=False, figsize=(8, 6), 
                              spany=False, aligny=True)
    myplt.make_lower_triangular(axes)
    axes.format(suptitle='Transverse correlations', titleborder=True)

    for df, kws in zip(dataframes, kws_list):
        for i in range(4):
            for j in range(i + 1):
                ax = axes[i, j]
                col = moment_label(i, j)
                df.corr[['s/L', col]].plot('s/L', ax=ax, **kws)
                ax.format(title=moment_label_string(i, j))

    save('all_correlations', dirs['comparison'], dpi=dpi)

## Twiss 

In [None]:
if files_exist['bunch_moments']:
    fig, axes = setup_figure(3)
    for df, kws in zip(dataframes, kws_list):
        df.twiss2D[['s/L','bx']].plot('s/L', ax=axes[0, 0], **kws)
        df.twiss2D[['s/L','by']].plot('s/L', ax=axes[0, 1], **kws)
        df.twiss2D[['s/L','ax']].plot('s/L', ax=axes[1, 0], **kws)
        df.twiss2D[['s/L','ay']].plot('s/L', ax=axes[1, 1], **kws)
        df.twiss2D[['s/L','ex']].plot('s/L', ax=axes[2, 0], **kws)
        df.twiss2D[['s/L','ey']].plot('s/L', ax=axes[2, 1], **kws)
    axes.format(collabels=['Horizontal', 'Vertical'])
    set_labels(axes[:, 0], [r'$\beta$ [m]', r'$\alpha$ [rad]', r'$\varepsilon$ [mm $\cdot$ mrad]'], 'ylabel')
    save('twiss', dirs['comparison'], dpi=dpi)

## Real space orientation

In [None]:
if files_exist['bunch_moments']:
    fig, axes = plot.subplots(nrows=2, figsize=(width, 1.5*height), spany=False, aligny=True)
    for df, kws in zip(dataframes, kws_list):
        df.realspace[['s/L','angle']].plot('s/L', ax=axes[0], **kws)
        df.realspace[['s/L','area']].plot('s/L', ax=axes[1], **kws)
    set_labels(axes, ['tilt angle', r'area [mm$^2$]'], 'ylabel')
    axes[0].format(yformatter='deg')
    save('beam_dims', dirs['comparison'], dpi=dpi)

In [None]:
if files_exist['bunch_moments']:
    fig, axes = setup_figure(4)
    for ax, key in zip(axes, ('cx', 'cy')):
        for df, kws in zip(dataframes, kws_list):
            df.realspace[['s/L', key]].plot('s/L', ax=ax, **kws)
    axes.format(ylabel='[mm]')
    set_labels(axes, [r"$c_x$", r"$c_y$"], 'title')
    axes[1].legend(labels=['theory', 'calc'], ncols=1, loc=(1.02, 0), fontsize='small')
    save('radii', dirs['comparison'], dpi=dpi)

## Phase space projections

In [None]:
if files_exist['bunch_coords']:
    for i, name in zip((0, -1), ('Initial', 'Final')):
        axes = myplt.corner(coords[i], env_params_list[i], text=name, diag_kind='none',
                            figsize=4, pad=0.25)
        save(name, dirs['comparison'], dpi=dpi)

In [None]:
if files_exist['bunch_coords'] and animate:
    anim = myanim.corner(coords, env_params_list, skip=5, figsize=5, 
                         diag_kind='none', fps=fps, pad=0.25,
                         text_fmt='s / L = {:.2f}', text_vals=positions_normed)
    play(anim)