In [None]:
from functools import partial
import os
import pickle

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0" #select GPU, -1 means use CPU

import equinox as eqx
from flowjax.distributions import StandardNormal, Logistic, StudentT, AbstractDistribution, Normal
from flowjax.flows import masked_autoregressive_flow, block_neural_autoregressive_flow, triangular_spline_flow, coupling_flow
import h5py
import jax
import matplotlib.pyplot as plt
from matplotlib import colormaps
from matplotlib.colors import LogNorm
import numpy as np
import pandas as pd
from scipy import stats

from tqdm import trange, tqdm
import tqdm.utils as tutils
def ssl(x):
    return 100, 200
tutils._screen_shape_linux = ssl

from neural_net_defs import *

In [None]:
import gzip
import json
import optax
import diffrax
from jaxtyping import PyTree, Array
from copy import deepcopy

In [None]:
jax.devices()

In [None]:
tpc_r = 66.4

In [None]:
data_obj = np.load('kr83_sr1_50runs_FDCtrain.npz')

In [None]:
z = data_obj['z_corr']
conditions = data_obj['condition']

In [None]:
key = jax.random.PRNGKey(42)
key, flow_key = jax.random.split(key, 2)

In [None]:
conditions.shape

In [None]:
flow = coupling_flow(
    flow_key, base_dist=StandardNormal((2,)), invert=False, flow_layers=flow_layers, nn_width=NN_width, nn_depth=NN_depth, nn_activation=activation, cond_dim=conditions.shape[1], transformer=bijection
)
flow = eqx.tree_deserialise_leaves("../flow_posrec/posrec_flow_uniform_100e_2to5e_PMTs_turned_off.eqx", flow)

In [None]:
compiled_sample = eqx.filter_jit(flow.sample)

In [None]:
def generate_samples(key, conditions, N_samples):
    output = compiled_sample(key, (N_samples,), condition=conditions)
    return data_inv_transformation(jnp.reshape(output, (-1,2)))

In [None]:
generate_samples(key, conditions[0:4], 2).shape

In [None]:
class Func(eqx.Module):
    layers: list[eqx.nn.Linear]

    def __init__(self, *, data_size, width_size, depth, key, **kwargs):
        super().__init__(**kwargs)
        keys = jax.random.split(key, depth + 1)
        layers = []
        if depth == 0:
            layers.append(
                ConcatSquash(in_size=data_size, out_size=data_size, key=keys[0])
            )
        else:
            layers.append(
                ConcatSquash(in_size=data_size, out_size=width_size, key=keys[0])
            )
            for i in range(depth - 1):
                layers.append(
                    ConcatSquash(
                        in_size=width_size, out_size=width_size, key=keys[i + 1]
                    )
                )
            layers.append(
                ConcatSquash(in_size=width_size, out_size=data_size, key=keys[-1])
            )
        self.layers = layers

    def __call__(self, t, y, args):
        t = jnp.asarray(t)[None]
        for layer in self.layers[:-1]:
            y = layer(t, y)
            y = jax.nn.tanh(y)
        y = self.layers[-1](t, y)
        return y


# Credit: this layer, and some of the default hyperparameters below, are taken from the
# FFJORD repo.
class ConcatSquash(eqx.Module):
    lin1: eqx.nn.Linear
    lin2: eqx.nn.Linear
    lin3: eqx.nn.Linear

    def __init__(self, *, in_size, out_size, key, **kwargs):
        super().__init__(**kwargs)
        key1, key2, key3 = jax.random.split(key, 3)
        self.lin1 = eqx.nn.Linear(in_size, out_size, key=key1)
        self.lin2 = eqx.nn.Linear(1, out_size, key=key2)
        self.lin3 = eqx.nn.Linear(1, out_size, use_bias=False, key=key3)

    def __call__(self, t, y):
        return self.lin1(y) * jax.nn.sigmoid(self.lin2(t)) + self.lin3(t)


In [None]:
class MLPFunc(eqx.Module):
    layers: list[eqx.nn.Linear]
    # layers_t: list[eqx.nn.Linear]

    def __init__(self, *, data_size, width_size, depth, key, **kwargs):
        super().__init__(**kwargs)
        keys = jax.random.split(key, depth + 1)
        layers = []
        # layers_t = []
        if depth == 0:
            layers.append(
                eqx.nn.Linear(data_size+1, data_size, key=keys[0])
            )
        else:
            layers.append(
                eqx.nn.Linear(data_size+1, width_size, key=keys[0])
            )
            for i in range(depth - 1):
                layers.append(
                    eqx.nn.Linear(
                        width_size, width_size, key=keys[i + 1]
                    )
                )
            layers.append(
                eqx.nn.Linear(width_size, data_size, key=keys[-1])
            )

        # if depth == 0:
        #     layers_t.append(
        #         eqx.nn.Linear(1, 1, key=keys[0])
        #     )
        # else:
        #     layers_t.append(
        #         eqx.nn.Linear(1, width_size, key=keys[0])
        #     )
        #     for i in range(depth - 1):
        #         layers_t.append(
        #             eqx.nn.Linear(
        #                 width_size, width_size, key=keys[i + 1]
        #             )
        #         )
        #     layers_t.append(
        #         eqx.nn.Linear(width_size, 1, key=keys[-1])
        #     )
        self.layers = layers
        # self.layers_t = layers_t

    def __call__(self, t, y, args):
        t = jnp.asarray(t)[None]
        # y_init = y
        y = jnp.concatenate((y,t), axis=-1)

        for layer in self.layers[:-1]:
            y = layer(y)
            y = jax.nn.silu(y)
        y = self.layers[-1](y)
        # for layer in self.layers_t[:-1]:
        #     t = layer(t)
        #     t = jax.nn.silu(t)
        # t = self.layers[-1](t)
        return y

In [None]:
def approx_logp_wrapper(t, y, args):
    y, _ = y
    *args, eps, func = args
    fn = lambda y: func(t, y, args)
    f, vjp_fn = jax.vjp(fn, y)
    (eps_dfdy,) = vjp_fn(eps)
    logp = jnp.sum(eps_dfdy * eps)
    return f, logp


def exact_logp_wrapper(t, y, args):
    y, _ = y
    *args, _, func = args
    fn = lambda y: func(t, y, args)
    f, vjp_fn = jax.vjp(fn, y)
    (size,) = y.shape  # this implementation only works for 1D input
    eye = jnp.eye(size)
    (dfdy,) = jax.vmap(vjp_fn)(eye)
    logp = jnp.trace(dfdy)
    return f, logp

In [None]:
class CNF(eqx.Module):
    func_drift: eqx.Module
    func_extract: eqx.Module
    data_size: int
    exact_logp: bool
    t0: float
    extract_t1: float
    dt0: float
    stepsizecontroller: diffrax.AbstractStepSizeController
    
    def __init__(
        self,
        *,
        data_size,
        exact_logp,
        width_size,
        depth,
        key,
        stepsizecontroller=diffrax.ConstantStepSize(),
        func=Func,
        **kwargs,
    ):
        keys = jax.random.split(key, 2)
        super().__init__(**kwargs)
        self.func_drift = (
            func(
                data_size=data_size,
                width_size=width_size,
                depth=depth,
                key=keys[0],
            )
        )
        self.func_extract = (
            func(
                data_size=data_size,
                width_size=width_size,
                depth=depth,
                key=keys[1],
            )
        )
        self.data_size = data_size
        self.exact_logp = exact_logp
        self.t0 = 0
        self.extract_t1 = 10
        self.dt0 = 1
        self.stepsizecontroller=stepsizecontroller

    def transform(self, *, y, t1):
        term = diffrax.ODETerm(self.func_extract)
        solver = diffrax.Tsit5()
        sol = diffrax.diffeqsolve(term, solver, self.t0, self.extract_t1, self.dt0, y, stepsize_controller=self.stepsizecontroller)
        (y,) = sol.ys
        
        term = diffrax.ODETerm(self.func_drift)
        solver = diffrax.Tsit5()
        sol = diffrax.diffeqsolve(term, solver, self.t0, t1, self.dt0, y, stepsize_controller=self.stepsizecontroller)
        (y,) = sol.ys
        return y

    def transform_and_log_det(self, *, y, t1):
        if self.exact_logp:
            term = diffrax.ODETerm(exact_logp_wrapper)
        else:
            term = diffrax.ODETerm(approx_logp_wrapper)
        eps = jax.random.normal(key, y.shape)
        delta_log_likelihood = 0.0
        
        y = (y, delta_log_likelihood)
        solver = diffrax.Tsit5()
        sol = diffrax.diffeqsolve(term, solver, self.t0, self.extract_t1, self.dt0, y, (eps, self.func_extract), stepsize_controller=self.stepsizecontroller)
        (y,), (delta_log_likelihood,) = sol.ys

        y = (y, delta_log_likelihood)
        solver = diffrax.Tsit5()
        sol = diffrax.diffeqsolve(term, solver, self.t0, t1, self.dt0, y, (eps, self.func_drift), stepsize_controller=self.stepsizecontroller)
        (y,), (delta_log_likelihood,) = sol.ys
        return y, delta_log_likelihood

    def inverse_and_log_det(self, *, y, t1):
        if self.exact_logp:
            term = diffrax.ODETerm(exact_logp_wrapper)
        else:
            term = diffrax.ODETerm(approx_logp_wrapper)
        eps = jax.random.normal(key, y.shape)
        delta_log_likelihood = 0.0

        y = (y, delta_log_likelihood)
        solver = diffrax.Tsit5()
        sol = diffrax.diffeqsolve(term, solver, t1, self.t0, -self.dt0, y, (eps, self.func_drift), stepsize_controller=self.stepsizecontroller)
        (y,), (delta_log_likelihood,) = sol.ys
        
        y = (y, delta_log_likelihood)
        solver = diffrax.Tsit5()
        sol = diffrax.diffeqsolve(term, solver, self.extract_t1, self.t0, -self.dt0, y, (eps, self.func_extract), stepsize_controller=self.stepsizecontroller)
        (y,), (delta_log_likelihood,) = sol.ys
        return y, delta_log_likelihood

In [None]:
from jax.scipy.interpolate import RegularGridInterpolator

In [None]:
def load_civ():
        civ_file_name = "field_dependent_radius_depth_maps_B2d75n_C2d75n_G0d3p_A4d9p_T0d9n_PMTs1d3n_FSR0d65p_QPTFE_0d5n_0d4p.json.gz"
  
        with gzip.open(civ_file_name, "rb") as f:
            file = json.load(f)
        civ_map = RegularGridInterpolator(
            tuple([np.linspace(*ax[1]) for ax in file['coordinate_system']]),
            np.array(file['survival_probability_map']).reshape([ax[1][-1] for ax in file['coordinate_system']]),
            bounds_error=False,
            fill_value=0,
        )
        
        return civ_map


In [None]:
civ_map = load_civ()
vec_civ_map = jax.vmap(civ_map)

In [None]:
tpc_height = 148.6515
z_scale = 5
data_bool = (z>-tpc_height) & (z<0)
z_sel = z[data_bool]
z_sel_scaled = -z_sel/z_scale
cond_sel = conditions[data_bool]

In [None]:
len(z_sel_scaled)/len(z)

In [None]:
@jax.jit
def compute_r(xy_arr):
    return jnp.sqrt(xy_arr[:,0]**2 + xy_arr[:,1]**2)

In [None]:
key, model_key = jax.random.split(key, 2)

In [None]:
model = CNF(
    data_size=2,
    exact_logp=True,
    width_size=48,
    depth=3,
    key=model_key,
    stepsizecontroller=diffrax.PIDController(rtol=1e-3, atol=1e-6, dtmax=5),
    func=MLPFunc
)

In [None]:
def rolloff_func(x, rolloff=1e-2):
    return x+rolloff*jnp.exp(-x/rolloff)

In [None]:
def curl_loss(key, model, z, x, extract_max_z=10.):
    rand_z = jax.random.uniform(key, 1, minval=0.0, maxval=extract_max_z)
    jac_drift = jax.jacfwd(lambda a:model.func_drift(z, a, 0.))(x)
    jac_ex = jax.jacfwd(lambda a:model.func_extract(rand_z[0], a, 0.))(x)
    return (jac_drift[1,0] - jac_drift[0,1])**2 + (jac_ex[1,0] - jac_ex[0,1])**2

In [None]:
def single_likelihood_loss(key, model, condition, t1, z, min_p=1e-3, N_samples=4, tpc_r=66.4, curl_loss_multiplier=1000.):
    keys = jax.random.split(key,2)
    samples = generate_samples(keys[0], condition[np.newaxis,...], N_samples)
    transformed_samples, logdet = eqx.filter_vmap(lambda y: model.transform_and_log_det(y=y, t1=t1))(samples)
    sample_r = compute_r(transformed_samples)
    p_surv = vec_civ_map(jnp.vstack((sample_r, np.repeat(z, N_samples))).T)
    # p_surv = jnp.where(p_surv<min_p, min_p, p_surv)
    p_surv = rolloff_func(p_surv, min_p)*jnp.prod(jnp.where(sample_r <= tpc_r, jnp.zeros_like(sample_r)+1, jnp.exp((tpc_r-sample_r)/10000)))
    # import pdb; pdb.set_trace()
    # return jnp.log(jnp.sum(p_surv*logdet)) - jnp.log(N_samples)
    return (-jax.nn.logsumexp(a=logdet, b=p_surv) + jnp.log(N_samples) + #this is negative likelihood
        curl_loss_multiplier*curl_loss(keys[1], model, z, transformed_samples[0]))

In [None]:
curl_loss(key, model, 5, np.array([50,30.]))

In [None]:
-jax.nn.logsumexp(a=jnp.array([1,1]), b=jnp.array([1,0.1]))

In [None]:
single_likelihood_loss(key, model, conditions[0], z_sel_scaled[0], z_sel[0])

In [None]:
@eqx.filter_jit
def likelihood_loss(model, key, conditions, t1s, zs, N_samples=4):
    keys = jax.random.split(key, len(zs))
    vec_loss = eqx.filter_vmap(lambda k, cond, t1, z: single_likelihood_loss(k, model, cond, t1, z, N_samples=N_samples))
    return jnp.mean(vec_loss(keys, conditions, t1s, zs))

In [None]:
likelihood_loss(model, key, conditions[:32], z_sel_scaled[:32], z_sel[:32])

In [None]:
%%timeit
eqx.filter_value_and_grad(likelihood_loss)(model, key, conditions[:32], z_sel_scaled[:32], z_sel[:32])

In [None]:
def train(
    key, 
    model,
    optim: optax.GradientTransformation, 
    epochs: int,
    conditions: Array,
    t1s: Array,
    zs: Array,
    N_train: int,
    N_batch: int,
    N_samples: int,
    N_test: int,
    use_best:bool=False,
    loss=likelihood_loss,
):
    opt_state = optim.init(eqx.filter(model, eqx.is_array))
    cond_train = conditions[:-N_test]
    t1s_train = t1s[:-N_test]
    zs_train = zs[:-N_test]

    N_data_loops = N_train//N_batch

    cond_test = conditions[-N_test:]
    t1s_test = t1s[-N_test:]
    zs_test = zs[-N_test:]
    @eqx.filter_jit
    def make_step(model, opt_state: PyTree, key, conds, t1s, zs):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, key, conds, t1s, zs, N_samples=N_samples)
        # import pdb; pdb.set_trace()
        updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    loop = trange(epochs)
    best_model = model
    train_loss_list = []
    test_loss_list = [loss(model, key, cond_test, t1s_test, zs_test, N_samples=N_samples)]
    for i in loop:
        key, thiskey = jax.random.split(key, 2)
        indices = jax.random.permutation(thiskey, jnp.arange(N_train))
        cond_train = cond_train[indices]
        t1s_train = t1s_train[indices]
        zs_train = zs_train[indices]
        for j in range(N_data_loops):
            key, thiskey = jax.random.split(key, 2)
            this_conds = cond_train[j*N_batch : (j+1)*N_batch]
            this_t1s = t1s_train[j*N_batch : (j+1)*N_batch]
            this_zs = zs_train[j*N_batch : (j+1)*N_batch]
            model, opt_state, train_loss = make_step(model, opt_state, thiskey, this_conds, this_t1s, this_zs)
            train_loss_list.append(train_loss)
            loop.set_postfix({
                'loss': f'{train_loss_list[-1]:0.2f}', 
                'loss MA': f'{np.mean(train_loss_list[-64:]):0.3f}',
                'test loss': f'{test_loss_list[-1]:0.3f}',
            })
        test_loss_list.append(loss(model, key, cond_test, t1s_test, zs_test, N_samples=N_samples))
        if np.argmin(test_loss_list) == len(test_loss_list)-1:
            best_model = model
    if use_best:
            model = best_model
    return model, train_loss_list, test_loss_list

In [None]:
key, train_key = jax.random.split(key, 2)
# optax_sched = optax.linear_schedule(init_value=2e-3, end_value=1e-4, transition_steps=20)
optax_sched = optax.join_schedules([
    optax.constant_schedule(2e-3), 
    optax.constant_schedule(2e-4), 
    optax.constant_schedule(2e-5)
], [25, 30])
optimizer = optax.adamw(learning_rate=optax_sched, weight_decay=1e-4)
optimizer = optax.apply_if_finite(optax.MultiSteps(optimizer, every_k_schedule=4), max_consecutive_errors=4)
trained_model, train_loss, test_loss = train(
    train_key,
    model,
    optimizer,
    100,
    cond_sel,
    z_sel_scaled,
    z_sel,
    N_train=200000,
    N_batch=2048,
    N_samples=16,
    N_test=20000,
    use_best=True
)

In [None]:
fig = plt.figure(figsize=(6,5))
ax = fig.add_subplot(111)
ax.plot(np.linspace(0, len(test_loss), len(train_loss)), train_loss)
ax.plot(test_loss)
plt.show()

In [None]:
fig = plt.figure(figsize=(6,5))
ax = fig.add_subplot(111)
# ax.plot(np.linspace(0, len(test_loss), len(train_loss)), train_loss)
ax.plot(test_loss[2:])
plt.show()

In [None]:
print(np.min(test_loss))

In [None]:
key, train_key = jax.random.split(key, 2)
optimizer = optax.adamw(learning_rate=1e-5, weight_decay=1e-4)
optimizer = optax.apply_if_finite(optax.MultiSteps(optimizer, every_k_schedule=4), max_consecutive_errors=4)
trained_model, train_loss, test_loss = train(
    train_key,
    trained_model,
    optimizer,
    20,
    cond_sel,
    z_sel_scaled,
    z_sel,
    N_train=200000,
    N_batch=2048,
    N_samples=16,
    N_test=20000,
    use_best=True
)

In [None]:
fig = plt.figure(figsize=(6,5))
ax = fig.add_subplot(111)
# ax.plot(np.linspace(0, len(test_loss), len(train_loss)), train_loss)
ax.plot(test_loss)
plt.show()

In [None]:
key, train_key = jax.random.split(key, 2)
optimizer = optax.adamw(learning_rate=1e-5, weight_decay=1e-4)
optimizer = optax.apply_if_finite(optax.MultiSteps(optimizer, every_k_schedule=8), max_consecutive_errors=4)
trained_model, train_loss, test_loss = train(
    train_key,
    trained_model,
    optimizer,
    40,
    cond_sel,
    z_sel_scaled,
    z_sel,
    N_train=200000,
    N_batch=2048,
    N_samples=16,
    N_test=20000,
    use_best=True
)

In [None]:
fig = plt.figure(figsize=(6,5))
ax = fig.add_subplot(111)
# ax.plot(np.linspace(0, len(test_loss), len(train_loss)), train_loss)
ax.plot(test_loss)
plt.show()

In [None]:
compiled_flow_transform = eqx.filter_jit(flow.bijection.transform)
vec_flow_transform = eqx.filter_jit(eqx.filter_vmap(flow.bijection.transform))

In [None]:
N_plot = 80000

In [None]:
naive_xy_sel = data_inv_transformation(vec_flow_transform(np.zeros((N_plot,2)), condition=cond_sel[:N_plot]))

In [None]:
naive_r_sel = compute_r(naive_xy_sel)

In [None]:
vec_transformation = eqx.filter_vmap(lambda a,b: trained_model.transform(y=a, t1=b))

In [None]:
xy_sel = vec_transformation(naive_xy_sel[:N_plot], z_sel_scaled[:N_plot])

In [None]:
r_sel = compute_r(xy_sel)

In [None]:
naive_xy_sel[:10]

In [None]:
xy_sel[:10]

In [None]:
def plot_wires_and_tpc(ax, color='k', tpc_r=66.4, linewidth=1):
    theta = np.linspace(0, 2*np.pi, 200)
    x_circ = tpc_r*np.cos(theta)
    y_circ = tpc_r*np.sin(theta)
    ax.plot(x_circ, y_circ, color=color, linewidth=linewidth)
    x = np.array([-10*tpc_r, 10*tpc_r])
    ax.plot(x, np.sqrt(3)*x + 28.3, color=color, linestyle='--', linewidth=linewidth)
    ax.plot(x, np.sqrt(3)*x - 28.3, color=color, linestyle='--', linewidth=linewidth)
    ax.plot(x, np.sqrt(3)*x + 31.8, color=color, linestyle='--', linewidth=linewidth)
    ax.plot(x, np.sqrt(3)*x - 31.8, color=color, linestyle='--', linewidth=linewidth)
    ax.plot(x, np.sqrt(3)*x + 26.3, color=color, linestyle='--', linewidth=linewidth)
    ax.plot(x, np.sqrt(3)*x - 26.3, color=color, linestyle='--', linewidth=linewidth)

In [None]:
c_vec_civ_map = jax.jit(vec_civ_map)

In [None]:
H, xbins, ybins = np.histogram2d(
    r_sel**2, z_sel[:N_plot], 
    bins=(np.linspace(0, 70**2, 100), np.linspace(-148.6515, 0, 81))
)

civ_H = np.zeros((len(xbins), len(ybins)))

for i,r in enumerate(tqdm(xbins)):
        civ_H[i] = c_vec_civ_map(np.concatenate((np.repeat([np.sqrt(r)], len(ybins))[:,np.newaxis], ybins[:,np.newaxis]), axis=1))[:,0]
        
fig = plt.figure(figsize=(9,4), tight_layout=True)
ax = fig.add_subplot(121)
ax.pcolormesh(xbins, ybins, H.T)
ax.contour(xbins, ybins, civ_H.T, levels=[0.5], colors='white')
ax.set(xlabel='R^2 (cm^2)', ylabel='z (cm)')

H, xbins, ybins = np.histogram2d(
    xy_sel[:N_plot,0], xy_sel[:N_plot,1], 
    bins=(np.linspace(-85, 85, 100), np.linspace(-75, 75, 100))
)
ax = fig.add_subplot(122)
ax.axis('equal')
ax.pcolormesh(xbins, ybins, H.T)
plot_wires_and_tpc(ax, color='w')
ax.set(xlabel='x (cm)', ylabel='y (cm)', xlim=[-75,75], ylim=[-70,70])
plt.show()

In [None]:
H, xbins, ybins = np.histogram2d(
    naive_r_sel**2, z_sel[:N_plot], 
    bins=(np.linspace(0, 70**2, 100), np.linspace(-148.6515, 0, 81))
)

civ_H = np.zeros((len(xbins), len(ybins)))

for i,r in enumerate(tqdm(xbins)):
        civ_H[i] = c_vec_civ_map(np.concatenate((np.repeat([np.sqrt(r)], len(ybins))[:,np.newaxis], ybins[:,np.newaxis]), axis=1))[:,0]


fig = plt.figure(figsize=(9,4), tight_layout=True)
ax = fig.add_subplot(121)
ax.pcolormesh(xbins, ybins, H.T)
ax.contour(xbins, ybins, civ_H.T, levels=[0.5], colors='white')
ax.set(xlabel='R^2 (cm^2)', ylabel='z (cm)')

H, xbins, ybins = np.histogram2d(
    naive_xy_sel[:N_plot,0], naive_xy_sel[:N_plot,1], 
    bins=(np.linspace(-85, 85, 100), np.linspace(-75, 75, 100))
)
ax = fig.add_subplot(122)
ax.pcolormesh(xbins, ybins, H.T)
plot_wires_and_tpc(ax, color='w')
ax.set(xlabel='x (cm)', ylabel='y (cm)', xlim=[-75,75], ylim=[-70,70])
plt.show()

In [None]:
xbins, ybins=(np.linspace(-70, 70, 100), np.linspace(-10, 0, 81))

flow_field = np.zeros((len(xbins), len(ybins)))

for i,r in enumerate(tqdm(xbins)):
    for j,z_plot in enumerate(ybins):
        flow_field[i,j] = trained_model.func_extract(-z_plot, jnp.array([0,r]), 0)[0]*np.sign(r)
        
fig = plt.figure(figsize=(6,4), tight_layout=True)
ax = fig.add_subplot(111)
c = ax.pcolormesh(xbins, ybins, flow_field.T/z_scale, cmap=colormaps['bwr'], vmin=-0.2, vmax=0.2)
# ax.contour(xbins, ybins, civ_H.T, levels=[0.5], colors='white')
plt.colorbar(c)
ax.set(xlabel='R (cm)', ylabel='z (AU)')

In [None]:
jnp.mean(flow_field)*10

In [None]:
xbins, ybins=(np.linspace(-70, 70, 100), np.linspace(-tpc_height, 0, 81))
flow_field = np.zeros((len(xbins), len(ybins)))

for i,r in enumerate(tqdm(xbins)):
    for j,z_plot in enumerate(ybins):
        flow_field[i,j] = trained_model.func_drift(-z_plot/z_scale, jnp.array([0,r]), 0)[1]*np.sign(r)

field_line_arr = np.linspace(-tpc_r, tpc_r, 20)
lines_arr = np.zeros((len(field_line_arr), len(ybins)))
for i, z_line in enumerate(tqdm(ybins)):
    lines_arr[:,i] = vec_transformation(
    np.concatenate((np.zeros_like(field_line_arr)[:,np.newaxis], field_line_arr[:,np.newaxis]), axis=1), 
    np.repeat([-z_line/z_scale], len(field_line_arr)))[:,1]
        
fig = plt.figure(figsize=(4,4), tight_layout=True)
ax = fig.add_subplot(111)
c = ax.pcolormesh(xbins, ybins, flow_field.T/z_scale, cmap=colormaps['bwr'], vmin=-0.2, vmax=0.2)
ax.contour(xbins, ybins, civ_H.T, levels=[0.5], colors='white')
ax.contour(-xbins, ybins, civ_H.T, levels=[0.5], colors='white')
for i in range(len(field_line_arr)):
    ax.plot(lines_arr[i], ybins, color='white', linestyle='--')
plt.colorbar(c, label='Ratio of horizontal to vertical drift velocity')
ax.set(xlabel='y (cm)', ylabel='z (cm)', xlim=[-tpc_r, tpc_r], ylim=[-tpc_height, 0])

In [None]:
fig = plt.figure(figsize=(5,4), tight_layout=True)
ax = fig.add_subplot(111)
c = ax.pcolormesh(xbins, ybins, flow_field.T/z_scale, cmap=colormaps['bwr'], vmin=-0.2, vmax=0.2)
ax.contour(xbins, ybins, civ_H.T, levels=[0.5], colors='black')
ax.contour(-xbins, ybins, civ_H.T, levels=[0.5], colors='black')
for i in range(len(field_line_arr)):
    ax.plot(lines_arr[i], ybins, color='grey', linestyle='--')
plt.colorbar(c, label='Ratio of horizontal to vertical drift velocity')
ax.set(xlabel='y (cm)', ylabel='z (cm)', xlim=[-tpc_r, tpc_r], ylim=[-tpc_height, 0])
plt.show()

In [None]:
xbins, ybins=(np.linspace(-70, 70, 100), np.linspace(-tpc_height, 0, 81))
flow_field = np.zeros((len(xbins), len(ybins)))

for i,r in enumerate(tqdm(xbins)):
    for j,z_plot in enumerate(ybins):
        flow_field[i,j] = trained_model.func_drift(-z_plot/z_scale, jnp.array([r,0]), 0)[0]*np.sign(r)

field_line_arr = np.linspace(-tpc_r, tpc_r, 20)
lines_arr = np.zeros((len(field_line_arr), len(ybins)))
for i, z_line in enumerate(tqdm(ybins)):
    lines_arr[:,i] = vec_transformation(
    np.concatenate((field_line_arr[:,np.newaxis], np.zeros_like(field_line_arr)[:,np.newaxis]), axis=1), 
    np.repeat([-z_line/z_scale], len(field_line_arr)))[:,0]
        
fig = plt.figure(figsize=(4,4), tight_layout=True)
ax = fig.add_subplot(111)
c = ax.pcolormesh(xbins, ybins, flow_field.T/z_scale, cmap=colormaps['bwr'], vmin=-0.2, vmax=0.2)
ax.contour(xbins, ybins, civ_H.T, levels=[0.5], colors='white')
ax.contour(-xbins, ybins, civ_H.T, levels=[0.5], colors='white')
for i in range(len(field_line_arr)):
    ax.plot(lines_arr[i], ybins, color='white', linestyle='--')
plt.colorbar(c, label='Ratio of horizontal to vertical drift velocity')
ax.set(xlabel='x (cm)', ylabel='z (cm)', xlim=[-tpc_r, tpc_r], ylim=[-tpc_height, 0])

In [None]:
fig = plt.figure(figsize=(5,4), tight_layout=True)
ax = fig.add_subplot(111)
c = ax.pcolormesh(xbins, ybins, flow_field.T/z_scale, cmap=colormaps['bwr'], vmin=-0.2, vmax=0.2)
ax.contour(xbins, ybins, civ_H.T, levels=[0.5], colors='black')
ax.contour(-xbins, ybins, civ_H.T, levels=[0.5], colors='black')
for i in range(len(field_line_arr)):
    ax.plot(lines_arr[i], ybins, color='grey', linestyle='--')
plt.colorbar(c, label='Ratio of horizontal to vertical drift velocity')
ax.set(xlabel='x (cm)', ylabel='z (cm)', xlim=[-tpc_r, tpc_r], ylim=[-tpc_height, 0])
plt.show()

In [None]:
lines_arr.shape

In [None]:
curl_loss(key, model, 5, np.array([50,30.]))

In [None]:
curl_loss(key, trained_model, 5, np.array([50,30.]))

In [None]:
jnp.mean(flow_field)*tpc_height/2