In [14]:
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import proplot as pplt
import psdist.visualization as psv
import torch
from ipywidgets import interact
from ipywidgets import widgets

import mentflow as mf
from mentflow.utils import grab
from mentflow.utils import unravel

In [15]:
mf.train.plot.set_proplot_rc()

In [38]:
cfg = {
    "dist": {
        "name": "rings",
        "shear": 1.0,
    },
    "meas": {
        "min_angle": 0.0,
        "max_angle": 180.0,
        "num": 6,
    },
}
cfg = omegaconf.DictConfig(cfg)

In [39]:
dist = mf.dist.dist_2d.gen_dist(**cfg.dist)

In [42]:
## Constant linear focusing, varying multipole strength.
transforms = []
strength_max = +1.5
strength_min = -strength_max
order = 4
strengths = np.linspace(strength_min, strength_max, cfg.meas.num)
for strength in strengths:
    multipole = mf.sim.MultipoleTransform(order=order, strength=strength)

    angle = np.radians(90.0)
    matrix = mf.sim.rotation_matrix(angle)
    matrix = matrix.type(torch.float32)
    rotation = mf.sim.LinearTransform(matrix)
    
    transform = mf.sim.CompositeTransform(multipole, rotation)
    transforms.append(transform)

In [43]:
@interact(
    index=widgets.IntSlider(min=0, max=(len(transforms) - 1), value=0),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e+05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.5),
    bins=widgets.IntSlider(min=4, max=200, value=125),
)
def update(index, n, xmax, bins):
    transform = transforms[index]

    x = dist.sample(n)
    x = transform(x)
    x = grab(x)
    
    fig, ax = pplt.subplots()
    limits = 2 * [(-xmax, +xmax)]
    ax.hist2d(x[:, 0], x[:, 1], bins=bins, range=limits)
    
    pax = ax.panel_axes("bottom", width=0.75)

    hist, edges = np.histogram(x[:, 0], bins=90, density=True)
    hist = hist / hist.max()
    psv.plot_profile(hist, edges=edges, ax=pax, color="black", kind="step")
    pplt.show()

interactive(children=(IntSlider(value=0, description='index', max=5), FloatLogSlider(value=100000.0, descripti…

In [44]:
@interact(
    index=widgets.IntSlider(min=0, max=(len(transforms) - 1), value=0),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e+05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.5),
    n_bins=widgets.IntSlider(min=4, max=200, value=125),
    n_lines=widgets.IntSlider(min=5, max=50, value=15),
    scale=widgets.FloatSlider(min=1.0, max=4.0, value=1.0),
)
def update(index, n, xmax, n_bins, n_lines, scale):
    transform = transforms[index]

    x = dist.sample(n)
    u = transform(x)
    u = grab(u)
    x = grab(x)
    
    fig, axs = pplt.subplots(ncols=2)
    limits = 2 * [(-xmax, +xmax)]
    for ax, _x in zip(axs, [x, u]):
        ax.hist2d(_x[:, 0], _x[:, 1], bins=n_bins, range=limits)
    
    n_dots_per_line = 150
    u = mf.utils.get_grid_points_torch(
        scale * torch.linspace(-xmax, +xmax, n_lines),
        3.0 * torch.linspace(-xmax, +xmax, n_dots_per_line),
    )
    
    x = transform.inverse(u)
    x = grab(x)
    u = grab(u)
    
    for ax, _x in zip(axs, [x, u]):
        for line in np.split(_x, n_lines):
            ax.plot(line[:, 0], line[:, 1], color="white", alpha=0.5)
    axs.format(xlim=(-xmax, xmax), ylim=(-xmax, xmax))
    axs[0].format(title="Input space")
    axs[1].format(title="Transformed space")
    pplt.show()

interactive(children=(IntSlider(value=0, description='index', max=5), FloatLogSlider(value=100000.0, descripti…