# Matched envelope comparison 

This notebook generates some plots used in [this paper](https://doi.org/10.1103/PhysRevAccelBeams.24.044201).

## Setup

### Imports 

In [None]:
import sys
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
import proplot as pplt

sys.path.append('/Users/46h/Research/') 
from scdist.tools import animation as myanim
from scdist.tools import beam_analysis as ba
from scdist.tools import plotting as myplt
from scdist.tools import utils
from scdist.tools.ap_utils import get_phase_adv
from scdist.tools.plotting import set_labels
from scdist.tools.utils import file_exists
from scdist.tools.utils import play
from scdist.tools.utils import show
from scdist.tools.plot_utils import moment_label
from scdist.tools.plot_utils import moment_label_string
from scdist.tools.plot_utils import sci_notation
from scdist.tools.plot_utils import PHASE_SPACE_LABELS_UNITS

### Settings

In [None]:
animate = False
plt_kws = dict(legend=False, xlabel='s / L')
pplt.rc['animation.html'] = 'jshtml'
pplt.rc['cycle'] = 'default'
pplt.rc['figure.facecolor'] = 'white'
pplt.rc['grid'] = False
pplt.rc['grid.alpha'] = 0.04
pplt.rc['savefig.dpi'] = 'figure'

cmap = pplt.Colormap('fire_r')
cmap_range = (0, 0.9)
dpi = 500

## Read Data

In [None]:
latnames = [f.rstrip() for f in open('_output/data/latnames.txt', 'r')]
perveances = np.load('_output/data/perveances.npy')

tracked_params_dict = dict()
transfer_mats_dict = dict()
stats_dict = dict()
positions_dict = dict()

In [None]:
for latname in latnames:
    for mode in (1, 2):
        key = latname + '_mode{}'.format(mode)
        print('Loading', key, '...')

        # Load the list of tracked envelope parameters and transfer matrices        
        tracked_params_list = np.load('_output/data/tracked_params_list_{}_{}.npy'.format(latname, mode))
        tracked_params_list *= 1000.0 # convert from m to mm
        tranfer_mat_list = np.load('_output/data/transfer_mats_{}_{}.npy'.format(latname, mode))
        positions = np.load('_output/data/positions_{}.npy'.format(latname))
        
        # For some reason there are three multiple nodes at the same position at the end
        # of the lattice, so strip the last two.
        positions = positions[:-2]
        tracked_params_list = tracked_params_list[:, :-2, :]
        
        # For the skew lattice, the last position gives NaN for the envelope parameters.
        # I think the beam area becomes exactly zero at this point.
        if key == 'fodo_skew_mode2':
            positions = positions[:-1]
            tracked_params_list = tracked_params_list[:, :-1, :]
        
        # Compute the beam statistics for each set of envelope parameters in the list
        stats_list = []
        for i, tracked_params in enumerate(tracked_params_list):
            stats = ba.BeamStats(mode)
            stats.read_env(tracked_params)
            beta_x = stats.twiss2D.loc[:, 'beta_x'].values
            beta_y = stats.twiss2D.loc[:, 'beta_y'].values
            stats.twiss2D['mux'] = get_phase_adv(beta_x, positions, units='deg')
            stats.twiss2D['muy'] = get_phase_adv(beta_y, positions, units='deg')
            
            # nu is undefined when <x^2> or <y^2> are zero; set it to 90 degrees in this case
            if latname.startswith('fodo_split') and perveances[i] == 0:
                stats.twiss4D['nu'] = 90.0
                
            for df in stats.dfs():
                df['s'] = positions
                df['s/L'] = positions / positions[-1]
            stats_list.append(stats)
            
        stats_dict[key] = stats_list
        tracked_params_dict[key] = tracked_params_list
        transfer_mats_dict[key] = tranfer_mat_list
        positions_dict[key] = positions
        
print('Done.')

## Plotting

In [None]:
_cycler = myplt.colorcycle(cmap, len(perveances), cmap_range)

### Phase space projections at lattice entrance

In [None]:
for key, tracked_params_list in tracked_params_dict.items():
    print('Plotting', key, '...')
    axes = myplt.corner_env(tracked_params_list[:, 0, :], figsize=(5, 5), 
                            autolim_kws=dict(pad=0.25), 
                            cmap=cmap, cmap_range=cmap_range, lw=1)
    axes[0, 1].annotate('s = 0', xy=(0.5, 0.5), xycoords='axes fraction')
    plt.savefig('_output/figures/corner_vs_sc_{}.png'.format(key), dpi=500, facecolor='white')
    plt.close()
print('Done.')

### Twiss parameters within lattice

In [None]:
grouped_keys = [('fodo_mode1', 'fodo_mode2'), 
                ('fodo_split_mode1', 'fodo_split_mode2'), 
                ('fodo_skew_mode1', 'fodo_skew_mode2'), 
                ('fodo_sol_mode1', 'fodo_sol_mode2')]

In [None]:
for keys in grouped_keys:
    fig, axes = pplt.subplots(nrows=3, ncols=2, figsize=(6.5, 6), spany=False, aligny=True)
    for ax in axes:
        ax.set_prop_cycle(_cycler)
    for j, key in enumerate(keys):
        print('Plotting', key, '...')
        stats_list = stats_dict[key]
        for stats in stats_list:
            lw = 1.1
            stats.moments[['s/L','x_rms']].plot('s/L', ax=axes[0, j], lw=lw, **plt_kws)
            stats.twiss2D[['s/L','eps_x']].plot('s/L', ax=axes[1, j], lw=lw, **plt_kws)
#             stats.twiss4D[['s/L','nu']].plot('s/L', ax=axes[2, j], lw=lw, **plt_kws)
            stats.twiss2D[['s/L', 'mux']].plot('s/L', ax=axes[2, j], lw=lw, **plt_kws)
        for stats in stats_list:
            stats.moments[['s/L','y_rms']].plot('s/L', ax=axes[0, j], ls='--', zorder=0, legend=False, lw=1.1)
            stats.twiss2D[['s/L','eps_y']].plot('s/L', ax=axes[1, j], ls='--', zorder=0, legend=False, lw=1.1)
            stats.twiss2D[['s/L', 'muy']].plot('s/L', ax=axes[2, j], ls='--', zorder=0, legend=False, lw=1.1)
#     axes[2, 0].format(yformatter='deg')
    axes.format(grid=False, toplabels=['Solution 1', 'Solution 2'])
    set_labels(axes[:, 0], ['Beam size [mm]', r'Emittance [$mm \cdot mrad$]', r'$\nu$'], 'ylabel')
    cbar = fig.colorbar(cmap, width=0.075, ticks=[0], label='Perveance', pad=3)
    cbar.set_label('Perveance', labelpad=-5)
    # Save
    tag = key[:-6]
    plt.savefig('_output/figures/matched_traj_{}.png'.format(tag), dpi=500, facecolor='white')
    plt.close()
print('Done.')

## Combined plot 

In [None]:
exponent = int(np.floor(np.log10(perveances[-1])))
perveances_reduced = perveances / 10**exponent
cbar_labels = ['{:.0f}'.format(Q) for Q in perveances_reduced]

In [None]:
def plot_combined(key, lw=1.0, label_kws=None, 
                  pad=0.1,
                  cmap='rocket', cmap_range=(0.0, 1.0),
                  dashed_lw_reduction=0.8):
    if label_kws is None:
        label_kws = dict()
    if type(cmap) is str:
        cmap = pplt.Colormap(cmap)

    fig, axes = pplt.subplots(
        nrows=3, ncols=4, share=False, figwidth=7.35,
        wspace=[1.0, 1.0, 6.0], 
        hspace=[1.0, 1.0],
        width_ratios=[1.0, 1.0, 1.0, 1.5],
        aligny=True, alignx=True,
    )
    for ax in axes:
        ax.set_prop_cycle(myplt.colorcycle(cmap, len(perveances), cmap_range))

    caxes = axes[:, :3]
    for i in range(3):
        for j in range(3):
            ax = caxes[i, j]
            if j > i:
                ax.axis('off')
                continue
            if j > 0:
                ax.format(yticklabels=[])
            if i < 2:
                ax.format(xticklabels=[])
    myplt.despine(caxes, ('top', 'right'))

    for i in range(3):
        caxes[i, 0].set_ylabel(PHASE_SPACE_LABELS_UNITS[i + 1], **label_kws)
        caxes[-1, i].set_xlabel(PHASE_SPACE_LABELS_UNITS[i], **label_kws)

    myplt.corner_env(
        tracked_params_dict[key][:, 0, :], 
        axes=caxes, 
        use_existing_limits=False, autolim_kws=dict(pad=pad),
        lw=lw,
    )
    caxes[0, 1].annotate('s = 0', xy=(0.5, 0.5), xycoords='axes fraction', 
                         horizontalalignment='center',
                         **label_kws)
    # If beam is flat, then the size is zero in either x-x' or y-y'. Plot a
    # dot at the origin to indicate this.
    if 'split' in key:
        if key == 'fodo_split_mode1':
            ax = axes[2, 2]
        elif key == 'fodo_split_mode2':
            ax = axes[0, 0]
        ax.plot(0.0, 0.0, color=cmap(0), marker='.', ms=lw)

    paxes = axes[:, -1]
    paxes[:-1].format(xticklabels=[])
    for ax, label in zip(paxes, ['Beam size [mm]', r'Emittance [mm mrad]', r'$\nu$ / $\pi$']):
        ax.set_ylabel(label, **label_kws)
    paxes[-1].set_xlabel('s / L')

    for stats in stats_dict[key]:
        stats.moments[['s/L','x_rms']].plot('s/L', ax=paxes[0], lw=lw, **plt_kws)
        stats.twiss2D[['s/L','eps_x']].plot('s/L', ax=paxes[1], lw=lw, **plt_kws)
        paxes[2].plot(
            stats.twiss4D.loc[:, 's/L'].values,
            stats.twiss4D.loc[:, 'nu'].values / 180.0,
            lw=lw,
        )
    for stats in stats_dict[key]:
        ls = 'dashed'
        lw_ = dashed_lw_reduction * lw
        stats.moments[['s/L','y_rms']].plot('s/L', ax=paxes[0], ls=ls, zorder=0, legend=False, lw=lw_)
        stats.twiss2D[['s/L','eps_y']].plot('s/L', ax=paxes[1], ls=ls, zorder=0, legend=False, lw=lw_)

    norm = matplotlib.colors.Normalize(vmin=perveances_reduced[0], vmax=perveances_reduced[-1])
    cbar = fig.colorbar(matplotlib.cm.ScalarMappable(norm, cmap), 
                        width=0.09,
                        label=r'$10^{-5}$ Q')
    return axes

In [None]:
cmap = pplt.Colormap('fire_r')
cmap_range = (0.0, 0.875)
for keys in grouped_keys:
    for key in keys:
        print('Plotting', key, '...')
        plot_combined(key, lw=0.95, cmap=cmap, cmap_range=cmap_range)
        plt.savefig('_output/figures/matched_vs_sc_{}.png'.format(key), dpi=700)
        plt.show()

### Animation

In [None]:
# if animate:
#     key = 'fodo_mode1'
#     tracked_params_list = tracked_params_dict[key]
#     anim = myanim.corner_env(tracked_params_list, skip=99, 
#                              cmap=cmap, cmap_range=cmap_range,
#                              text_vals=np.linspace(0, 1, len(tracked_params_list[0])), 
#                              text_fmt='s / L = {:.2f}')
#     play(anim)

## Effective transfer matrix 

### Eigenvalues

In [None]:
nrows, ncols = 4, 3
s = 17
fontsize = 7
marker = 'o'

fig, axes = pplt.subplots(nrows=4, ncols=3, figwidth=5.75, hspace=0, wspace=0)
myplt.set_labels(axes[0, :], ['Q = {:.2e}'.format(Q) for Q in perveances[:3]], 'title')

# Plot unit circle in background
for ax in axes:
    myplt.unit_circle(ax, color='black', lw=0.75, alpha=0.25, zorder=0)
    
# Plot eigenvalues
keys = transfer_mats_dict.keys()
keys_solution1 = [key for key in keys if key.endswith('1')]
keys_solution2 = [key for key in keys if key.endswith('2')]
for i, key in enumerate(keys_solution1):
    transfer_mat_list = transfer_mats_dict[key]
    for ax, M in zip(axes[i, :], transfer_mat_list[:3]):
        eigvals, eigvecs = np.linalg.eig(M)
        myplt.eigvals_complex_plane(ax, eigvals, colors=('r','b'), 
                                    marker=marker, zorder=1, legend=False, ms=10,)
        mu1, _, mu2, _ = np.degrees(np.arccos(eigvals.real))
        split = 0.05
        ax.annotate(r'$\mu_1 = {:.2f}\degree$'.format(mu1), xy=(0.5, 0.48 + split), 
                    xycoords='axes fraction', horizontalalignment='center', fontsize=fontsize)
        ax.annotate(r'$\mu_2 = {:.2f}\degree$'.format(mu2), xy=(0.5, 0.48 - split), 
                    xycoords='axes fraction', horizontalalignment='center', fontsize=fontsize)
        
# Formatting
scale = 1.5
axes.format(
    grid=False,
    xlim=(-scale, scale), ylim=(-scale, scale), xlabel='Real', ylabel='Imaginary',
    xticks=[-1, 0, 1], yticks=[-1, 0, 1],
    leftlabels=['FODO', 'FODO\n(split tunes)', 'FODO\n(skew quads)', 'FODO\n(solenoid insert)'], 
    leftlabels_kw=dict(rotation='horizontal', fontsize='medium'),
    xminorlocator='null', yminorlocator='null',
)
plt.savefig('_output/figures/eigvals.png', dpi=dpi, facecolor='w')
plt.show()