# Setup

## Imports 

In [None]:
import sys
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
import proplot as plot

sys.path.append('/Users/46h/Research/code/accphys') 
from tools import beam_analysis as ba
from tools import plotting as myplt
from tools.plotting import save, set_labels
from tools import animation as myanim
from tools import utils
from tools.utils import show, play, file_exists
from tools.accphys_utils import get_phase_adv
from tools.plot_utils import moment_label, moment_label_string

## Settings

In [None]:
plt_kws = dict(legend=False, xlabel='s / L')
plot.rc['figure.facecolor'] = 'white'
plot.rc['grid.alpha'] = 0.04
plot.rc['style'] = None 
plot.rc['savefig.dpi'] = 'figure' 
plot.rc['animation.html'] = 'jshtml'
dpi = 500

cmap = plot.Colormap('Blues')
cmap_range = (0, 1)

# Load data

In [None]:
positions = np.load('_output/data/positions.npy')
positions_normed = positions / positions[-1]
costs = np.load('_output/data/costs.npy')

tracked_params_lists = {}
tracked_params_lists['tbt'] = np.load('_output/data/tbt_params_list.npy')
tracked_params_lists['sdep'] = np.load('_output/data/sdep_params_list.npy')

stats_lists = {}
for key in ('tbt', 'sdep'):
    stats_lists[key] = []
    for tracked_params in tracked_params_lists[key]:
        stats = ba.Stats(mode=1)
        stats.read_env(tracked_params)
        if key == 'sdep':
            for df in stats.dfs():
                df['s'] = positions
                df['s/L'] = positions_normed
        stats_lists[key].append(stats)
    stats_lists[key] = np.array(stats_lists[key])

# Plot 

### Turn-by-turn parameters 

In [None]:
nshow = 3
data_kws = dict(lw=None, ls='-', marker='.', legend=False)
mean_kws = dict(lw=0.75, ls='-', alpha=0.5, zorder=0)

fig, axes = plot.subplots(nrows=4, ncols=nshow, figsize=(5, 4.5), spany=False, aligny=True)
for j, stats in enumerate(stats_lists['tbt'][:nshow]):
    n = stats.twiss4D.shape[0]
    for i, cols in enumerate([['bx','by'], ['ax','ay'], ['u'], ['nu']]):
        ax = axes[i, j]
        color = 'k' if len(cols) == 1 else None
        data = stats.twiss4D[cols]
        data.plot(ax=ax, color=color, **data_kws)
        if len(cols) > 1:
            xavg, yavg = data.mean()
            ax.plot(list(range(n)), n * [xavg], color='blue7', **mean_kws)
            ax.plot(list(range(n)), n * [yavg], color='orange7', **mean_kws)
        else:
            avg = data.mean()
            ax.plot(list(range(n)), n * [avg], color='k', **mean_kws)
    
axes.format(xlabel='Turn number', grid=False,
            toplabels=['Iteration {}'.format(i) for i in range(nshow)])
axes[-1, 0].format(yformatter='deg')
myplt.set_labels(axes[:, 0], [r'$\beta_l$ [m]', r'$\alpha_l$ [rad]', 'u', r'$\nu$'], 'ylabel')

plt.savefig('_output/figures/iters.png', dpi=dpi, facecolor='white')

### s-dependent parameters 

In [None]:
nshow = 8

data_kws = dict(lw=None, ls='-', legend=False, ms=1)
mean_kws = dict(lw=0.75, ls='--', alpha=0.5, zorder=0)

fig, axes = plot.subplots(nrows=3, ncols=nshow, figsize=(12, 4), spany=False, aligny=True)
for j, stats in enumerate(stats_lists['sdep'][:nshow]):
    n = stats.twiss4D.shape[0]
    for i, cols in enumerate([['bx','by'], ['u'], ['nu']]):
        ax = axes[i, j]
        color = 'k' if len(cols) == 1 else None
        stats.twiss4D[['s/L'] + cols].plot('s/L', ax=ax, color=color, **data_kws)
    
axes.format(xlabel='s / L',
            toplabels=['Iteration {}'.format(i) for i in range(nshow)])
axes[-1, 0].format(yformatter='deg')
myplt.set_labels(axes[:, 0], [r'$\beta_{1}$ [m]', 'u', r'$\nu$'], 'ylabel')