In [None]:
import os
import sys

import numpy as np
from numpy import linalg as la
from scipy import optimize as opt
from matplotlib import pyplot as plt
from tqdm import tqdm, trange
import proplot as plot
import gif

plot.rc['grid.alpha'] = 0.05
plot.rc['axes.grid'] = False

sys.path.append('/Users/46h/Research/')
from accphys.tools import plotting as myplt
from accphys.tools import beam_analysis as ba
from accphys.tools import utils
from accphys.tools.accphys_utils import phase_adv_matrix
from accphys.emittance_measurement_4D.analysis import to_mat

In [None]:
n_parts = 200000
X0 = np.random.normal(size=(n_parts, 4))
Sigma0 = np.cov(X0.T)
print('Sigma =')
print(Sigma0)
print('eps_1, eps_2 =', ba.apparent_emittances(Sigma0))
print('eps_x, eps_y =', ba.intrinsic_emittances(Sigma0))

In [None]:
def scan(Sigma0, tmats, dsig=0.0, dphi=0.0):
    # Add random error to diagonal wire angle
    phi = np.radians(45.0)
    phi += np.random.uniform(-dphi, dphi)
    A, b = [], []
    for M in tmats:            
        # Transport
        Sigma = la.multi_dot([M, Sigma0, M.T])
        # Take measurement
        sig_xx = Sigma[0, 0]
        sig_yy = Sigma[2, 2]
        sig_xy = Sigma[0, 2]
        sig_uu = 0.5 * (2 * sig_xy + sig_xx+ sig_yy)
        # Add some errors to the measured moments
        sig_xx *= (1.0 + np.random.normal(scale=dsig))
        sig_yy *= (1.0 + np.random.normal(scale=dsig))
        sig_uu *= (1.0 + np.random.normal(scale=dsig))
        # Calculate <xy> from the (possibly) wrong angle and modified moments.
        sn, cs = np.sin(phi), np.cos(phi)
        sig_xy = (sig_uu - sig_xx * cs**2 - sig_yy * sn**2) / (2 * sn * cs)
        # Form arrays
        A.append([M[0, 0]**2, M[0, 1]**2, 2*M[0, 0]*M[0, 1], 0, 0, 0, 0, 0, 0, 0])
        A.append([0, 0, 0, M[2, 2]**2, M[2, 3]**2, 2*M[2, 2]*M[2, 3], 0, 0, 0, 0])
        A.append([0, 0, 0, 0, 0, 0, M[0, 0]*M[2, 2],  M[0, 1]*M[2, 2],  M[0, 0]*M[2, 3],  M[0, 1]*M[2, 3]])
        b.append(sig_xx)
        b.append(sig_yy)
        b.append(sig_xy)
        
    moments, res, rank, s = la.lstsq(A, b, rcond=None)
    return moments, res, rank, s

In [None]:
def is_positive_definite(Sigma):
    return np.all(np.linalg.eigvals(Sigma) > 0)

In [None]:
def run(Sigma0, tmats, n_trials, dsig=0, dphi=0, pbar=False):
    emittances = []
    n_fail = 0
    irange = range(n_trials) if not pbar else trange(n_trials)
    for i in irange:
        moments, res, rank, s = scan(Sigma0, tmats, dsig, dphi)
        Sigma = to_mat(moments)
        if not is_positive_definite(Sigma):
            n_fail += 1
            continue
        eps_x, eps_y = ba.apparent_emittances(Sigma)
        eps_1, eps_2 = ba.intrinsic_emittances(Sigma)
        emittances.append([eps_x, eps_y, eps_1, eps_2])
    fail_rate = n_fail / n_trials
    return fail_rate, np.array(emittances)

In [None]:
n_trials = 1000
dsig = 0.05
dphi = np.radians(0.5)

phase_coverages = np.radians(np.linspace(10., 180., 30))
fail_rates = []
emittances_list = []
for phase_coverage in tqdm(phase_coverages):
    phixx = np.linspace(0.0, phase_coverage, 4)
    phiyy = np.linspace(0.0, phase_coverage, 4)
    tmats = [phase_adv_matrix(phix, phiy) for phix, phiy in zip(phixx, phiyy)]
    fail_rate, emittances = run(Sigma0, tmats, n_trials, dsig=dsig, dphi=dphi)
    fail_rates.append(fail_rate)
    emittances_list.append(emittances)

In [None]:
fig, ax = plot.subplots()
ax.plot(np.degrees(phase_coverages), fail_rates, marker='.', color='black');
title = ''.join([
    'RMS frac. error in measured moments = {}'.format(dsig),
    '\n',
    'Max diag. wire angle error = {} [deg]'.format(np.degrees(dphi)),
])
ax.format(xlabel='phase coverage [deg]', ylabel='Fail rate', title=title,
          xticks=[0, 45, 90, 135, 180], ylim=(-0.02, 1.02))

## Realistic example

In [None]:
def load_tmats_dict(filename, exclude=None, max_n_meas=None):
    """Load dictionary of transfer matrix elements at each wire-scanner.
    
    Each line in the file reads [node_id, M11, M12, M13, M14, M21, M22,
    M23, M24, M31, M32, M33, M34, M41, M42, M43, M44].
    """
    tmats_dict = dict()
    file = open(filename, 'r')
    for line in file:
        tokens = line.split()
        ws_id = tokens[0]
        tmat_elems = [float(token) for token in tokens[1:]]
        tmat = np.reshape(tmat_elems, (4, 4))
        if ws_id not in tmats_dict:
            tmats_dict[ws_id] = []
        tmats_dict[ws_id].append(tmat)
    file.close()
    tmats_dict = utils.blacklist(tmats_dict, exclude)
    if max_n_meas:
        for ws_id in tmats_dict:
            tmats_dict[ws_id] = tmats_dict[ws_id][:max_n_meas]
    return tmats_dict

In [None]:
def load_moments_dict(filename, exclude=None, max_n_meas=None):
    """Load dictionary of moments at each wire-scanner.
    
    Each line in the file reads [node_id, <xx>, <yy>, <xy>].
    """
    moments_dict = dict()
    file = open(filename, 'r')
    for line in file:
        tokens = line.split()
        ws_id = tokens[0]
        moments = [float(token) for token in tokens[1:]]
        if ws_id not in moments_dict:
            moments_dict[ws_id] = []
        moments_dict[ws_id].append(moments)
    file.close()
    moments_dict = utils.blacklist(moments_dict, exclude)
    for ws_id in moments_dict:
        if max_n_meas:
            moments_dict[ws_id] = moments_dict[ws_id][:max_n_meas]
        moments_dict[ws_id] = np.array(moments_dict[ws_id])
    return moments_dict

In [None]:
folder = './_saved/2021-08-10/scan/'
exclude = None
max_n_meas = 100

tmats_dict = load_tmats_dict(os.path.join(folder, 'transfer_mats.dat'), exclude, max_n_meas)
moments_dict = load_moments_dict(os.path.join(folder, 'moments.dat'), exclude, max_n_meas)
ws_ids = sorted(list(tmats_dict))

In [None]:
for 

In [None]:
Sigma0 = np.array([
    [132.22438094,  29.02394835,   9.06323335,  -2.50129748],
    [ 29.02394835,  10.29463098,   5.93827178,  -0.51453613],
    [  9.06323335,   5.93827178, 161.91844364, -10.49587901],
    [ -2.50129748,  -0.51453613, -10.49587901,   1.97252215]
])
print(ba.apparent_emittances(Sigma0))
print(ba.intrinsic_emittances(Sigma0))

In [None]:
def scatter_hist(x, y, ax, ax_marg_x, ax_marg_y, joint_kws=None, marginal_kws=None):
    if joint_kws is None:
        joint_kws = dict()
    if marginal_kws is None:
        marginal_kws = dict()        
    if 'range' in joint_kws:
        xrange, yrange = joint_kws['range']
    else:
        xrange = yrange = None
    if 'bins' in joint_kws:
        bins = joint_kws['bins']
    else:
        heights, edges, patches = ax_marg_x.hist(x, bins='auto', range=xrange, **marginal_kws)
        for patch in patches:
            patch.set_visible(False)
        bins = len(heights)
        joint_kws['bins'] = bins
    ax_marg_x.hist(x, range=xrange, bins=bins, **marginal_kws)
    ax_marg_y.hist(y, range=yrange, bins=bins, orientation='horizontal', **marginal_kws)
    ax.hist2d(x, y, **joint_kws)
    return ax

def create_grid(fig, gridspec, row, col):
    ax_joint = fig.add_subplot(gridspec[row, col])
    ax_marg_x = fig.add_subplot(gridspec[row - 1, col])
    ax_marg_y = fig.add_subplot(gridspec[row, col + 1])
    for ax in [ax_marg_x, ax_marg_y]:
        ax.set_xticks([])
        ax.set_yticks([])
        for side in ['top', 'bottom', 'left', 'right']:
            ax.spines[side].set_visible(False)
    return ax_joint, ax_marg_x, ax_marg_y

In [None]:
def emittances_joint_hist(emittances, lims=((10, 40), (0, 20))):
    fig = plt.figure(figsize=(10, 4))
    h = 1.5
    gridspec = fig.add_gridspec(2, 5, width_ratios=(7, h, 2.5, 7, h), height_ratios=(h, 7),
                                left=0.1, right=0.9, bottom=0.1, top=0.9,
                                wspace=0, hspace=0)
    ax1, ax1_marg_x, ax1_marg_y = create_grid(fig, gridspec, 1, 0)
    ax2, ax2_marg_x, ax2_marg_y = create_grid(fig, gridspec, 1, 3)

    joint_kws = dict(cmap='fire_r', range=lims, bins=75)
    marginal_kws = dict(histtype='step', color='black')

    scatter_hist(emittances[:, 0], emittances[:, 1], ax1, ax1_marg_x, ax1_marg_y, joint_kws, marginal_kws)
    scatter_hist(emittances[:, 2], emittances[:, 3], ax2, ax2_marg_x, ax2_marg_y, joint_kws, marginal_kws)
    ax1_marg_x.set_xlim(lims[0])
    ax1_marg_y.set_ylim(lims[1])
    ax2_marg_x.set_xlim(lims[0])
    ax2_marg_y.set_ylim(lims[1])
    ax1.set_xlabel(r'$\varepsilon_x$ [mm mrad]')
    ax1.set_ylabel(r'$\varepsilon_y$ [mm mrad]')
    ax2.set_xlabel(r'$\varepsilon_1$ [mm mrad]')
    ax2.set_ylabel(r'$\varepsilon_2$ [mm mrad]')
    ax1_marg_x.set_title(r'Apparent emittances ($\varepsilon_x$, $\varepsilon_y$)')
    ax2_marg_x.set_title(r'Intrinsic emittances ($\varepsilon_1$, $\varepsilon_2$)')      
    return ax1, ax1_marg_x, ax1_marg_y, ax2, ax1_marg_y, ax1_marg_y

In [None]:
meas_index = 4
n_trials = 50000
dsig_list = np.linspace(0.0, 0.05, 6)
dphi = np.radians(0.0)

tmats = [tmats_dict[ws_id][meas_index] for ws_id in sorted(list(tmats_dict))]
emittances_list, fail_rates = [], []
for dsig in dsig_list:
    fail_rate, emittances = run(Sigma0, tmats, n_trials, dsig, dphi, pbar=True)
    fail_rates.append(fail_rate)
    emittances_list.append(emittances)

In [None]:
fig, ax = plot.subplots(figsize=(3, 2))
ax.plot(dsig_list, fail_rates, marker='.', color='k')
ax.format(xlabel='RMS fractional error in measured moments',
          ylabel='Fail rate', grid=True,
          ylim=(0, 1))
plt.savefig('_output/failrate_vs_meas_error_{}'.format(meas_index), facecolor='white')

In [None]:
eps_x, eps_y = ba.apparent_emittances(Sigma0)
eps_1, eps_2 = ba.intrinsic_emittances(Sigma0)

In [None]:
@gif.frame
def plot_errors(dsig, fail_rate, emittances):
    ax1, ax1_marg_x, ax1_marg_y, ax2, ax1_marg_y, ax1_marg_y = emittances_joint_hist(emittances, lims=((10, 40), (0, 20)))
    ax2.annotate('fail rate = {:.2f}'.format(fail_rate), 
             xy=(0.03, 0.93), xycoords='axes fraction', color='white')  
    ax2.annotate(r'rms frac err = {:.0f}%'.format(100 * dsig), 
             xy=(0.03, 0.87), xycoords='axes fraction', color='white') 
    line_kws = dict(color='white', lw=0.25, alpha=0.5)
    ax1.axvline(eps_x, **line_kws)
    ax1.axhline(eps_y, **line_kws)
    ax2.axvline(eps_1, **line_kws)
    ax2.axhline(eps_2, **line_kws)
    figname = '_output/res/results_dsig{}%.png'.format(100*dsig)
    plt.savefig(figname, facecolor='white', dpi=250)
    plt.show()

In [None]:
frames = []
for dsig, fail_rate, emittances in zip(dsig_list, fail_rates, emittances_list):
    frame = plot_errors(dsig, fail_rate, emittances)
    frames.append(frame)

In [None]:
# gif.options.matplotlib['dpi'] = 200
# fps = 0.75
# gif.save(frames, '_output/example.gif', duration=len(frames)/fps, unit="s", between="startend")