# Sextupole scan

In [None]:
import math
import os
import sys
import time

import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.lines
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.visualization as psv
from ipywidgets import interact
from ipywidgets import widgets

import ment

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [None]:
def rotation_matrix(angle: float) -> np.ndarray:
    return np.array([[np.cos(angle), np.sin(angle)], [-np.sin(angle), np.cos(angle)]])

In [None]:
class Transform:
    def __init__(self) -> None:
        return
        
    def forward(self, x: np.ndarray) -> np.ndarray:
        raise NotImplementedError

    def inverse(self, u: np.ndarray) -> np.ndarray:
        raise NotImplementedError

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return self.forward(x)

In [None]:
class CompositeTransform(Transform):
    def __init__(self, *transforms) -> None:
        super().__init__()
        self.transforms = transforms
            
    def forward(self, x: np.ndarray) -> np.ndarray:
        u = x
        for transform in self.transforms:
            u = transform(u)
        return u

    def inverse(self, u: np.ndarray) -> np.ndarray:
        x = u
        for transform in reversed(self.transforms):
            x = transform.inverse(x)
        return x

In [None]:
class MultipoleTransform(Transform):
    """Applies multipole kick.
    
    https://github.com/PyORBIT-Collaboration/PyORBIT3/blob/main/src/teapot/teapotbase.cc    
        """
    def __init__(self, order: int, strength: float, skew: bool = False) -> None:
        """Constructor.

        Parameters
        ----------
        order: int
            The multipole number (1 for dipole, 2 for quad, 3 for sextupole, etc.).
        strength : float
            Integrated kick strength [m^(-pole)].
        skew : bool
            If True, rotate the magnet 45 degrees.
        """
        super().__init__()
        self.order = order
        self.strength = strength
        self.skew = skew
        
    def forward(self, x: np.ndarray) -> np.ndarray:
        ndim = x.ndim
        if ndim == 2:
            x = np.hstack([x, np.zeros((x.shape[0], 2))])
                    
        zn = (x[:, 0] + 1.0j * x[:, 2]) ** (self.order - 1)
        zn_imag = zn.imag
        zn_real = zn.real

        k = self.strength / math.factorial(self.order - 1)
        u = x.copy()
        if self.skew:
            u[:, 1] = u[:, 1] + k * zn_imag
            u[:, 3] = u[:, 3] + k * zn_real
        else:
            u[:, 1] = u[:, 1] - k * zn_real
            u[:, 3] = u[:, 1] + k * zn_imag
        u = u[:, :ndim]
        return u

In [None]:
class LinearTransform(Transform):
    def __init__(self, matrix: np.ndarray) -> None:
        super().__init__()
        self.set_matrix(matrix)

    def set_matrix(self, matrix: np.ndarray) -> None:
        self.matrix = matrix
        self.matrix_inv = np.linalg.inv(matrix)
        
    def forward(self, x: np.ndarray) -> np.ndarray:
        return np.matmul(x, self.matrix.T)

    def inverse(self, u: np.ndarray) -> np.ndarray:
        return np.matmul(u, self.matrix_inv.T)

In [None]:
dist = ment.dist.SwissRoll()

In [None]:
n_meas = 5
order = 5
strength_max = +1.5
strength_min = -strength_max

transforms = []
strengths = np.linspace(strength_min, strength_max, n_meas)
for strength in strengths:
    multipole = MultipoleTransform(order=order, strength=strength)

    angle = np.radians(90.0)
    matrix = rotation_matrix(angle)
    rotation = LinearTransform(matrix)

    transform = CompositeTransform(multipole, rotation)
    transforms.append(transform)

In [None]:
@interact(
    index=widgets.IntSlider(min=0, max=(len(transforms) - 1), value=0),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.5),
    bins=widgets.IntSlider(min=4, max=200, value=125),
)
def update(index, n, xmax, bins):
    transform = transforms[index]

    x = dist.sample(n)
    x = transform(x)

    fig, ax = pplt.subplots()
    limits = 2 * [(-xmax, +xmax)]
    ax.hist2d(x[:, 0], x[:, 1], bins=bins, range=limits)

    pax = ax.panel_axes("bottom", width=0.75)

    hist, edges = np.histogram(x[:, 0], bins=90, density=True)
    hist = hist / hist.max()
    psv.plot_profile(hist, edges=edges, ax=pax, color="black", kind="step")
    pplt.show()

In [None]:
xmax = 6.0
n_bins = 85
bin_edges = np.linspace(-xmax, xmax, n_bins + 1)

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.Histogram1D(axis=0, bin_edges=bin_edges)
    diagnostics.append([diagnostic])

x_true = dist.sample(1_000_000)

measurements = []
for index, transform in enumerate(transforms):
    u = transform(x_true)
    measurements.append([diagnostic(u) for diagnostic in diagnostics[index]])

In [None]:
ndim = 2
prior = ment.GaussianPrior(ndim=ndim, scale=1.0)

limits = [(-xmax, xmax)] * ndim
sampler = ment.samp.GridSampler(grid_limits=limits, grid_shape=(ndim * [100]))

model = ment.MENT(
    ndim=ndim,
    measurements=measurements,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation=dict(method="linear"),
    mode="sample",
    sampler=sampler,
    n_samples=200_000,
    verbose=True,
)

learning_rate = 0.85
n_epochs = 10

In [None]:
for epoch in range(-1, n_epochs):
    if epoch >= 0:
        model.gauss_seidel_step(lr=learning_rate)

    x = model.sample(1_000_000)

    fig, axs = pplt.subplots(ncols=2, figwidth=6.0)
    for i, ax in enumerate(axs):
        norm = ("log" if i else None)
        hist, edges = np.histogramdd(x, bins=100, range=limits)
        ax.pcolormesh(edges[0], edges[1], hist.T + 1.0, norm=norm, colorbar=True)
    pplt.show()

    ncols = min(n_meas, 7)
    nrows = int(np.ceil(n_meas / ncols))
    figwidth = 1.6 * ncols
    figheight = 1.2 * nrows

    for log in [False, True]:
        fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figwidth=figwidth, figheight=figheight, sharex=True, sharey=True)
        for index, transform in enumerate(transforms):
            y_pred = model.simulate(index, diag_index=0).copy()
            y_meas = measurements[index][0].copy()
            ax = axs[index]
    
            normalization = y_meas.max()
            y_pred /= y_meas.max()
            y_meas /= y_meas.max()                
            ax.stairs(y_meas, color="black", lw=1.25)
            ax.stairs(y_pred, color="red", lw=1.25)
            ax.format(ymax=1.25)
            if log:
                ax.format(yscale="log", ymax=5.0, ymin=1.00e-05, yformatter="log")
        plt.show()