# McMillan map in periodic ring

In [1]:
import os
import sys
import time
from typing import Callable

import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.lines
import numpy as np
import proplot as pplt
import psdist as ps
from ipywidgets import interact
from ipywidgets import widgets

import ment

In [2]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [3]:
class ComposedTransform:
    def __init__(self, *transforms) -> None:
        self.transforms = transforms
            
    def forward(self, x: np.ndarray) -> np.ndarray:
        u = x
        for transform in self.transforms:
            u = transform(u)
        return u

    def inverse(self, u: np.ndarray) -> np.ndarray:
        x = u
        for transform in reversed(self.transforms):
            x = transform.inverse(x)
        return x

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return self.forward(x)

In [4]:
class Transform:
    def __init__(self) -> None:
        return
        
    def forward(self, x: np.ndarray) -> np.ndarray:
        raise NotImplementedError

    def inverse(self, u: np.ndarray) -> np.ndarray:
        raise NotImplementedError

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return self.forward(x)

In [5]:
class LinearTransform(Transform):
    def __init__(self, matrix: np.ndarray) -> None:
        super().__init__()
        self.set_matrix(matrix)

    def set_matrix(self, matrix: np.ndarray) -> None:
        self.matrix = matrix
        self.matrix_inv = np.linalg.inv(matrix)
        
    def forward(self, x: np.ndarray) -> np.ndarray:
        return np.matmul(x, self.matrix.T)

    def inverse(self, u: np.ndarray) -> np.ndarray:
        return np.matmul(u, self.matrix_inv.T)

In [6]:
class AxiallySymmetricNonlinearKick(Transform):
    def __init__(self, alpha: float, beta: float, phi: float, A: float, E: float, T: float) -> None:
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.phi = phi
        self.A = A
        self.E = E
        self.T = T

    def forward(self, x: np.ndarray) -> np.ndarray:
        ndim = x.shape[1]
        if ndim == 2:
            x = np.hstack([x, np.zeros((x.shape[0], 2))])
            
        r = np.sqrt(x[:, 0] ** 2 + x[:, 2] ** 2)
        theta = np.arctan2(x[:, 2], x[:, 0])
        
        dr = -(1.0 / (self.beta * np.sin(self.phi))) * ((self.E * r) / (self.A * r**2 + self.T)) - ((2.0 * r) / (self.beta * np.tan(self.phi)))
        
        x_out = np.copy(x)
        x_out[:, 1] += dr * np.cos(theta)
        x_out[:, 3] += dr * np.cos(theta)
        return x_out[:, :ndim]

In [7]:
def norm_matrix(alpha: float, beta: float) -> np.ndarray:
    V = np.zeros((2, 2))
    V[0, 0] = np.sqrt(beta)
    V[0, 1] = 0.0
    V[1, 0] = -alpha / np.sqrt(beta)
    V[1, 1] = 1.0 / np.sqrt(beta)
    return np.linalg.inv(V)

In [8]:
def phase_advance_matrix(phi: float) -> np.ndarray:
    return np.array([[np.cos(phi), np.sin(phi)], [-np.sin(phi), np.cos(phi)]])

In [9]:
def track_tbt(x: np.ndarray, transform: Callable, turns: int = 1) -> np.ndarray:
    coords = np.zeros((turns + 1, x.shape[0], x.shape[1]))
    coords[0] = x
    for i in range(1, coords.shape[0]):
        coords[i] = transform(coords[i - 1])
    return coords

In [10]:
# Define twiss parameters at origin
alpha = 0.0
beta = 1.0
phi = 2.0 * np.pi * 0.18

# Create lattice.
Vinv = norm_matrix(alpha, beta)
V = np.linalg.inv(Vinv)
R = phase_advance_matrix(phi)
M = np.linalg.multi_dot([V, R, Vinv])

lattice = ComposedTransform(
    LinearTransform(M),
    AxiallySymmetricNonlinearKick(alpha, beta, phi, A=1.0, E=0.5, T=1.0),
)

In [14]:
# Create distribution
n_parts = 100_000

t = np.linspace(0.0, 2.0 * np.pi, n_parts)
x_true = np.vstack([np.cos(t), np.sin(t)]).T
x_true *= 2.5
x_true += np.random.normal(scale=0.1, size=x_true.shape)

## Apply mismatch
# x_true = np.matmul(x_true, np.linalg.inv(norm_matrix(alpha=1.5, beta=1.0).T))

Show TBT:

In [15]:
# coords = track_tbt(x_true, lattice, turns=100)

In [16]:
# @interact(turn=(0, coords.shape[0] - 1), kind=["scatter", "hist"])
# def update(turn: int = 0, kind: str = "scatter"):
#     # Settings
#     xmax = 7.0
#     limits = 2 * [(-xmax, xmax)]

#     x = coords[turn]

#     fig, ax = pplt.subplots()

#     if kind == "scatter":
#         ax.scatter(x[:, 0], x[:, 1], color="k", s=3)
#     elif kind == "hist":
#         ax.hist2d(x[:, 0], x[:, 1], bins=125, range=limits, cmap="mono")
#     else:
#         raise ValueError

#     i = 12
#     ax.plot([0.0, x[i, 0]], [0.0, x[i, 1]], color="red")

#     pax = ax.panel_axes("bottom")
#     hist, edges = np.histogram(x[:, 0], bins=75, range=limits[0], density=True)
#     hist = hist / np.sum(hist)
#     pax.stairs(hist, edges, color="black", fill=True)
#     pax.format(yticks=[])

#     ax.format(xlim=limits[0], ylim=limits[1])
#     plt.show()

In [None]:
# Try fitting into existing workflow.
n_turns = 15

transforms = []
for t in range(n_turns):
    transform = ComposedTransform(*(t * [lattice]))
    transforms.append(transform)

In [None]:
xmax = 6.5
limits = 2 * [(-xmax, xmax)]
n_bins = 64
bin_edges = np.linspace(-xmax, xmax, 100)
bin_coords = 0.5 * (bin_edges[1:] + bin_edges[:-1])

diagnostic = ment.diag.Histogram1D(axis=0, bin_edges=bin_edges)
diagnostics = [[diagnostic,] for transform in transforms]

In [None]:
measurements = ment.sim.forward(x_true, transforms, diagnostics)

Create MENT model.

In [None]:
ndim = 2
prior = ment.GaussianPrior(ndim=ndim, scale=2.0)

limits = [(-xmax, xmax)] * ndim
sampler = ment.samp.GridSampler(grid_limits=limits, grid_shape=(ndim * [128]))

model = ment.MENT(
    ndim=ndim,
    measurements=measurements,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation=dict(method="linear"),
    mode="sample",
    sampler=sampler,
    n_samples=1_000_000,
    verbose=True,
)

learning_rate = 0.5
n_epochs = 10

In [None]:
for epoch in range(-1, n_epochs):
    if epoch >= 0:
        model.gauss_seidel_step(lr=learning_rate)

    x_pred = model.sample(1_000_000)

    fig, axs = pplt.subplots(ncols=2)
    for ax, x in zip(axs, [x_pred, x_true]):
        hist, edges = np.histogramdd(x, bins=100, range=limits)
        ax.pcolormesh(edges[0], edges[1], hist.T + 1.0)
    pplt.show()

    ncols = min(n_turns, 7)
    nrows = int(np.ceil(n_turns / ncols))
    figwidth = 1.6 * ncols
    figheight = 1.2 * nrows

    log = False
    
    fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figwidth=figwidth, figheight=figheight, sharex=True, sharey=True)
    for index, transform in enumerate(transforms):
        y_pred = model.simulate(index, diag_index=0).copy()
        y_meas = measurements[index][0].copy()
        ax = axs[index]

        normalization = y_meas.max()
        y_pred /= y_meas.max()
        y_meas /= y_meas.max()  
        ax.plot(y_meas, color="black")
        ax.plot(y_pred, color="red")
        # ax.stairs(y_meas, color="black", lw=1.25)
        # ax.stairs(y_pred, color="red", lw=1.25)
        ax.format(ymax=1.25)
        if log:
            ax.format(yscale="log", ymax=5.0, ymin=1.00e-05, yformatter="log")
    plt.show()