## with ivp

In [None]:
from scipy.integrate import solve_ivp

def integrate_ivp(Bfield_2D, RZstart, phis, **kwargs):
    options = {
        "rtol": 1e-7,
        "atol": 1e-9,
        "nintersect": 100,
        "method": "DOP853",
        "direction": 1,
        "m": 1,
        "nfp": 1,
    }
    options.update(kwargs)

    assert RZstart.shape[1] == 2, "RZstart must be a 2D array with shape (n, 2)"
    assert len(phis) > 0, "phis must be a list of floats with at least one element"
    assert isinstance(options["nintersect"], int) and options["nintersect"] > 0, "nintersect must be a positive integer"
    assert options["direction"] in [-1, 
                                    1], "direction must be either -1 or 1"
    
    # setup the phis of the poincare sections
    phis = np.unique(np.mod(phis, 2 * np.pi / options['nfp']))
    phis.sort()

    # setup the evaluation points for those sections
    phi_evals = np.array(
        [
            phis + options['m'] * 2 * np.pi * i / options['nfp']
            for i in range(options["nintersect"] + 1)
        ]
    )

    out = solve_ivp(
        Bfield_2D,
        [0, phi_evals[-1, -1]],
        RZstart.flatten(),
        t_eval=phi_evals.flatten(),
        method=options["method"],
        atol=options["atol"],
        rtol=options["rtol"],
    )

    return out

def Bfield_2D(t, rzs, direction = 1):
    print(f"{t:.2f}, {rzs}")
    rzs = rzs.reshape((-1, 2))
    phis = direction*(t % (2 * np.pi)) * np.ones(rzs.shape[0])
    Bs = np.array([ps.B([rzs[i, 0], phis[i], rzs[i, 1]]) for i in range(len(rzs))])
    
    is_perturbed = (rzs[:, 0] < 1e-22) + (Bs[:,1] < 1e-24)
    Bs[is_perturbed, :] = np.array([0, 1, 0])
    
    return np.array([Bs[:,0]/Bs[:,1], Bs[:,2]/Bs[:,1]]).T.flatten()

In [None]:
nfieldlines = 20
Rs = np.linspace(3, 3, nfieldlines)
Zs = np.linspace(2, -2, nfieldlines)
RZs = np.array([[r, z] for r, z in zip(Rs, Zs)])

In [None]:
out = integrate_ivp(Bfield_2D, RZs, [0])

In [None]:
ys = out.y.reshape(nfieldlines, 2, -1)
for yy in ys:
    plt.scatter(yy[0, :], yy[1, :], s=10, marker=".")
# plt.scatter(ps._R0, -2, color="r", s=10)