In [1]:
from pyoculus.problems import CylindricalBfield
from pyoculus.solvers import PoincarePlot
import matplotlib.pyplot as plt
import numpy as np

### Analytically deriving the toy model and perturbation

$$
\text{As}\quad\vec{\nabla}\cdot\vec{B} = 0\quad\text{it get to}\quad\vec{B} = \vec{\nabla} \times \vec{A}\quad\text{thus with}\quad\vec{A} = (A_r, A_\phi, A_z)\\
$$
$$
\vec{B} = (-\partial_z A_\phi,\, \partial_z A_r - \partial_r A_z,\, 1/r\cdot\partial_z(r A_\phi))
$$

In [None]:
import sympy as sp

In [None]:
r, phi, z, d, m, n, R, sf, shear = sp.symbols('r phi z d m n R sf shear', real=True, positive=True)
rr = sp.symbols('rr[0] rr[1] rr[2]')
I = sp.I  # Imaginary unit

replacement_dict = {r: rr[0], phi: rr[1], z: rr[2], I: 1j}

# Define b
b = sp.sqrt((R - r)**2 + z**2)

# Define Psi
Psi = b**2

# Define F
F = 2 * (sf + shear * b**2) * sp.sqrt(R**2 - b**2)

# Define the function
psi_mb = sp.sqrt(2/sp.pi) * (b**2/d**3) * sp.exp(-b**2/(2*d**2)) * ((r - R) + I*z)**m * sp.exp(I*n*phi)
psi_gaussian = (1/sp.sqrt(2*sp.pi*d**2)) * sp.exp(-b**2/(2*d**2)) * ((r - R) + I*z)**m * sp.exp(I*n*phi)

In [None]:
psi_mb

In [None]:
psi_gaussian

In [None]:
B_equilibrium = sp.Matrix([-1/r * sp.diff(Psi, z), F/r, 1/r * sp.diff(Psi, r)])
B_equilibrium

In [None]:
B_mb = sp.simplify(sp.re(sp.Matrix([-1/r * sp.diff(psi_mb, z), 0, 1/r * sp.diff(psi_mb, r)])))
B_gaussian = sp.simplify(sp.re(sp.Matrix([-1/r * sp.diff(psi_gaussian, z), 0, 1/r * sp.diff(psi_gaussian, r)])))

Transform the $\vec{B}$ expressions from `sympy` to actual functions

In [None]:
temp1 = sp.lambdify((rr, R, sf, shear), B_equilibrium.subs(replacement_dict), 'numpy')
toroidal = lambda *args: np.squeeze(temp1(*args))

temp2 = sp.lambdify((rr, R, d, m, n), B_mb.subs(replacement_dict), 'numpy')
mb_perturbation = lambda *args: np.squeeze(temp2(*args))

temp3 = sp.lambdify((rr, R, d, m, n), B_gaussian.subs(replacement_dict), 'numpy')
gaussian_perturbation = lambda *args: np.squeeze(temp3(*args))

In [None]:
B_equilibrium

Compute the jacobians

In [None]:
grad_B_equilibrium = sp.zeros(3)

for i in range(3):
    grad_B_equilibrium[i, 0] = sp.diff(B_equilibrium[i], r)
    grad_B_equilibrium[i, 1] = 1/r * sp.diff(B_equilibrium[i], phi)
    grad_B_equilibrium[i, 2] = sp.diff(B_equilibrium[i], z)


grad_B_mb = sp.zeros(3)

for i in range(3):
    grad_B_mb[i, 0] = sp.diff(B_mb[i], r)
    grad_B_mb[i, 1] = 1/r * sp.diff(B_mb[i], phi)
    grad_B_mb[i, 2] = sp.diff(B_mb[i], z)

grad_B_gaussian = sp.zeros(3)

for i in range(3):
    grad_B_gaussian[i, 0] = sp.diff(B_gaussian[i], r)
    grad_B_gaussian[i, 1] = 1/r * sp.diff(B_gaussian[i], phi)
    grad_B_gaussian[i, 2] = sp.diff(B_gaussian[i], z)

In [None]:
grad_B_equilibrium = sp.simplify(grad_B_equilibrium)
grad_B_mb = sp.simplify(grad_B_mb)
grad_B_gaussian = sp.simplify(grad_B_gaussian)

#### Printing in python format

In [None]:
sp.printing.python(B_gaussian.subs(replacement_dict))

### Numerical adventure

In [30]:
import sys
import os
sys.path.append('..')

from toybox import *

In [3]:
def gaussian_psi(rr, R=3.0, d=0.1, m=2, n=1):
    return ((rr[0] - R + rr[2]*1j)**m)*np.exp(-0.5*((-rr[0] + R)**2 + rr[2]**2)/d**2 + 1j*n*0)/(np.sqrt(d**2)*np.sqrt(2*np.pi))

def mb_psi(rr, R=3.0, d=0.1, m=2, n=1):
    return np.exp(-0.5*((-rr[0] + R)**2 + rr[2]**2)/d**2 + n*rr[1]*1j)*np.sqrt(2/np.pi)*(rr[0] - R + rr[2]*1j)**m*((-rr[0] + R)**2 + rr[2]**2)/d**3

In [75]:
import jax.numpy as jnp
from jax import jacfwd
from functools import partial

PERT_TYPES_DICT = {"maxwell-boltzmann": pert_maxwellboltzmann, "gaussian": pert_gaussian}

class AnalyticCylindricalBfield(CylindricalBfield):
    """
    
    """
    def __init__(
        self,
        R,
        sf,
        shear,
        perturbations_args
    ):
        """
        pert1_dict = {m:1, n:2, d:1, type: "maxwell-boltzmann", amplitude: 10}
        pert2_dict = {m:1, n:2, d:1, type: "gaussian", amplitude: 10}
        myBfield = AnalyticalBfieldProblem(R= 3, sf = 1.1, shear=3 pert=[pert1_dict, pert2_dict])
        myBfield.pert_list[0](rphiz)
        >> value
        """

        # Define the equilibrium field and its gradient
        self.B_equilibrium = partial(equ_squared, R=R, sf=sf, shear=shear)
        self.dBdX_equilibrium = lambda rr: jacfwd(self.B_equilibrium)(rr)

        # Define the perturbations and their gradients
        self._perturbations = [None] * len(perturbations_args)
        for pertdic in perturbations_args:
            pertdic.update({"R": R})

        self.perturbations_args = perturbations_args
        self._initialize_perturbations()

        # Call the CylindricalBfield constructor that needs (R,Z) of the axis 
        super().__init__(R, 0)

    @property
    def amplitudes(self):
        return [pert["amplitude"] for pert in self.perturbations_args]
    
    @amplitudes.setter
    def amplitudes(self, value):
        self.amplitudes = value
        self._initialize_perturbations()

    def set_amplitude(self, index, value):
        self.amplitudes[index] = value
        self._initialize_perturbations(index)

    def _initialize_perturbations(self, index = None):
        if index is not None:
            indices = [index]
        else:
            indices = range(len(self.perturbations_args))

        for i in indices:
            tmp_args = self.perturbations_args[i].copy()
            tmp_args.pop("amplitude")
            tmp_args.pop("type")

            self._perturbations[i] = partial(PERT_TYPES_DICT[self.perturbations_args[i]["type"]], **tmp_args)
        
        self.B_perturbation = lambda rr: jnp.sum(jnp.array([pertdic["amplitude"] * self._perturbations[i](rr) for i, pertdic in enumerate(self.perturbations_args)]), axis=0)
        self.dBdX_perturbation = lambda rr: jacfwd(self.B_perturbation)(rr)

    @property
    def perturbations(self):
        return [lambda rr: pertdic["amplitude"] * self._perturbations[i](rr) for i, pertdic in enumerate(self.perturbations_args)]

    # BfieldProblem methods implementation 
    def B(self, rr):
        return self.B_equilibrium(rr) + self.B_perturbation(rr)

    def dBdX(self, rr):
        return self.dBdX_equilibrium(rr) + self.dBdX_perturbation(rr)

    # def B_many(self, r, phi, z, input1D=True):
    #     return jnp.array([self.B([r[i], phi[i], z[i]]) for i in range(len(r))])

    # def dBdX_many(self, r, phi, z, input1D=True):
    #     return jnp.array([self.dBdX([r[i], phi[i], z[i]]) for i in range(len(r))])


In [76]:
pert_1 = {"m": 1, "n": -1, "d": 1, "type": "maxwell-boltzmann", "amplitude": 0.2}
pert_2 = {"m": 1, "n": 0, "d": 1, "type": "gaussian", "amplitude": -1.5}

ps = AnalyticCylindricalBfield(3, 3/3, 0.7/3, [pert_1, pert_2])

In [77]:
ps.amplitudes

[0.2, -1.5]

In [79]:
jnp.sum(jnp.array([pertdic["amplitude"] * ps._perturbations[i]([1,1,1]) for i, pertdic in enumerate(ps.perturbations_args)]), axis=0)

Array([0.03373277, 0.        , 0.16395476], dtype=float32)

In [80]:
ps.perturbations[0]([1,1,1])

Array([ 0.09824152, -0.        ,  0.14736232], dtype=float32)

In [81]:
ps.B_perturbation([1, 1, 1])

Array([0.03373277, 0.        , 0.16395476], dtype=float32)

In [82]:
ps.B([1, 1, 1])

Array([-1.9662672,  8.666666 , -3.8360453], dtype=float32)

In [None]:
def perturbation(rr, R, d, m, n):
    return -15*tb.pert_gaussian(rr, R, 1., 1, 0) + 2*tb.pert_maxwellboltzmann(rr, R, 1, 1, -1)

def perturbation_dBdX(rr, R, d, m, n):
    return tb.pert_gaussian_dBdX(rr, R, 0.1, 1, 0).reshape(3,3) + tb.pert_maxwellboltzmann_dBdX(rr, R, d, m, n).reshape(3,3)

# def perturbation(rr, R, d, m, n):
#     return tb.pert_gaussian(rr, R, 0.1, 1, 0)

# def perturbation_dBdX(rr, R, d, m, n):
#     return tb.pert_gaussian_dBdX(rr, R, 0.1, 1, 0).reshape(3,3)

# def perturbation(rr, R, d, m, n):
#     return tb.pert_maxwellboltzmann(rr, R, d, m, n) 

# def perturbation_dBdX(rr, R, d, m, n):
#     return tb.pert_maxwellboltzmann_dBdX(rr, R, d, m, n).reshape(3,3)

# def perturbation(rr, R, d, m, n):
#     return tb.pert_gaussian(rr, R, d, m, n)

# def perturbation_dBdX(rr, R, d, m, n):
#     return tb.pert_gaussian_dBdX(rr, R, d, m, n).reshape(3,3)

equ_args = {
    "R": 3.0,
    "shear": 3/3,
    "sf": 0.7/3,
}

pert_args = {
    "R": 3.0,
    "d": 1,
    "m": 2,
    "n": -1,
}

ps = tb.AnalyticCylindricalBfield([tb.equ_squared, tb.equ_squared_dBdX], [perturbation, perturbation_dBdX], A_p=0.1, equilibrium_args=equ_args, perturbation_args=pert_args)

In [None]:
fig, axs = plt.subplots(1,4, figsize=(20, 5))

r = np.linspace(2, 5, 200)
z = np.linspace(-2, 2, 200)

R, Z = np.meshgrid(r, z)
Bs = np.array([ps.B([r, 0.0, z]) for r, z in zip(R.flatten(), Z.flatten())]).reshape(
    R.shape + (3,)
)
mappable = axs[2].contourf(R, Z, np.linalg.norm(Bs, axis=2))
fig.colorbar(mappable)

Bs = np.array([ps.B_equilibrium([r, 0.0, z]) for r, z in zip(R.flatten(), Z.flatten())]).reshape(
    R.shape + (3,)
)
mappable = axs[3].contourf(R, Z, np.linalg.norm(Bs, axis=2))
fig.colorbar(mappable)

r = np.linspace(2, 5, 80)
z = np.linspace(-2, 2, 100)

R, Z = np.meshgrid(r, z)
Bs = np.array([ps._A_p*ps.B_perturbation([r, 0.0, z]) for r, z in zip(R.flatten(), Z.flatten())]).reshape(
    R.shape + (3,)
)
mappable = axs[1].contourf(R, Z, np.linalg.norm(Bs, axis=2))
fig.colorbar(mappable)

psi = np.array([mb_psi([r, 0.0, z], ps.perturbation_args['R'], ps.perturbation_args['d'], ps.perturbation_args['m'], ps.perturbation_args['n']) for r, z in zip(R.flatten(), Z.flatten())]).reshape(
    R.shape
)
psi = np.real(psi)
mappable = axs[0].contourf(R, Z, psi)
axs[0].quiver(R, Z, Bs[:, :, 0], Bs[:, :, 2])
# fig.colorbar(mappable)

# Set the aspect equal
for ax in axs:
    ax.set_aspect("equal")
    ax.scatter(ps._R0, 0, color="r", s=1)

# Add circle to the plot
# crop_circle(ps._R0, ps._d, ax, color='r')
# crop_circle(ps._R0, np.sqrt(2)*ps._d, ax, color='r')

In [None]:
# set up the integrator
iparams = dict()
iparams["rtol"] = 1e-7

# set up the Poincare plot
pparams = dict()
pparams["Rbegin"] = 2.98
pparams["Rend"] = 3.6
pparams["nPtrj"] = 20
pparams["nPpts"] = 500
pparams["zeta"] = 0

# Set RZs
nfieldlines = pparams["nPtrj"]+1
Rs = np.linspace(3.0, 3.75, nfieldlines)
Zs = np.linspace(0., 0., nfieldlines)
RZs = np.array([[r, z] for r, z in zip(Rs, Zs)])

pplot = PoincarePlot(ps, pparams, integrator_params=iparams)
pdata = pplot.compute(RZs)
# pdata = pplot.compute()

In [None]:
pplot.plot(marker=".", s=1)
plt.xlim(2.6, 3.5)
plt.scatter(ps._R0, 0, color="r", s=4)
fig = plt.gcf()

In [None]:
pplot.compute_iota()
pplot.plot_q()

plt.hlines(np.abs(ps.perturbation_args['m']/ps.perturbation_args['n']), 3, 4, color="k", linestyle="--")
# plt.hlines(ps._n/ps._m, 3, 3.7, color="k", linestyle="--")

# rho = np.linspace(3.001, 3.7, 20)
# q=ps._sf+ps._shear*(ps._R0-rho)**2

# plt.plot(rho, q, color="orange", linestyle="--")

In [None]:
from pyoculus.solvers import FixedPoint

# set up the integrator
iparams = dict()
iparams["rtol"] = 1e-22

pparams = dict()
pparams["nrestart"] = 0
pparams['niter'] = 200

fp = FixedPoint(ps, pparams, integrator_params=iparams)

# fp.compute(guess=[3.5, 0], pp=1, qq=2, sbegin=2, send=4, tol = 1e-12)
fp.compute(guess=[3., -0.6], pp=1, qq=2, sbegin=2, send=4, tol = 1e-12)

In [None]:
results = [list(p) for p in zip(fp.x, fp.y, fp.z)]
results

In [None]:
ax = fig.get_axes()[0]
ax.scatter(results[0][0], results[0][2], color="b", s=5, marker="X") 
ax.scatter(results[1][0], results[1][2], color="b", s=5, marker="X")
fig

## with ivp

In [None]:
from scipy.integrate import solve_ivp

In [None]:
def integrate_ivp(Bfield_2D, RZstart, phis, **kwargs):
    options = {
        "rtol": 1e-7,
        "atol": 1e-9,
        "nintersect": 400,
        "method": "DOP853",
        "direction": 1,
        "m": 1,
        "nfp": 1,
    }
    options.update(kwargs)

    assert RZstart.shape[1] == 2, "RZstart must be a 2D array with shape (n, 2)"
    assert len(phis) > 0, "phis must be a list of floats with at least one element"
    assert isinstance(options["nintersect"], int) and options["nintersect"] > 0, "nintersect must be a positive integer"
    assert options["direction"] in [-1, 1], "direction must be either -1 or 1"
    
    # setup the phis of the poincare sections
    phis = np.unique(np.mod(phis, 2 * np.pi / options['nfp']))
    phis.sort()

    # setup the evaluation points for those sections
    phi_evals = np.array(
        [
            phis + options['m'] * 2 * np.pi * i / options['nfp']
            for i in range(options["nintersect"] + 1)
        ]
    )

    out = solve_ivp(
        Bfield_2D,
        [0, phi_evals[-1, -1]],
        RZstart.flatten(),
        t_eval=phi_evals.flatten(),
        method=options["method"],
        atol=options["atol"],
        rtol=options["rtol"],
    )

    return out

In [None]:
def Bfield_2D(t, rzs, direction = 1):
    rzs = rzs.reshape((-1, 2))
    phis = direction*(t % (2 * np.pi)) * np.ones(rzs.shape[0])
    Bs = np.array([ps.B([rzs[i, 0], phis[i], rzs[i, 1]]) for i in range(len(rzs))])

    return np.array([Bs[:,0]/Bs[:,1], Bs[:,2]/Bs[:,1]]).T.flatten()

In [None]:
nfieldlines = 1
Rs = np.linspace(3, 3, nfieldlines)
Zs = np.linspace(0.01, 0.5, nfieldlines)
RZs = np.array([[r, z] for r, z in zip(Rs, Zs)])

In [None]:
out = integrate_ivp(Bfield_2D, RZs, [0])

In [None]:
ys = out.y.reshape(nfieldlines, 2, -1)
for yy in ys:
    plt.scatter(yy[0, :], yy[1, :], s=10, marker=".")

## The tokamap

In [None]:
def P(psi, theta, K):
    return psi - 1 - K/(2*np.pi) * np.sin(2*np.pi*theta)

def W(psi, w):
    return w/4*(2-psi)*(2-2*psi+psi**2)

In [None]:
def tokamap(psi, theta, K, w):
    p = P(psi, theta, K)
    psi_evolved = 0.5*(p + np.sqrt(p**2 + 4*psi))
    theta_evolved = theta + W(psi_evolved, w) - K/((2*np.pi)**2)/((1+psi_evolved)**2)*np.cos(2*np.pi*theta)
    return np.array([
        psi_evolved,
        theta_evolved % 1
    ])

In [None]:
x = np.array([0., 0.5])

In [None]:
x = tokamap(x[0], x[1], 3.7, 1.)
print(x)

In [None]:
xi = np.linspace(0, 1, 7)
Xi = np.meshgrid(xi, xi)
Xi = np.array(Xi).reshape(2, -1)

In [None]:
nev = 1000
Ev = np.empty((Xi.shape[0], Xi.shape[1], nev))

In [None]:
K = 3.7
w = 1.0

for i, x in enumerate(Xi.T):
    Ev[:, i, 0] = x
    xt = x.copy()
    for j in range(nev-1):
        xt = tokamap(xt[0], xt[1], K, w)
        Ev[:, i, j+1] = xt

In [None]:
for i in range(len(Ev[0,:,0])):
    plt.scatter(Ev[1, i, :], Ev[0, i, :], s=0.5, alpha=0.5)