# Linac simulation analysis

In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import proplot as pplt
import psdist as ps
import psdist.visualization as psv
import yaml
from ipywidgets import interact
from ipywidgets import widgets
from omegaconf import OmegaConf
from omegaconf import DictConfig
from pprint import pprint

from analysis import get_input_dir

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["cycle"] = "538"
pplt.rc["grid"] = False
pplt.rc["figure.facecolor"] = "white"

## Setup

In [None]:
timestamp = None  # None = latest
script_name = "sim"

input_dir = get_input_dir(timestamp, script_name)
print("input_dir = ", input_dir)

In [None]:
cfg_path = os.path.join(input_dir, "config/config.yaml")
cfg = yaml.safe_load(open(cfg_path, "r"))
cfg = DictConfig(cfg)

print("config:")
print(OmegaConf.to_yaml(cfg))

In [None]:
cycle_colors = pplt.Cycle(pplt.rc["cycle"]).by_key()["color"]

## Scalars

In [None]:
history = pd.read_csv(os.path.join(input_dir, "history.dat"))
history.head()

In [None]:
print("history keys:")
pprint(list(history.keys()))

In [None]:
fig, ax = pplt.subplots(figsize=(4.5, 2.0))
for key in ["x_rms", "y_rms"]:
    ax.plot(history["position"].values, history[key].values * 1000.0, label=key)
ax.legend(loc="r", ncols=1)
ax.format(xlabel="Distance [m]", ylabel="[mm]", ymin=0.0)
plt.show()

In [None]:
fig, ax = pplt.subplots(figsize=(4.5, 2.0))
for key in ["eps_x", "eps_y"]:
    ax.plot(history["position"].values, history[key].values * 1.00e+06, label=key)
ax.legend(loc="r", ncols=1)
ax.format(xlabel="Distance [m]", ylabel="[mm mrad]", ymin=0.0)
plt.show()

In [None]:
history["eps_x_norm"] = history["eps_x"] * (history["beta"] * history["gamma"])
history["eps_y_norm"] = history["eps_y"] * (history["beta"] * history["gamma"])

fig, ax = pplt.subplots(figsize=(4.5, 2.0))
for key in ["eps_x_norm", "eps_y_norm"]:
    ax.plot(history["position"].values, history[key].values * 1.00e+06, label=key)
ax.legend(loc="r", ncols=1)
ax.format(xlabel="Distance [m]", ylabel="[mm mrad]", ymin=0.0)
plt.show()

## Phase space distribution

In [None]:
bunch_filenames = os.listdir(input_dir)
bunch_filenames = sorted(bunch_filenames)
bunch_filenames = [f for f in bunch_filenames if f.startswith("bunch")]
bunch_filenames = [os.path.join(input_dir, f) for f in bunch_filenames]

bunches = []
for filename in bunch_filenames:
    X = np.loadtxt(filename, comments="%", usecols=range(6))
    X[:, 0] *= 1000.0  # x [m] -> [mm]
    X[:, 1] *= 1000.0  # x' [rad] -> [mrad]
    X[:, 2] *= 1000.0  # y [m] -> [mm]
    X[:, 3] *= 1000.0  # y' [rad] -> [mrad]
    X[:, 4] *= 1000.0  # z [m] -> [mm]
    X[:, 5] *= 1000.0  # dE [GeV] -> [MeV]
    bunches.append(X)

In [None]:
dims = ["x", "xp", "y", "yp", "z", "dE"]
units = ["mm", "mrad", "mm", "mrad", "mm", "MeV"]
labels = [f"{dim} [{unit}]" for dim, unit in zip(dims, units)]

limits = [
    ps.points.limits(X, zero_center=True, pad=0.0, share=[(0, 2, 4), (1, 3)])
    for X in bunches
]
limits = psv.combine_limits(limits)

### Interactive 2D projections

In [None]:
@interact(
    dim1=widgets.Dropdown(options=dims, value=dims[0]),
    dim2=widgets.Dropdown(options=dims, value=dims[1]),
    index=widgets.IntSlider(min=0, max=(len(bunches) - 1), value=0, continuous_update=False),
    bins=widgets.IntSlider(min=32, max=128, value=64, continuous_update=False),
    lim_scale=widgets.FloatSlider(min=0.1, max=4.0, value=1.0, continuous_update=False),
    log=False,
    normalize=False,
)
def update(
    dim1: str,
    dim2: str,
    index: int, 
    bins: int, 
    lim_scale: float, 
    log: bool, 
    normalize: bool
):
    if dim1 == dim2:
        return

    axis = [dims.index(dim) for dim in [dim1, dim2]]
    axis = tuple(axis)
    
    X = bunches[index]
    if normalize:
        X = ps.points.norm_xxp_yyp_zzp(X, scale_emittance=True)

    _limits = [limits[axis[0]], limits[axis[1]]]
    if normalize:
        _limits = 2 * [(-6.0, 6.0)]
    _limits = np.multiply(_limits, lim_scale).tolist()

    rho, edges = np.histogramdd(X[:, axis], bins=bins, range=_limits)
    offset = 1.0
    rho = rho + offset

    fig, ax = pplt.subplots()
    ax.pcolormesh(edges[0], edges[1], rho.T, norm=("log" if log else None))
    ax.format(xlabel=labels[axis[0]], ylabel=labels[axis[1]])
    ax.format(title=f"")

    paxs = [ax.panel_axes(loc) for loc in ["top", "right"]]
    for pax in paxs:
        pax.format(xspineloc="bottom", yspineloc="left")

    rho_x, edges_x = np.histogram(X[:, axis[0]], range=_limits[0], bins=bins, density=False)
    rho_y, edges_y = np.histogram(X[:, axis[1]], range=_limits[1], bins=bins, density=False)

    offset = 1.0
    rho_x = rho_x + offset
    rho_y = rho_y + offset

    kws = dict(color="black", lw=1.25)
    paxs[0].stairs(rho_x, edges_x, **kws)
    paxs[1].stairs(rho_y, edges_y, orientation="horizontal", **kws)
    if log:
        paxs[0].format(yscale="log")
        paxs[1].format(xscale="log")
    plt.show()

### Interactive corner

In [None]:
@interact(
    ndim=widgets.BoundedIntText(min=4, max=6, value=6),
    index=widgets.IntSlider(min=0, max=(len(bunches) - 1), value=0, continuous_update=False),
    bins=widgets.IntSlider(min=32, max=128, value=64, continuous_update=False),
    lim_scale=widgets.FloatSlider(min=0.1, max=4.0, value=1.0, continuous_update=False),
    log=False,
    ellipse=False,
    mask=False,
    normalize=False,
)
def update(
    ndim: str,
    index: int, 
    bins: int, 
    lim_scale: float, 
    log: bool, 
    ellipse: bool,
    mask: bool,
    normalize: bool,
): 
    X = bunches[index][:, :ndim]
    if normalize:
        X = ps.points.norm_xxp_yyp_zzp(X, scale_emittance=True)

    _limits = limits
    if normalize:
        _limits = ndim * [(-6.0, 6.0)]
    _limits = np.multiply(_limits, lim_scale).tolist()
    
    grid = psv.CornerGrid(ndim, diag_shrink=0.85, diag_rspine=False)
    grid.plot_points(
        X,
        bins=bins,
        limits=_limits,
        mask=mask,
        rms_ellipse=ellipse,
        rms_ellipse_kws=dict(level=2.0, color="white"),
        norm=("log" if log else None),
        offset=1.0,
    )
    if log:
        grid.format_diag(yscale="log", ymin=1.00e-04)
    grid.set_labels(labels)
    grid.axs.format(suptitle=f"")
    plt.show()