# Multi-Harmonic Particle Tracking
This notebook is used to track a particle in a multi-harmonic RF system.
It is an implementation based on the theory presented in the Introductory CAS lectures on Longitudinal Beam Dynamics. It is meant for educational purposes and to play around and get an intuitive feel for beam dynamics.


Author: Anibal Luciano Pastinante (anibalpastinante@gmail.com)

Date: 22/08/2025

Acknowledgements: CAS team and CERN's SY-RF-BR section (builds upon the notebooks offered in the CAS course)

### Imports and Plotting Style


In [None]:
import numpy as np 
from ipywidgets import interactive 
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.integrate import quad
from fractions import Fraction
from matplotlib import animation
from functools import partial
from IPython.display import HTML


# Plotting style
plt.style.use('fivethirtyeight')

# Make the backround white though of the figure and the axes
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['savefig.facecolor'] = 'white'
plt.rcParams['axes.autolimit_mode'] = 'data'
plt.rcParams['figure.autolayout'] = True
plt.rcParams['figure.figsize'] = (10,8)
plt.rcParams['font.size'] = 12
plt.rcParams['legend.fontsize'] = 12


### Longitudinal Beam Dynamics Equations


In [None]:

def voltage_mh(V, r , h, phi, Phis):
    """
    Calculate the total acceleration voltage for the multi-harmonic RF system.
    Parameters:
    - V: voltage of the 1st harmonic: float
    - r: ratios of the n harmonics to the 1st harmonic: list
    - h: harmonic number: list
    - phi: phase: np.array
    - Phis: relative phase of the n harmonics relative to the 1st harmonic: list
    Returns:
    - V_total: total acceleration voltage: float
    """
    assert len(r) == len(h) == len(Phis)
    V_total = 0
    for i in range(len(r)):
        V_total += r[i]*V*(np.sin(h[i]*phi + Phis[i])) #- np.sin(h[i]*phi_s + Phis[i]))
    return V_total

def potential_mh(V, r , h, phi, dE_s, Phis, below_transition):
    """
    RF Potential for the multi-harmonic RF system as a function of phase phi.
    
    Parameters:
    - V: voltage of the 1st harmonic: float
    - r: ratios of the n harmonics to the 1st harmonic: list
    - h: harmonic number: list
    - phi: phase: np.array
    - dE_s: energy gain per turn of the synchronous particle: float
    - Phis: relative phase of the n harmonics relative to the 1st harmonic: list
    - below_transition: boolean to indicate if the particle is below the transition energy: bool
    Returns:
    - U: potential: np.array
    """
    V_total = - dE_s 
    for i in range(len(r)):
        V_total += voltage_mh(V, r, h, phi, Phis)

    U = np.zeros_like(phi)
    for i in range(len(phi)):
        # Integrate V_total from 0 to phi[i] using scipy.integrate.quad
        def integrand(phi_val):
            V_val = -dE_s
            for j in range(len(r)):
                V_val += r[j] * V * np.sin(h[j] * phi_val + Phis[j])
            return V_val
        
        
        U[i], _ = quad(integrand, 0, phi[i])

    return -U/(2*np.pi) if not below_transition else U/(2*np.pi)

def hamiltonian_mh(V, r , h, phi, dE, dE_s, Phis, below_transition, constant, E):
    """
    Hamiltonian for the multi-harmonic RF system as a function of particle coordinates in phase space (dE, phi).
    Parameters:
    - V: voltage of the 1st harmonic: float
    - r: ratios of the n harmonics to the 1st harmonic: list
    - h: harmonic number: list
    - phi: phase: np.array (same size as dE)
    - dE: energy deviation from the synchronous particle: np.array (same size as phi)
    - dE_s: energy gain per turn of the synchronous particle: float
    - Phis: relative phase of the n harmonics relative to the 1st harmonic: list
    - below_transition: boolean to indicate if the particle is below the transition energy: bool
    - constant: h*eta/beta^2 : float
    - E: energy of the particle: float
    Returns:
    - H: Hamiltonian value for each particle: np.array
    """
    return potential_mh(V, r, h, phi, dE_s, Phis, below_transition) + 0.5 * constant * dE**2 / E

def tracking_function(V, r , h, dE_s, Phis,  below_transition, constant,particle_state):
    """
    Tracking function for the multi-harmonic RF system as a function of phase phi for +1 turn.
    Parameters:
    - V: voltage of the 1st harmonic: float
    - r: ratios of the n harmonics to the 1st harmonic: list
    - h: harmonic numbers of the n harmonics: list
    - dE_s: energy gain per turn of the synchronous particle: float
    - Phis: relative phase of the n harmonics relative to the 1st harmonic: list
    - below_transition: boolean to indicate if the particle is below the transition energy: bool
    - constant: h*eta/beta^2 : float
    - particle_state: current state of the particle [[dE_n, phi_n] , E_n]: list
    Returns:
    - particle_state_new: new state of the particle [[dE_n+1, phi_n+1] , E_n+1]: list
    """
    dE = particle_state[0][1]
    phi = particle_state[0][0]
    E_new = particle_state[1] + dE_s
    
    dE_new = dE + voltage_mh(V, r, h, phi, Phis) - dE_s
    if below_transition:
        # 2*pi comes from the fact that we assume that 0.5 * h*eta/(beta^2 *E) = 0.5 * constant / E
        # So then 2*pi*h*eta/(beta^2 *E) =  2*pi*constant / E
        phi_new = phi - 2*np.pi* constant * dE_new / E_new
    else:
        phi_new = phi + 2*np.pi* constant * dE_new / E_new
    
    
    return [[phi_new, dE_new], E_new]

### Separatrix Computer and Plotter

In [None]:
def separatrices_mh_colored(
    V: float,
    r: list,
    h: list,
    phi_s: float,
    Phis: list,
    dE_s: float,
    E: float,
    constant: float,
    ax=None,
    n_phi: int = 2000,
    n_delta: int = 400,
    phi_margin_percent: float = 20.0,
    base_color: str | None = None,
    inner_alpha: float = 0.35,
    outer_fill_alpha: float = 0.20,
    inner_linewidth: float = 1.2,
    outer_linewidth: float = 2.0,
    epsilon: float = 1e-10,
    below_transition: bool = True,
    fill_potential: float = None,
    force_ylim: bool = False,
    ):
    """
    Draw separatrix contours H(phi, dE) = U(phi) + 0.5*constant*dE^2 / E using a single base color.
    - Outer separatrix: bold line + filled interior (outer_fill_alpha).
    - Inner separatrices: lighter alpha lines (inner_alpha).

    Parameters:
    - V: RF voltage of the main harmonic (V1): float
    - r: voltage ratio value of the all harmonics (Vn/V1): list
    - h: harmonic number of the all harmonics: list
    - phi_s: synchronous phase(s) (computed automatically in dE_to_phi_s) 
    - Phis: phase shift for each harmonic: list
    - dE_s: energy gain per turn of the synchronous particle: float
    - E: energy of the particle of the synchronous particle: float
    - constant: h*eta/beta^2 : float
    - ax: axes to plot on: matplotlib.axes.Axes
    - n_phi: number of points to sample the potential: int
    - n_delta: number of points to sample the energy difference grid: int
    - phi_margin_percent: margin to add to the potential for maximum detection at the edges: float
    - base_color: color of the separatrix: str
    - inner_alpha: alpha of the inner separatrices: float
    - outer_fill_alpha: alpha of the outer separatrix fill: float
    - inner_linewidth: linewidth of the inner separatrices: float
    - outer_linewidth: linewidth of the outer separatrix: float
    - epsilon: epsilon for numerical stability: float
    - below_transition: boolean to indicate if the particle is below the transition energy: bool
    - fill_potential: potential to fill the separatrix: float
    - force_ylim: boolean to force the ylim to be reset (necessary for animation when making a modification to the RF system): bool
    """
    margin = phi_margin_percent / 100.0 # Add a margin to capture the maximas at the edges
    if below_transition:
        phi = np.linspace(-margin * 2 * np.pi -np.pi,  margin * 2 * np.pi + np.pi , n_phi)
        bounds = [-margin * 2 * np.pi -np.pi,  margin * 2 * np.pi + np.pi]
        #Grab 3 periods for the potential to avoid edge effects
        phi_potential = np.linspace(- 2 * np.pi -np.pi,  2 * np.pi + np.pi , n_phi*2)
    else:
        phi = np.linspace(-margin * 2 * np.pi,  margin * 2 * np.pi + 2*np.pi , n_phi)
        bounds = [-margin * 2 * np.pi,  margin * 2 * np.pi + 2*np.pi]
        #Grab 3 periods for the potential to avoid edge effects
        phi_potential = np.linspace(- 2 * np.pi,  2 * np.pi + 2*np.pi , n_phi*2)

    U_overkill = potential_mh(V, r, h, phi_potential, dE_s, Phis, below_transition=below_transition)
    U = potential_mh(V, r, h, phi, dE_s, Phis, below_transition=below_transition)
    

    # Extrema detection (use the potential that is 2x longer to avoid edge effects)
    dU = np.gradient(U_overkill, phi_potential)
    dU_prev = np.roll(dU, 1)
    dU_next = np.roll(dU, -1)
    maxima_idx = np.where((dU_prev > 0) & (dU_next < 0))[0]
    minima_idx = np.where((dU_prev < 0) & (dU_next > 0))[0]
    maxima_idx = np.unique(maxima_idx)
    minima_idx = np.unique(minima_idx)

    # Remove all the minima that are outside the phi range
    minima_idx = minima_idx[(phi_potential[minima_idx] > bounds[0]) & (phi_potential[minima_idx] < bounds[1])]
    maxima_idx = maxima_idx[(phi_potential[maxima_idx] > bounds[0]) & (phi_potential[maxima_idx] < bounds[1])]

    # Remove the minima and maxima that are too close together (can happen because of numerical precision)
    minima_diff = np.diff(phi_potential[minima_idx])
    maxima_diff = np.diff(phi_potential[maxima_idx])

    minima_remove_idx = np.where(abs(minima_diff) < 1e-1)[0]
    maxima_remove_idx = np.where(abs(maxima_diff) < 1e-1)[0]

    minima_idx = np.delete(minima_idx, minima_remove_idx)
    maxima_idx = np.delete(maxima_idx, maxima_remove_idx)

    U -= np.max(U_overkill[maxima_idx])
    U_overkill -= np.max(U_overkill[maxima_idx])


    created_fig = None
    if ax is None:
        created_fig, ax = plt.subplots(1, 1)

    if len(maxima_idx) == 0:
        return created_fig, ax, [], phi, U

    # Build H-levels per well
    H_levels = []
    m = len(maxima_idx)
    identified_H_levels_maxima_pairs = []
    for i in range(m):
        left = maxima_idx[i]
        right = maxima_idx[(i + 1) % m]
        # Check if there's a minimum between the maxima and that it's lower than both maxima
        if left < right:
            min_between = minima_idx[(minima_idx > left) & (minima_idx < right)]
        else:
            min_between = minima_idx[(minima_idx < left) | (minima_idx > right)]
        
    
        has_min = False
        # But also check if the minimum found has a lower potential energy than the lowest of the two maxima
        if len(min_between) > 0:
            # Check that at least one minimum is lower than both maxima
            min_values = U_overkill[min_between]
            # Check that at least one minimum is lower than both maxima, and that at least one minimum is lower than the identified H_levels of the separatrices
            has_min = np.any((min_values < min(U_overkill[left], U_overkill[right])) & np.all([np.any(min_values < min(U_overkill[i_left], U_overkill[i_right])) for i_left, i_right in identified_H_levels_maxima_pairs])) 

        if not has_min:
            continue
        
        identified_H_levels_maxima_pairs.append((left, right))

        H_level = float(min(U_overkill[left], U_overkill[right]))

        if not any(abs(H_level - L) <= 1e-9 for L in H_levels): # Check if the level is already in the list
            H_levels.append(H_level)
    H_levels = sorted(H_levels)
    
    if not H_levels:
        print("No H_levels found. Modify your parameters to have stable acceleration.")
        return created_fig, ax, H_levels, phi, U


    delta_max = np.sqrt((max(H_levels) - np.min(U_overkill[minima_idx]))*E/(0.5 * constant))
    if not np.isfinite(delta_max) or delta_max == 0:
        delta_max = 1.0

    
    dE = np.linspace(-1.05 * delta_max, 1.05 * delta_max, n_delta)
    PHI, DE = np.meshgrid(phi, dE)
    H_grid = U[np.newaxis, :] + 0.5* constant * (DE ** 2) / E

    if base_color is None:
        base_color = "C0"

    # Inner levels
    inner_levels = H_levels[:-1]
    outer_level = H_levels[-1]

    if inner_levels:
        
        linestyles = ['--', '-.', ':', '-']
        for i, level in enumerate(inner_levels):
            ax.contour(
                PHI,
                DE,
                H_grid,
                levels=[level],
                colors=[base_color],
                linewidths=inner_linewidth,
                linestyles=linestyles[i%len(linestyles)],
                alpha=inner_alpha,
            )
        

    # Outer separatrix line
    contour_set = ax.contour(
        PHI,
        DE,
        H_grid,
        levels=[outer_level],
        colors=[base_color],
        alpha=1,
        linestyles='solid',
        linewidths=outer_linewidth,
    )
    
    
    # Outer separatrix fill (only inside of the closed contour)
    if below_transition:
        mask = PHI < np.max(phi_potential[maxima_idx]) 
    else:
        mask = PHI > np.min(phi_potential[maxima_idx])

    masked_H_grid = np.ma.masked_where(~mask, H_grid)

    ax.contourf(
        PHI,
        DE,
        masked_H_grid,
        levels=[-np.inf, outer_level],
        colors=[base_color],
        alpha=outer_fill_alpha,
    )
    

    # Plot the synchronous phase 
    if not isinstance(phi_s, list):
        # Convert phi_s to a fraction of pi
        phi_s_over_pi = phi_s / np.pi
        
        # Handle common fractions
        frac = Fraction(phi_s_over_pi).limit_denominator(100)
        
        if frac.denominator == 1:
            if frac.numerator == 1:
                f_label = r"\pi"
            else:
                if frac.numerator == 0:
                    f_label = r"0"
                else:
                    f_label = f"{frac.numerator}" + r"\pi"
        else:
            if frac.numerator == 1:
                f_label = r"\frac{\pi}{" + str(frac.denominator) + "}"
            else:
                f_label = r"\frac{" + str(frac.numerator) + r"\pi}{" + str(frac.denominator) + "}"
        
        label = r"$\phi_s = " + f_label + r"$ rad"
        
        ax.scatter(phi_s, 0, color=base_color, label=label)
    else:
        label= ""
        for i, phi_s_i in enumerate(phi_s):
            phi_s_over_pi = phi_s_i / np.pi
            frac = Fraction(phi_s_over_pi).limit_denominator(100)
            if frac.denominator == 1:
                if frac.numerator == 1:
                    f_label = r"\pi"
                else:
                    if frac.numerator == 0:
                        f_label = r"0"
                    else:
                        f_label = f"{frac.numerator}" + r"\pi"
            else:
                if frac.numerator == 1:
                    f_label = r"\frac{\pi}{" + str(frac.denominator) + "}"
                else:
                    f_label = r"\frac{" + str(frac.numerator) + r"\pi}{" + str(frac.denominator) + "}"
            label += r"$\phi_{s" + str(i+1) + r"} = " + f_label + r"$ rad" + '\n'
        label = label[:-1]
        ax.scatter(phi_s, [0]*len(phi_s), color=base_color, label=label)

    
    # Fill the potential with a certain filling factor if specified
    if fill_potential is not None:
        
        
        if len(maxima_idx) > 2:
            maxima_idx = maxima_idx[[0, -1]]
        # Find the energy difference between the lowest minimum and the highest maximum
        dE_min = np.min(U_overkill[minima_idx])
        dE_max = np.min(U_overkill[maxima_idx])
        dE_diff = dE_max - dE_min
        # lowest_H = np.min(U
        ax.contourf(
            PHI,
            DE,
            H_grid,
            levels=[-np.inf, dE_min + fill_potential*dE_diff],
            colors=['purple'],
            alpha=0.6
        )
        ax.contour(
            PHI,
            DE,
            H_grid,
            levels=[dE_min + fill_potential*dE_diff],
            colors=['purple'],
            linewidths=outer_linewidth,
            label='Particles'
        )
    if below_transition:
        ax.set_xlim(-1.01*np.pi, 1.01*np.pi)
    else:
        ax.set_xlim(-0.01*np.pi, 2.01*np.pi)

    # Check if we need to force the ylim to be reset
    if not force_ylim:
        # Check if the ylim is larger in the ax already (if not, set it to the max energy difference)
        if ax.get_ylim()[1] < 1.05*delta_max:
            ax.set_ylim(-1.05*delta_max, 1.05*delta_max)
    else:
        ax.ylim(-1.05*delta_max, 1.05*delta_max)
  
    ax.set_xlabel(r"$\phi$ [rad]")
    ax.set_ylabel(r"$\Delta E$ [arb. units]")
    ax.set_title("Separatrices")
    ax.legend(loc='upper right')

    return created_fig, ax, H_levels, phi, U , (PHI, DE, H_grid)

def separatrices_mh_contours(
    V: float,
    r: list,
    h: list,
    phi_s: float,
    Phis: list,
    dE_s: float,
    E: float, 
    constant: float,
    ax=None,
    n_phi: int = 2000,
    n_delta: int = 400,
    phi_margin_percent: float = 10.0,
    below_transition: bool = True,
    contour_margin: float = 1.5,
    no_fill: bool = False,
):
    """
    Draw separatrix contours H(phi, dE) = U(phi) + 0.5*constant*dE^2 / E using a single base color.
    - Outer separatrix: blue dashed line
    - Inner separatrices: red dashed lines

    Parameters:
    - V: RF voltage of the main harmonic (V1): float
    - r: voltage ratio value of the all harmonics (Vn/V1): list
    - h: harmonic number of the all harmonics: list
    - phi_s: synchronous phase(s) (computed automatically in dE_to_phi_s) 
    - Phis: phase shift for each harmonic: list
    - dE_s: energy gain per turn of the synchronous particle: float
    - E: energy of the particle of the synchronous particle: float
    - constant: h*eta/beta^2 : float
    - ax: axes to plot on: matplotlib.axes.Axes
    - n_phi: number of points to sample the potential: int
    - n_delta: number of points to sample the energy difference grid: int
    - phi_margin_percent: margin to add to the potential for maximum detection at the edges: float
    - below_transition: boolean to indicate if the particle is below the transition energy: bool
    """
    margin = phi_margin_percent / 100.0 # Add a margin to capture the maximas at the edges
    if below_transition:
        phi = np.linspace(-margin * 2 * np.pi -np.pi,  margin * 2 * np.pi + np.pi , n_phi)
        bounds = [-margin * 2 * np.pi -np.pi,  margin * 2 * np.pi + np.pi]
        #Grab 3 periods for the potential to avoid edge effects
        phi_potential = np.linspace(- 2 * np.pi -np.pi,  2 * np.pi + np.pi , n_phi*2)
    else:
        phi = np.linspace(-margin * 2 * np.pi,  margin * 2 * np.pi + 2*np.pi , n_phi)
        bounds = [-margin * 2 * np.pi,  margin * 2 * np.pi + 2*np.pi]
        #Grab 3 periods for the potential to avoid edge effects
        phi_potential = np.linspace(- 2 * np.pi,  2 * np.pi + 2*np.pi , n_phi*2)

    U_overkill = potential_mh(V, r, h, phi_potential, dE_s, Phis, below_transition=below_transition)
    U = potential_mh(V, r, h, phi, dE_s, Phis, below_transition=below_transition)
    

    # Extrema detection (use the potential that is 2x longer to avoid edge effects)
    dU = np.gradient(U_overkill, phi_potential)
    dU_prev = np.roll(dU, 1)
    dU_next = np.roll(dU, -1)
    maxima_idx = np.where((dU_prev > 0) & (dU_next < 0))[0]
    minima_idx = np.where((dU_prev < 0) & (dU_next > 0))[0]
    maxima_idx = np.unique(maxima_idx)
    minima_idx = np.unique(minima_idx)

    # Remove all the minima that are outside the phi range
    minima_idx = minima_idx[(phi_potential[minima_idx] > bounds[0]) & (phi_potential[minima_idx] < bounds[1])]
    maxima_idx = maxima_idx[(phi_potential[maxima_idx] > bounds[0]) & (phi_potential[maxima_idx] < bounds[1])]

    # Remove the minima and maxima that are too close together (can happen because of numerical precision)
    minima_diff = np.diff(phi_potential[minima_idx])
    maxima_diff = np.diff(phi_potential[maxima_idx])
    minima_remove_idx = np.where(abs(minima_diff) < 1e-1)[0]
    maxima_remove_idx = np.where(abs(maxima_diff) < 1e-1)[0]
    minima_idx = np.delete(minima_idx, minima_remove_idx)
    maxima_idx = np.delete(maxima_idx, maxima_remove_idx)
    
    U -= np.max(U_overkill[maxima_idx])
    U_overkill -= np.max(U_overkill[maxima_idx])


    created_fig = None
    if ax is None:
        created_fig, ax = plt.subplots(1, 1)

    if len(maxima_idx) == 0:
        return created_fig, ax, [], phi, U

    # Build H-levels per well
    H_levels = []
    m = len(maxima_idx)
    identified_H_levels_maxima_pairs = []
    for i in range(m):
        left = maxima_idx[i]
        right = maxima_idx[(i + 1) % m]
        # Check if there's a minimum between the maxima and that it's lower than both maxima
        if left < right:
            min_between = minima_idx[(minima_idx > left) & (minima_idx < right)]
        else:
            min_between = minima_idx[(minima_idx < left) | (minima_idx > right)]
        
    
        has_min = False
        # But also check if the minimum found has a lower potential energy than the lowest of the two maxima
        if len(min_between) > 0:
            # Check that at least one minimum is lower than both maxima
            min_values = U_overkill[min_between]
            has_min = np.any((min_values < min(U_overkill[left], U_overkill[right])) & np.all([np.any(min_values < min(U_overkill[i_left], U_overkill[i_right])) for i_left, i_right in identified_H_levels_maxima_pairs])) 

      

        if not has_min:
            continue
        
        identified_H_levels_maxima_pairs.append((left, right))

        H_level = float(min(U_overkill[left], U_overkill[right]))

        if not any(abs(H_level - L) <= 1e-9 for L in H_levels): # Check if the level is already in the list
            H_levels.append(H_level)
    H_levels = sorted(H_levels)

    if not H_levels:
        return created_fig, ax, H_levels, phi, U

    
    delta_max = np.sqrt((max(H_levels) - np.min(U_overkill[minima_idx]))*E/(0.5 * constant))
    if not np.isfinite(delta_max) or delta_max == 0:
        delta_max = 1.0


    
    dE = np.linspace(-1.5 * delta_max, 1.5 * delta_max, n_delta)
    PHI, DE = np.meshgrid(phi, dE)
    H_grid = U[np.newaxis, :] + 0.5* constant * (DE ** 2) / E


    # Inner levels
    inner_levels = H_levels[:-1]
    outer_level = H_levels[-1]

    # Fill the potential with contours  
    
    if len(maxima_idx) > 2:
        maxima_idx = maxima_idx[[0, -1]]

    # Find the energy difference between the lowest minimum and the highest maximum
    dE_min = np.min(U_overkill[minima_idx])
    dE_max = np.min(U_overkill[maxima_idx])
    dE_diff = dE_max - dE_min

    if not no_fill:
        contour_filled = ax.contourf(
            PHI,
            DE,
            H_grid,
            levels=list(np.linspace(dE_min, dE_min + contour_margin*dE_diff, 31)),
            cmap='viridis',
            alpha=0.6
        )
        plt.colorbar(contour_filled, ax=ax, label='Hamiltonian [arb. units]')
    else:
        ax.contour(
            PHI,
            DE,
            H_grid,
            levels=list(np.linspace(dE_min, dE_min + contour_margin*dE_diff, 31)),
            cmap='viridis',
            linewidths=0.5,
            alpha=0.6
        )

    if inner_levels:
        # Switch up the linestyles fo each inner level
        linestyles = ['--', '-.', ':', '-']
        for i, level in enumerate(inner_levels):
            ax.contour(
                PHI,
                DE,
                H_grid,
                levels=[level],
                colors=['red'],
                linestyles=linestyles[i%len(linestyles)],
                
            )

    # Outer separatrix line
    ax.contour(
        PHI,
        DE,
        H_grid,
        levels=[outer_level],
        colors=['blue'],
        linestyles='--',
    )

    if not isinstance(phi_s, list):
        # Convert phi_s to a fraction of pi
        phi_s_over_pi = phi_s / np.pi
        
        # Handle common fractions
        frac = Fraction(phi_s_over_pi).limit_denominator(100)
        
        if frac.denominator == 1:
            if frac.numerator == 1:
                f_label = r"\pi"
            else:
                if frac.numerator == 0:
                    f_label = r"0"
                else:
                    f_label = f"{frac.numerator}" + r"\pi"
        else:
            if frac.numerator == 1:
                f_label = r"\frac{\pi}{" + str(frac.denominator) + "}"
            else:
                f_label = r"\frac{" + str(frac.numerator) + r"\pi}{" + str(frac.denominator) + "}"
        
        label = r"$\phi_s = " + f_label + r"$ rad"
    else:
        label = ""
        for i, phi_s_i in enumerate(phi_s):
            phi_s_over_pi = phi_s_i / np.pi
            frac = Fraction(phi_s_over_pi).limit_denominator(100)
            if frac.denominator == 1:
                if frac.numerator == 1:
                    f_label = r"\pi"
                else:
                    if frac.numerator == 0:
                        f_label = r"0"
                    else:
                        f_label = f"{frac.numerator}" + r"\pi"
            else:
                if frac.numerator == 1:
                    f_label = r"\frac{\pi}{" + str(frac.denominator) + "}"
                else:
                    f_label = r"\frac{" + str(frac.numerator) + r"\pi}{" + str(frac.denominator) + "}"
            label += r"$\phi_{s" + str(i+1) + r"} = " + f_label + r"$ rad, "
            if i != len(phi_s)-1 and (i+1)%2 == 0:
                label += '\n'
        label = label[:-2]

    

    if below_transition:
        ax.set_xlim(-1.01*np.pi, 1.01*np.pi)
    else:
        ax.set_xlim(-0.01*np.pi, 2.01*np.pi)

    ax.set_ylim(-contour_margin*delta_max, contour_margin*delta_max)
    ax.set_xlabel(r"$\phi$ [rad]")
    ax.set_ylabel(r"$\Delta E$ [arb. units]")
    ax.set_title("Hamiltonian Contours for \n" + label)
    ax.legend(loc='upper right')

    return created_fig, ax, H_levels, phi, U , (PHI, DE, H_grid)

def separatrix_plotter(PHI, DE, H_grid, H_levels, contours =False, ax = None, ylim = None, xlim = None):
    """
    Plot the separatrix from the data of the Hamiltonian grid
    (Use if computational time is a concern)

    Parameters:
    - PHI: phase grid: np.array
    - DE: energy difference grid: np.array
    - H_grid: Hamiltonian grid: np.array
    - H_levels: levels of the Hamiltonian: list
    - contours: indicate if several Hamiltonian contours should be plotted with a colormap: bool
    - ax: axes to plot on: matplotlib.axes.Axes
    - **kwargs: additional plotting arguments for the potential
    """
    if ax is None:
        fig, ax = plt.subplots(1, 1)
    
    inner_levels = H_levels[:-1]
    outer_level = H_levels[-1]

    if inner_levels:
        ax.contour(PHI, DE, H_grid, levels=inner_levels, colors=['red'], linestyles='dashed', alpha=0.3)

    ax.contour(PHI, DE, H_grid, levels=[outer_level], colors=['blue'], linestyles='solid', alpha=1)

    if contours:
        ax.contourf(PHI, DE, H_grid, levels=H_levels, cmap='viridis')
        plt.colorbar(ax=ax, label='Hamiltonian [arb. units]')

    ax.set_xlabel(r"$\phi$ [rad]")
    ax.set_ylabel(r"$\Delta E$ [arb. units]")
    ax.set_title("Separatrices")
    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)


    return ax
    


### RF Potential $U(\phi)$ Plotter and $\phi_s$ computer

In [None]:
# Helper to plot potentials for multiple r values into a provided axes

def plot_potentials_multi_phi_s(V, r, h, phi_s_list, Phis ,dE_list, ax=None, n_phi: int = 2000, fill_potential = None, phi_margin_percent = 20, below_transition = True, **plot_kwargs):
    """Plot the potential for multiple phi_s (meaning multiple energy gain per turn indicated by dE_list) values into a provided axes
    Parameters:
    - V: voltage of the 1st harmonic: float
    - r: ratios of the n harmonics to the 1st harmonic: list
    - h: harmonic numbers of the n harmonics: list
    - phi_s_list: list of phi_s values for each particle (can be a list for multiple particles)
    - Phis_list: list of Phis values for each particle (can be a list for multiple particles)
    - dE_list: list of energy gain per turn of the particle (can be a list for multiple particles)
    - ax: axes to plot the potential into (if None, a new figure is created)
    - n_phi: number of points to evaluate the potential at
    - fill_potential: fraction of the potential to fill (for visualization)
    - phi_margin_percent: margin to add to the phi range to avoid edge effects
    - below_transition: boolean to indicate if the particle is below the transition energy
    - **plot_kwargs: additional plotting arguments for the potential
    Returns:
    """
    margin = phi_margin_percent / 100.0
    if below_transition:
        phi = np.linspace(-3*np.pi, 3*np.pi, n_phi*2) # Grab 3 periods to avoid edge effects
        bounds = [-(1+margin)*np.pi, (1+margin)*np.pi]
    else:
        phi = np.linspace(-3*np.pi, 3*np.pi, n_phi*2) # Grab 3 periods to avoid edge effects
        bounds = [-margin*np.pi, (2+margin)*np.pi]

    created_fig = None
    if ax is None:
        created_fig, ax = plt.subplots(1, 1)
    
    # Define the plot mask so that the vertical axis is not zoomed out
    if below_transition:
        plot_mask = (phi >= -np.pi*1.5) & (phi <= np.pi*1.5)
    else:
        plot_mask = (phi >= -np.pi*0.5) & (phi <= np.pi*2.5)

    colors = []
    for i, dE in enumerate(dE_list):
        U = potential_mh(V, r, h, phi, dE, Phis, below_transition)
        
        if not isinstance(phi_s_list[i], list):

            # Convert phi_s to a fraction of pi
            phi_s_over_pi = phi_s_list[i] / np.pi
            
            # Handle common fractions
            frac = Fraction(phi_s_over_pi).limit_denominator(100)
            
            if frac.numerator == 0:
                label = r"$\phi_s$ = 0 rad"
            else:
                if frac.denominator == 1:
                    if frac.numerator == 1:
                        f_label = r"\pi"
                    else:
                        f_label = f"{frac.numerator}" + r"\pi"
                else:
                    if frac.numerator == 1:
                        f_label = r"\frac{\pi}{" + str(frac.denominator) + "}"
                    else:
                        f_label = r"\frac{" + str(frac.numerator) + r"\pi}{" + str(frac.denominator) + "}"
                
                label = r"$\phi_s = " + f_label + r"$ rad"
        else:
            label = ""
            for i, phi_s in enumerate(phi_s_list[i]):
                phi_s_over_pi = phi_s / np.pi
                frac = Fraction(phi_s_over_pi).limit_denominator(100)
                if frac.numerator == 0:
                    label += r"$\phi_{s" + str(i+1) + r"} = 0$ rad, "
                else:
                    if frac.denominator == 1:
                        if frac.numerator == 1:
                            f_label = r"\pi"
                        else:
                            f_label = f"{frac.numerator}" + r"\pi"
                    else:
                        if frac.numerator == 1:
                            f_label = r"\frac{\pi}{" + str(frac.denominator) + "}"
                        else:
                            f_label = r"\frac{" + str(frac.numerator) + r"\pi}{" + str(frac.denominator) + "}"
                    label += r"$\phi_{s" + str(i+1) + r"} = " + f_label + r"$ rad" + '\n'
            label = label[:-1]


        line, = ax.plot(phi[plot_mask], U[plot_mask], **plot_kwargs, label=label)
        colors.append(line.get_color())
        if below_transition:
            ax.axvline(-np.pi, color='red', linestyle='--')
            ax.axvline(np.pi, color='red', linestyle='--')
        else:
            ax.axvline(0, color='red', linestyle='--')
            ax.axvline(2*np.pi, color='red', linestyle='--')
    
    if fill_potential is not None:
            if below_transition:
                mask = (phi >= -np.pi) & (phi <= np.pi)
            else:
                mask = (phi >= 0) & (phi <= 2*np.pi)

            dU = np.gradient(U, phi)
            dU_prev = np.roll(dU, 1)
            dU_next = np.roll(dU, -1)
            maxima_idx = np.where((dU_prev > 0) & (dU_next < 0))[0]
            minima_idx = np.where((dU_prev < 0) & (dU_next > 0))[0]

            maxima_idx = np.unique(maxima_idx)
            minima_idx = np.unique(minima_idx)
            
            # Remove the minima and maxima that are outside the bounds
            minima_idx = minima_idx[(phi[minima_idx] > bounds[0]) & (phi[minima_idx] < bounds[1])]
            maxima_idx = maxima_idx[(phi[maxima_idx] > bounds[0]) & (phi[maxima_idx] < bounds[1])]
            
            # Only keep the two extreme maxima
            if len(maxima_idx) > 2:
                maxima_idx = maxima_idx[[0, -1]]
            
            
            # Find the energy difference between the lowest minimum and the highest maximum
            dE_min = np.min(U[minima_idx])
            dE_max = np.min(U[maxima_idx])
            dE_diff = dE_max - dE_min
            
            limit_energy = dE_diff*fill_potential + dE_min

            # Fill the space between the potential well and the limit energy
            # Only fill between -pi and pi
            
            ax.fill_between(phi[mask][U[mask] < limit_energy], U[mask][U[mask] < limit_energy], limit_energy, color='purple', alpha=0.6, label = r'Particles ($FF=' + str(round(fill_potential*100, 1)) + r"\%$)")
            
    
    if below_transition:
        ax.set_xlim(-1.5*np.pi, 1.5*np.pi)
    else:
        ax.set_xlim(-0.5*np.pi, 2.5*np.pi)

    ax.set_xlabel(r"$\phi$ [rad]")
    ax.set_ylabel(r"Potential $U(\phi)$ [arb. units]")
    if len(r) == 1:
        ax.set_title(r"$U(\phi)$")
    else:
        r_label = r"$U(\phi)$ for "
        for i, r_i in enumerate(r[1:]):
            r_label += r"$r_{" + str(i+2) + r"} = " + str(r_i) + r"$, "
            if (i+1)%3 == 0:
                r_label += '\n'
        r_label = r_label[:-2]
        ax.set_title(r_label)

    ax.legend()
    
    return created_fig, ax, phi, colors


def dE_to_phi_s(dE_s, V, r, h, Phis, below_transition, margin = 0.2, n_phi = 1000):
    """ Find the phi_s that corresponds to the energy gain per turn of the synchronous particle:
    Parameters:
    - dE_s: energy gain per turn of the synchronous particle: float
    - V: voltage of the 1st harmonic: float
    - r: ratios of the n harmonics to the 1st harmonic: list
    - h: harmonic numbers of the n harmonics: list
    - Phis: relative phase of the n harmonics relative to the 1st harmonic: list
    - below_transition: boolean to indicate if the particle is below the transition energy: bool
    - margin: margin to add to the phi range to avoid edge effects
    - n_phi: number of points to evaluate the potential at
    Returns:
    - phi_s_solution: list (or float) of phi_s values that correspond to the energy gain per turn of the synchronous particle
    """
    
    
    # Method: Find all the maxima in the range [-pi-margin, pi+margin] and remove the first and last one
    # Extrema detection
    if below_transition:
        bounds = [-margin * 2 * np.pi -np.pi,  margin * 2 * np.pi + np.pi]
        #Grab 3 periods for the potential to avoid edge effects
        phi = np.linspace(- 2 * np.pi -np.pi,  2 * np.pi + np.pi , n_phi*2)
    else:
        bounds = [-margin * 2 * np.pi,  margin * 2 * np.pi + 2*np.pi]
        #Grab 3 periods for the potential to avoid edge effects
        phi = np.linspace(- 2 * np.pi,  2 * np.pi + 2*np.pi , n_phi*2)

    U_overkill = potential_mh(V, r, h, phi, dE_s, Phis, below_transition=below_transition)
    U = potential_mh(V, r, h, phi, dE_s, Phis, below_transition=below_transition)
    U -= np.max(U)

    # Extrema detection (use the potential that is 2x longer to avoid edge effects)
    dU = np.gradient(U_overkill, phi)
    dU_prev = np.roll(dU, 1)
    dU_next = np.roll(dU, -1)
    maxima_idx = np.where((dU_prev > 0) & (dU_next < 0))[0]
    minima_idx = np.where((dU_prev < 0) & (dU_next > 0))[0]
    maxima_idx = np.unique(maxima_idx)
    minima_idx = np.unique(minima_idx)

    # Remove all the minima that are outside the phi range
    minima_idx = minima_idx[(phi[minima_idx] > bounds[0]) & (phi[minima_idx] < bounds[1])]
    maxima_idx = maxima_idx[(phi[maxima_idx] > bounds[0]) & (phi[maxima_idx] < bounds[1])]

    # Remove the minima and maxima that are too close together (can happen because of numerical precision)
    minima_diff = np.diff(phi[minima_idx])
    maxima_diff = np.diff(phi[maxima_idx])
    minima_remove_idx = np.where(minima_diff < 1e-1)[0]
    maxima_remove_idx = np.where(maxima_diff < 1e-1)[0]
    minima_idx = np.delete(minima_idx, minima_remove_idx)
    maxima_idx = np.delete(maxima_idx, maxima_remove_idx)
    
    # Remove the first and last maxima (ones at the edges)
    try:
        phi_s_solution = phi[maxima_idx[1:-1]]
    except:
        if len(maxima_idx) == 2:
            phi_s_solution = [phi[maxima_idx[0]]]
        else:
            phi_s_solution = phi[maxima_idx]

    phi_min_energy = phi[minima_idx[np.argmin(U[minima_idx])]]
    
    
    
    # If there are no solutions, return the lowest minimum
    if len(phi_s_solution) == 0:
        if below_transition:
            return (phi_min_energy + np.pi) % (2*np.pi) - np.pi
        else:
            return (phi_min_energy) % (2*np.pi)

    if len(phi_s_solution) == 1:
        # Wrap the solution to the range [-pi, pi] if below transition, if above [0, 2*pi]
        if below_transition:
            phi_s_solution = (phi_s_solution[0] + np.pi) % (2*np.pi) - np.pi
        else:
            phi_s_solution = (phi_s_solution[0]) % (2*np.pi)
        return phi_s_solution
    else:
        if below_transition:
            phi_s_solution = [(phi_s + np.pi) % (2*np.pi) - np.pi for phi_s in phi_s_solution]
        else:
            phi_s_solution = [(phi_s) % (2*np.pi) for phi_s in phi_s_solution]
        phi_s_solution = list(np.unique(phi_s_solution)) # Remove duplicates
        phi_s_solution.sort() # Sort the solutions

        # Remove solutions that are too close to the edges
        if below_transition:
            phi_s_solution = [phi_s for phi_s in phi_s_solution if phi_s > -np.pi+5e-2 and phi_s < np.pi - 5e-2]
        else:
            phi_s_solution = [phi_s for phi_s in phi_s_solution if phi_s > 0+5e-2 and phi_s < 2*np.pi - 5e-2]
        return phi_s_solution



### Particle Distribution Generator

In [None]:
def generateBunch(bunch_position, bunch_length,
                  bunch_energy, energy_spread,
                  n_macroparticles, type='gaussian'):
    """
    This function generates a Gaussian distribution of particles in phase space 
    Parameters:
    bunch_position: float, position of the bunch in phase
    bunch_length: float, length of the bunch in phase
    bunch_energy: float, energy of the bunch
    energy_spread: float, spread of the energy of the bunch
    n_macroparticles: int, number of macroparticles to generate
    type: str, type of distribution to generate, 'gaussian' or 'rectangular'
    """
    # Generating phase and energy arrays
    phase_array = np.linspace(bunch_position-bunch_length/2,
                              bunch_position+bunch_length/2,
                              100)
    energy_array = np.linspace(bunch_energy-energy_spread/2,
                              bunch_energy+energy_spread/2,
                              100)
    
    # Getting Hamiltonian on a grid
    phase_grid, deltaE_grid = np.meshgrid(
        phase_array, energy_array)
    
    # Bin sizes
    bin_phase = phase_array[1]-phase_array[0]
    bin_energy = energy_array[1]-energy_array[0]
    
    # Density grid
    if type == 'gaussian':
        isodensity_lines = ((phase_grid-bunch_position)/bunch_length*2)**2. + \
            ((deltaE_grid-bunch_energy)/energy_spread*2)**2.
        density_grid = 1-isodensity_lines**2.
    elif type == 'rectangular':
        density_grid = np.ones((100, 100))
    else:
        raise ValueError("Invalid distribution type")

    density_grid[density_grid<0] = 0
    density_grid /= np.sum(density_grid)
   
    # Generating particles randomly inside the grid cells according to the
    # provided density_grid
    indexes = np.random.choice(np.arange(0,np.size(density_grid)), 
                               n_macroparticles, p=density_grid.flatten())
    
    # Randomize particles inside each grid cell (uniform distribution)
    particle_phase = (np.ascontiguousarray(phase_grid.flatten()[indexes] +
        (np.random.rand(n_macroparticles) - 0.5) * bin_phase))
    particle_energy = (np.ascontiguousarray(deltaE_grid.flatten()[indexes] +
        (np.random.rand(n_macroparticles) - 0.5) * bin_energy))
    
    return particle_phase, particle_energy


### Particle Tracking Class and Animator

In [None]:

class ParticleBeam: 
    def __init__(self, 
                 particles, 
                 V, 
                 r, 
                 h, 
                 dE_s, 
                 Phis, 
                 below_transition, 
                 constant, 
                 E, 
                 use_contour_bg = False,
                 efficiency_mode = False):
        """
        Initialize the multi-harmonic particle beam
        
        Parameters:
        -----------
        particles: (2,n_points) array of [phase, energy]
        V: voltage of the 1st harmonic: float
        r: ratios of the n harmonics to the 1st harmonic: list
        h: harmonic numbers of the n harmonics: list
        dE_s: energy gain per turn of the synchronous particle: float
        Phis: relative phase of the n harmonics relative to the 1st harmonic: list
        below_transition: boolean to indicate if the particle is below the transition energy: bool
        constant: |h*eta/beta^2| : float
        E: energy of the particle: float
        use_contour_bg: boolean to indicate if the separatrix should be plotted with a contour background: bool
        efficiency_mode: boolean to indicate if the separatrix data is stored and plotted in a more efficient way: bool
        """
        self.V = V
        self.r = r
        self.h = h
        self.dE_s = dE_s
        self.Phis = Phis
        self.below_transition = below_transition
        self.constant = constant
        self.E = E
        self.state = np.expand_dims(particles, axis=0)  # (1,2,n_particles) array of [[phase], [energy]]
        self.turn = 0
        self.n_particles = np.size(particles, 1)  # number of particles
        self.use_contour_bg = use_contour_bg
        self.efficiency_mode = efficiency_mode
        # Compute phi_s

        try:
            self.phi_s = dE_to_phi_s(dE_s, V, r, h, Phis, below_transition)
        except:
            raise ValueError("phi_s is not provided and could not be computed. Acceleration with current voltage is not possible. Try to modify the RF system. \n Tip: Increase V")

        self.limits = [-np.pi, np.pi] if self.below_transition else [0, 2*np.pi]

        # Get the separatrix plot (as it will be used in the background of every plot)
        # (we don't want to fill the separatrix as it will ruin the histogram)
        if efficiency_mode:
            if not use_contour_bg:
                self.sep_fig, separatrices_plot, self.H_levels, self.phi_potential, self.potential, self.H_data = separatrices_mh_colored(
                    self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                    ax=None, fill_potential=None, below_transition=self.below_transition, outer_fill_alpha=0.0 
                )
            else: 
                self.sep_fig, separatrices_plot, self.H_levels, self.phi_potential, self.potential, self.H_data = separatrices_mh_contours(
                    self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                    ax=None, below_transition=self.below_transition, contour_margin=1.1, no_fill=True
                )
        else:
            if not use_contour_bg:
                self.sep_fig, separatrices_plot, _, self.phi_potential, self.potential, _ = separatrices_mh_colored(
                    self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                    ax=None, fill_potential=None, below_transition=self.below_transition, outer_fill_alpha=0.0 
                )
            else:
                self.sep_fig, separatrices_plot, _, self.phi_potential, self.potential, _ = separatrices_mh_contours(
                    self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                    ax=None, below_transition=self.below_transition, contour_margin=1.1, no_fill=True
                )

        # Get the y limits for the animation
        self.y_lims = separatrices_plot.get_ylim()

    def plus_one_turn(self):
        """
        Track the particle beam for one turn using multi-harmonic tracking
        """
    
        #Track each particle using the multi-harmonic tracking function
        new_state = tracking_function(self.V, self.r, self.h, self.dE_s, 
                                           self.Phis, self.below_transition, 
                                           self.constant, [self.state[-1], self.E])
        self.state = np.vstack([self.state, np.expand_dims(new_state[0], axis=0)]) # Add the new state to the state array
        self.E = new_state[1]

        self.turn += 1
        return

    def advance_x_turns(self, x_turns):
        """
        Advance the particle beam x_turns turns
        """
        for i in range(x_turns):
            self.plus_one_turn()

    def plot_state(self, xbins=100, ybins=100, ax=None, separatrix = True):
        """
        Plot the current state 
        """
        
        if ax is None:
            plt.close()
            fig , ax = plt.subplots()
        

    
        if separatrix:
            if self.efficiency_mode:
                separatrix_plotter(self.H_data[0], self.H_data[1], self.H_data[2], self.H_levels, 
                                        contours=self.use_contour_bg, ax=ax, ylim=self.y_lims, xlim=self.limits)
            else:
                if self.use_contour_bg:
                    separatrices_mh_contours(
                        self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                        ax=ax, below_transition=self.below_transition, contour_margin=1.1, no_fill=True
                    )
                else:
                    separatrices_mh_colored(
                        self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                        ax=ax, fill_potential=None, below_transition=self.below_transition, outer_fill_alpha=0.0)

        
        # Create the histogram
        xedges = np.linspace(self.limits[0], self.limits[1], xbins+1, endpoint=True)
        yedges = np.linspace(self.y_lims[0], self.y_lims[1], ybins+1, endpoint=True)

        hist, xedges, yedges = np.histogram2d(self.state[-1,0,:], self.state[-1,1,:], (xedges, yedges))

        # Create a masked array where 0 values are masked
        masked_hist = np.ma.masked_where(hist == 0, hist)
        
        # Plot the histogram with masked values appearing as white
        hist2d = ax.pcolormesh(xedges, yedges, masked_hist.T, cmap='jet')
        hist2d.set_clim(0.0001, np.max(hist))

        if ax is None:
            fig.suptitle('Particle beam trajectory at turn {}'.format(self.turn))
            
            ax.set_xlabel(r'$\phi$ [rad]')
            
            ax.set_ylabel(r'$\Delta E$ [arb. units]')
            plt.show(block=True)
        
    def advance_and_plot_phase_space(self, x_turns, xbins=100, ybins=100, separatrix = True, override_limits = None):
        """
        Plot initial and final distribution in phase space, but keeping the same limits so that 
        changes in bunch distribution are not lost.
        """
        plt.close()

        self.advance_x_turns(x_turns)

        # Plot the initial state
        fig, axScatter , axHistx, axHisty = plotPhaseSpace(self.state[0], xbins, ybins, xlim=self.limits, ylim=self.y_lims)

        if separatrix:
            if self.efficiency_mode:
                separatrix_plotter(self.H_data[0], self.H_data[1], self.H_data[2], self.H_levels, 
                                        contours=self.use_contour_bg, ax=axScatter, ylim=self.y_lims, xlim=self.limits)
            else:
                if self.use_contour_bg:
                    separatrices_mh_contours(
                        self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                        ax=axScatter, below_transition=self.below_transition, contour_margin=1.1, no_fill=True
                    )
                else:
                    separatrices_mh_colored(
                        self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                        ax=axScatter, fill_potential=None, below_transition=self.below_transition, outer_fill_alpha=0.0)
        

        axScatter.set_title('')
        if override_limits is not None:
            axHistx.set_ylim(override_limits[0])
            axHisty.set_xlim(override_limits[1])
            ylim_x = override_limits[0]
            xlim_y = override_limits[1]
        else:
            # Get the current ylim/xlim of the 1D distributions so changes are visible
            ylim_x = axHistx.get_ylim()
            xlim_y = axHisty.get_xlim()

        axHistx.set_xlim(axScatter.get_xlim())
        axHisty.set_ylim(axScatter.get_ylim())
            

        plt.show()

        fig2, axScatter2 , axHistx2, axHisty2 = plotPhaseSpace(self.state[-1], xbins, ybins, xlim=self.limits, ylim=self.y_lims)
        
        if separatrix:
            if self.efficiency_mode:
                separatrix_plotter(self.H_data[0], self.H_data[1], self.H_data[2], self.H_levels, 
                                        contours=self.use_contour_bg, ax=axScatter2, ylim=self.y_lims, xlim=self.limits)
            else:
                if self.use_contour_bg:
                    separatrices_mh_contours(
                        self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                        ax=axScatter2, below_transition=self.below_transition, contour_margin=1.1, no_fill=True
                    )
                else:
                    separatrices_mh_colored(
                        self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                        ax=axScatter2, fill_potential=None, below_transition=self.below_transition, outer_fill_alpha=0.0)

        axScatter2.set_title('')
        axHistx2.set_ylim(ylim_x) # Set the ylim of the 1D distribution to the same as the initial distribution
        axHisty2.set_xlim(xlim_y) # Set the xlim of the 1D distribution to the same as the initial distribution

        axHistx2.set_xlim(axScatter2.get_xlim()) 
        axHisty2.set_ylim(axScatter2.get_ylim()) 



        plt.show()
        return (axHistx2.get_ylim(), axHisty2.get_xlim())

    def plot_phase_space(self, xbins=100, ybins=100, separatrix = True):
        """
        Plot the phase space with 2D and 1D histograms
        """
        plt.close()

        state = self.state[-1]

        fig, axScatter , axHistx, axHisty = plotPhaseSpace(state, xbins, ybins, xlim=self.limits, ylim=self.y_lims)

        if separatrix:
            if self.efficiency_mode:
                separatrix_plotter(self.H_data[0], self.H_data[1], self.H_data[2], self.H_levels, 
                                        contours=self.use_contour_bg, ax=axScatter, ylim=self.y_lims, xlim=self.limits)
            else:
                if self.use_contour_bg:
                    separatrices_mh_contours(
                        self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                        ax=axScatter, below_transition=self.below_transition, contour_margin=1.1, no_fill=True
                    )
                else:
                    separatrices_mh_colored(
                        self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                        ax=axScatter, fill_potential=None, below_transition=self.below_transition, outer_fill_alpha=0.0)
        

        axScatter.set_title('')
        axHistx.set_xlim(axScatter.get_xlim())
        axHisty.set_ylim(axScatter.get_ylim())

        fig.show()
        


    def plot_single_trajectory(self, i, single_plotting=True,ax=None):
        """
        Plot the trajectory of a single particle
        """
        if ax is None:
            plt.close()
            fig, ax = plt.subplots()

        x = self.state[:, 0, i]
        y = self.state[:, 1, i]
        ax.plot(x, y)
        ax.scatter(x[0], y[0], marker='x', facecolor=None)
        ax.scatter(x[-1], y[-1], marker='o', facecolor=None)

        
        
        if single_plotting:
            if self.below_transition:
                ax.set_xlim(-np.pi, np.pi)
            else:
                ax.set_xlim(0, 2*np.pi)
            ax.set_xlabel(r"$\phi$ [rad]")
            ax.set_ylabel(r"$\Delta E$ [arb. units]")
            ax.set_title(r"Particle {} trajectory for {} turns".format(i, self.turn))
            plt.show()
        return

    def plot_trajectory(self, turns, ax=None, separatrix=True):
        """
        Plot the entire trajectory
        Parameters:
        -----------
        turns: number of turns to plot
        """
        self.advance_x_turns(turns)
        
        if ax is None:
            plt.close()
            fig, ax = plt.subplots()

        if separatrix:
            if self.efficiency_mode:
                ax = separatrix_plotter(self.H_data[0], self.H_data[1], self.H_data[2], self.H_levels, contours=self.use_contour_bg, ax=ax, ylim=self.y_lims, xlim=self.limits)
            else:
                if self.use_contour_bg:
                    separatrices_mh_contours(
                        self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                        ax=ax, below_transition=self.below_transition, contour_margin=1.1, no_fill=True
                    )
                else:
                    separatrices_mh_colored(
                        self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                        ax=ax, fill_potential=None, below_transition=self.below_transition, outer_fill_alpha=0.0)


        for i in range(self.n_particles):
            self.plot_single_trajectory(i, single_plotting=False, ax=ax)
        
        fig.suptitle('Particle beam trajectory after {} turns'.format(self.turn))
        ax.set_xlabel(r'$\phi$ [rad]')
        ax.set_ylabel(r'$\Delta E$ [arb. units]')
        plt.tight_layout()
        plt.show()
        return


    def compute_spectrum(self, turns, indices= None):
        """
        Compute oscillation spectrum for selected particles
        """
        if indices is None:
            indices = [1, self.n_particles//2, self.n_particles-1]

        self.advance_x_turns(turns)
        plt.close()

        for index in indices:
            plt.plot(*oscillation_spectrum(self.state[:, 0, index], fft_zero_padding = 0), color=plt.cm.viridis(index/self.n_particles))
        plt.xlim(0, 0.01)
        plt.xlabel("Synchrotron tune")
        plt.ylabel("Amplitude [norm.]")
        plt.show()

        
    
    def plot_separatrices(self):
        """
        Plot the separatrices
        """
        plt.close()
        
        fig, ax = plt.subplots()
        if self.efficiency_mode:
            ax = separatrix_plotter(self.H_data[0], self.H_data[1], self.H_data[2], self.H_levels, contours=self.use_contour_bg, ax=ax, ylim=self.y_lims, xlim=self.limits)
        else:
            if self.use_contour_bg:
                separatrices_mh_contours(
                    self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                    ax=ax, below_transition=self.below_transition, contour_margin=1.1, no_fill=True
                )
            else:
                separatrices_mh_colored(
                    self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                    ax=ax, fill_potential=None, below_transition=self.below_transition, outer_fill_alpha=0.0)
        plt.tight_layout()
        plt.show()
        
    
    def plot_potential(self):
        """
        Plot the potential
        """
        plt.close()
        plt.figure()
        plt.title(r"Potential $U(\phi)$")
        plt.xlabel(r"$\phi$ [rad]")
        plt.ylabel(r"$U(\phi)$ [arb. units]")
        plt.plot(self.phi_potential, self.potential)
        plt.show()
    
    def plot_voltage_and_potential(self):
        """
        Plot the voltage and potential
        """
        plt.close()
        fig, ax = plt.subplots(1, 2)
        ax[0].plot(self.phi_potential, voltage_mh(self.V, self.r, self.h, self.phi_potential, self.Phis))
        ax[0].axvline(self.limits[0], color='r', linestyle='--', label=r"$h_1$ period")
        ax[0].axvline(self.limits[1], color='r', linestyle='--')
        ax[0].set_title(r"Voltage $V(\phi)$")
        ax[0].set_xlabel(r"$\phi$ [rad]")
        ax[0].set_ylabel(r"$V(\phi)$ [arb. units]")
        ax[0].legend()

        ax[1].plot(self.phi_potential, self.potential)
        ax[1].axvline(self.limits[0], color='r', linestyle='--', label=r"$h_1$ period")
        ax[1].axvline(self.limits[1], color='r', linestyle='--')
        ax[1].set_title(r"Potential $U(\phi)$")
        ax[1].set_xlabel(r"$\phi$ [rad]")
        ax[1].set_ylabel(r"$U(\phi)$ [arb. units]")
        ax[1].legend()
        plt.tight_layout()
        plt.show()

    def plot_all(self):
        """
        Plot the voltage, potential, and separatrices in one figure
        """
        plt.close()
        fig, ax = plt.subplots(1, 3)
        ax[0].plot(self.phi_potential, voltage_mh(self.V, self.r, self.h, self.phi_potential, self.Phis))
        ax[0].axvline(self.limits[0], color='r', linestyle='--', label=r"$h_1$ period")
        ax[0].axvline(self.limits[1], color='r', linestyle='--')
        ax[0].set_title(r"Voltage $V(\phi)$")
        ax[0].set_xlabel(r"$\phi$ [rad]")
        ax[0].set_ylabel(r"$V(\phi)$ [arb. units]")
        ax[0].legend()
        
        ax[1].plot(self.phi_potential, self.potential)
        ax[1].axvline(self.limits[0], color='r', linestyle='--', label=r"$h_1$ period")
        ax[1].axvline(self.limits[1], color='r', linestyle='--')
        ax[1].set_title(r"Potential $U(\phi)$")
        ax[1].set_xlabel(r"$\phi$ [rad]")
        ax[1].set_ylabel(r"$U(\phi)$ [arb. units]")
        ax[1].legend()
        
        if self.efficiency_mode:
            ax[2] = separatrix_plotter(self.H_data[0], self.H_data[1], self.H_data[2], self.H_levels, contours=self.use_contour_bg, ax=ax[2], ylim=self.y_lims, xlim=self.limits)
        else:
            self.sep_fig, ax[2], _, self.phi_potential, self.potential, _ = separatrices_mh_colored(
                self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                ax=ax[2], fill_potential=None, below_transition=self.below_transition, outer_fill_alpha=0.0 
            )
        self.plot_state(ax=ax[2])

        plt.tight_layout()
        plt.show()
    
    def modify_RF_system(self, V, r, h, Phis, dE_s):
        """
        Modify the RF system and recalculate the separatrices
        """
        self.V = V
        self.r = r
        self.h = h
        self.Phis = Phis
        self.dE_s = dE_s
        self.phi_s = dE_to_phi_s(self.dE_s, self.V, self.r, self.h, self.Phis, self.below_transition)

        if self.efficiency_mode:
            if not self.use_contour_bg:
                self.sep_fig, separatrices_plot, self.H_levels, self.phi_potential, self.potential, self.H_data = separatrices_mh_colored(
                    self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                    ax=None, fill_potential=None, below_transition=self.below_transition, outer_fill_alpha=0.0 
                )
            else:
                self.sep_fig, separatrices_plot, self.H_levels, self.phi_potential, self.potential, self.H_data = separatrices_mh_contours(
                    self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                    ax=None, below_transition=self.below_transition, contour_margin=1.1, no_fill=True
                )

        else:
            if not self.use_contour_bg:
                self.sep_fig, separatrices_plot, _, self.phi_potential, self.potential, _ = separatrices_mh_colored(
                    self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                    ax=None, fill_potential=None, below_transition=self.below_transition, outer_fill_alpha=0.0 
                )
            else:
                self.sep_fig, separatrices_plot, _, self.phi_potential, self.potential, _ = separatrices_mh_contours(
                    self.V, self.r, self.h, self.phi_s, self.Phis, self.dE_s, self.E, self.constant,
                    ax=None, below_transition=self.below_transition, contour_margin=1.1, no_fill=True
                )
        self.y_lims = separatrices_plot.get_ylim()

class MultiHarmonicTrackAnimation(object):
    def __init__(
        self, particle_beam: ParticleBeam, figname:str, iterations:int, framerate:int,
        xbins:int=50, ybins:int=50, xlim:list[float]=None, ylim:list[float]=None,
        name:str=None, fill_potential:list[float]=None,
        modifications:list[list[float]]=None, modification_turns:list[int]=None):

        """ 
        Class to animate the multi-harmonic tracking

        Parameters:
        -----------
        particle_beam: the particle beam object
        figname: the name of the figure 
        iterations: the number of iterations
        framerate: the framerate of the animation 
        xbins: the number of bins in the x-axis for the histogram
        ybins: the number of bins in the y-axis for the histogram
        xlim: the x-axis limits for the histogram (default is the limits of the particle beam)
        ylim: the y-axis limits for the histogram (default is the y-axis limits of the separatrices)
        name: the name of the animation (to save the animation, default is 'multi_harmonic_animation')
        fill_potential: the fill potential (default is None)
        modifications: the modifications to the RF system, it is a list of lists, each list contains the parameters of the modification [V, r, h, Phis, dE_s] (default is None)
        modification_turns: the turns at which the modifications are applied, it is a list of integers (default is None)
        """

        self.particle_beam = particle_beam
        self.figname = figname
        self.iterations = iterations
        self.framerate = framerate
        self.xbins = xbins
        self.ybins = ybins
        self.xlim = xlim
        self.ylim = ylim
        self.name = name
        self.fill_potential = fill_potential
        self.modifications = modifications
        self.modification_turns = modification_turns

        if self.xlim is None:
            self.xlim = self.particle_beam.limits
        if self.ylim is None:
            self.ylim = self.particle_beam.y_lims
        
        # If the separatrix changes, then we can't use blitting
        if self.modifications is not None and self.modification_turns is not None:
            self.blitting = False
        else:
            self.blitting = True

    def run_animation(self):
        """
        Run the animation and save as GIF
        """
        self._init(remove_labels=True)
        
        anim = animation.FuncAnimation(
            self.anim_fig, self._animate, init_func=self._init,
            frames=self.iterations, interval=1000/self.framerate, blit=self.blitting)
        # Determine output filename
        if self.name is None:
            self.name = "multi_harmonic_animation"

        output_file = self.name + ".gif"
        
        # Save the animation
        anim.save(output_file, writer='pillow')
        print('Animation saved to ' + output_file)
        return output_file

    def _init(self, remove_labels=False):
        """
        Initialize the animation plot with multi-harmonic separatrices
        """
        # Create the phase space plot with histograms
        plotPhaseSpace(self.particle_beam.state[-1], figname=self.figname,
                       xbins=self.xbins, ybins=self.ybins,
                       xlim=self.xlim, ylim=self.ylim)

        # Add multi-harmonic separatrices 
        if self.particle_beam.use_contour_bg:
            _, _, _, _, _, _ = separatrices_mh_contours(
                self.particle_beam.V, 
                self.particle_beam.r, 
                self.particle_beam.h,
                self.particle_beam.phi_s, 
                self.particle_beam.Phis, 
                self.particle_beam.dE_s, 
                self.particle_beam.E, 
                self.particle_beam.constant,
                ax=axScatter,
                below_transition=self.particle_beam.below_transition,
                contour_margin=1.1, no_fill=True
            )
        else:
            _, _, _, _, _, _ = separatrices_mh_colored(
            self.particle_beam.V, 
            self.particle_beam.r, 
            self.particle_beam.h,
            self.particle_beam.phi_s, 
            self.particle_beam.Phis, 
            self.particle_beam.dE_s, 
            self.particle_beam.E, 
            self.particle_beam.constant,
            ax=axScatter,
            fill_potential=self.fill_potential,
            below_transition=self.particle_beam.below_transition,
            outer_fill_alpha=0.0, 
            force_ylim=True
        )
        axScatter.set_title("")
        if remove_labels:
            axScatter.set_ylabel('')
            axScatter.set_xlabel('')
            axScatter.set_xticklabels([])
            axScatter.set_yticklabels([])

        # Make the grids of all the axes the same
        axHistx.set_xlim(axScatter.get_xlim())
        axHisty.set_ylim(axScatter.get_ylim())

        self.anim_fig = plt.gcf()
        return (line_phase, line_energy, distri_plot)

    def _animate(self, i):
        """
        Animate function for each frame
        """
        # Track particles for one turn
        self.particle_beam.plus_one_turn()
        if self.modifications is not None and np.any([i + 1 == mod_turn for mod_turn in self.modification_turns]):
            mod_index = np.where(np.array(self.modification_turns) == i + 1)[0][0]

            self.particle_beam.modify_RF_system(self.modifications[mod_index][0], self.modifications[mod_index][1], self.modifications[mod_index][2], self.modifications[mod_index][3], self.modifications[mod_index][4])

            # Clear the plot
            axScatter.cla()

            
            

            if self.particle_beam.use_contour_bg:
                _, _, _, _, _, _ = separatrices_mh_contours(
                                    self.particle_beam.V, 
                                    self.particle_beam.r, 
                                    self.particle_beam.h,
                                    self.particle_beam.phi_s, 
                                    self.particle_beam.Phis, 
                                    self.particle_beam.dE_s, 
                                    self.particle_beam.E, 
                                    self.particle_beam.constant,
                                    ax=axScatter,
                                    below_transition=self.particle_beam.below_transition,
                                    contour_margin=1.1, no_fill=True
                                )
            else:
                _, _, _, _, _, _ = separatrices_mh_colored(
                                    self.particle_beam.V, 
                                    self.particle_beam.r, 
                                    self.particle_beam.h,
                                    self.particle_beam.phi_s, 
                                    self.particle_beam.Phis, 
                                    self.particle_beam.dE_s, 
                                    self.particle_beam.E, 
                                    self.particle_beam.constant,
                                    ax=axScatter,
                                    fill_potential=self.fill_potential,
                                    below_transition=self.particle_beam.below_transition,
                                    outer_fill_alpha=0.0,
                                    force_ylim=True
                                    )
                
            axScatter.set_title("")
            
            
            
            # Create the histogram
            xedges = np.linspace(self.xlim[0], self.xlim[1], self.xbins+1, endpoint=True)
            yedges = np.linspace(self.ylim[0], self.ylim[1], self.ybins+1, endpoint=True)
            hist, xedges, yedges = np.histogram2d(self.particle_beam.state[-1][0], self.particle_beam.state[-1][1], (xedges, yedges))


            # Create a masked array where 0 values are masked
            masked_hist = np.ma.masked_where(hist == 0, hist)
            global distri_plot
            # Plot the histogram with masked values appearing as white (new artist)
            distri_plot = axScatter.pcolormesh(xedges, yedges, masked_hist.T, cmap='jet')
            distri_plot.set_clim(0.0001, np.max(hist))

            # Update the grids of all the axes
            axHisty.set_ylim(axScatter.get_ylim())
            axHistx.set_xlim(axScatter.get_xlim())

        else:
            # Update phase histogram
            hist_phase = np.histogram(self.particle_beam.state[-1], self.xbins, range=self.xlim)
            line_phase.set_data(hist_phase[1][0:-1]+(hist_phase[1][1]-hist_phase[1][0])/2, 
                            hist_phase[0]/np.max(hist_phase[0]))
            # Update energy histogram
            hist_energy = np.histogram(self.particle_beam.state[-1][1], self.ybins, range=self.ylim)
            line_energy.set_data(hist_energy[0]/np.max(hist_energy[0]), 
                            hist_energy[1][0:-1]+(hist_energy[1][1]-hist_energy[1][0])/2)

            # Update 2D distribution
            xedges, yedges = np.linspace(self.xlim[0], self.xlim[1], self.xbins+1, endpoint=True), \
                            np.linspace(self.ylim[0], self.ylim[1], self.ybins+1, endpoint=True)
            hist, xedges, yedges = np.histogram2d(self.particle_beam.state[-1][0], 
                                                self.particle_beam.state[-1][1], (xedges, yedges))
            
            # Create a masked array where 0 values are masked and flatten it
            masked_hist = np.ma.masked_where(hist == 0, hist)
            distri_plot.set_array(masked_hist.T.ravel())
            
            return (line_phase, line_energy, distri_plot)


def run_multi_harmonic_animation(particle_beam, figname, iterations, framerate,
                                name=None, xbins=50, ybins=50, xlim=None, ylim=None, 
                                fill_potential=None, modifications=None, modification_turns=None):
    """
    Convenience function to run multi-harmonic animation
    Parameters:
        -----------
        particle_beam: the particle beam object
        figname: the name of the figure 
        iterations: the number of iterations
        framerate: the framerate of the animation 
        xbins: the number of bins in the x-axis for the histogram
        ybins: the number of bins in the y-axis for the histogram
        xlim: the x-axis limits for the histogram (default is the limits of the particle beam)
        ylim: the y-axis limits for the histogram (default is the y-axis limits of the separatrices)
        name: the name of the animation (to save the animation, default is 'multi_harmonic_animation')
        fill_potential: the fill potential (default is None)
        modifications: the modifications to the RF system, it is a list of lists, each list contains the parameters of the modification [V, r, h, Phis, dE_s] (default is None)
        modification_turns: the turns at which the modifications are applied, it is a list of integers (default is None)
    """
    trackanim = MultiHarmonicTrackAnimation(particle_beam, figname, iterations, framerate,
                                           xbins=xbins, ybins=ybins, xlim=xlim, ylim=ylim, 
                                           name=name, fill_potential=fill_potential,
                                           modifications=modifications, modification_turns=modification_turns)
    
    return trackanim.run_animation()


def plotPhaseSpace(distribution, figname=None,
                   xbins=50, ybins=50,
                   xlim=None, ylim=None, color=None):
    """
    Plot phase space with histograms
    """
    fig = plt.figure(figname, figsize=(8,8))
    # Definitions for placing the axes
    left, width = 0.115, 0.63
    bottom, height = 0.115, 0.63
    bottom_h = left_h = left+width+0.03

    rect_scatter = [left, bottom, width, height]
    rect_histx = [left, bottom_h, width, 0.2]
    rect_histy = [left_h, bottom, 0.2, height]

    global axHistx
    axHistx = plt.axes(rect_histx)
    global axHisty
    axHisty = plt.axes(rect_histy)
    global axScatter
    axScatter = plt.axes(rect_scatter)
    
    hist_phase = np.histogram(distribution[0], xbins, range=xlim)
    global line_phase
    line_phase, = axHistx.plot(hist_phase[1][0:-1]+(hist_phase[1][1]-hist_phase[1][0])/2, hist_phase[0])    
    axHistx.axes.get_xaxis().set_ticklabels([])      
    axHistx.axes.get_yaxis().set_ticklabels([])  
    axHistx.set_ylabel('Bunch profile $\\lambda_{\\phi}$')
        
    hist_energy = np.histogram(distribution[1], ybins, range=ylim)
    global line_energy
    line_energy, = axHisty.plot(hist_energy[0], hist_energy[1][0:-1]+(hist_energy[1][1]-hist_energy[1][0])/2)    
    axHisty.axes.get_xaxis().set_ticklabels([])  
    axHisty.axes.get_yaxis().set_ticklabels([])  
    axHisty.set_xlabel('Energy spread $\\lambda_{\\Delta E}$')
    
    global distri_plot
    
    if color is not None:
        points = color + 'o'
    
    # Create the histogram
    xedges = np.linspace(xlim[0], xlim[1], xbins+1, endpoint=True)
    yedges = np.linspace(ylim[0], ylim[1], ybins+1, endpoint=True)
    hist, xedges, yedges = np.histogram2d(distribution[0], distribution[1], (xedges, yedges))

    # Create a masked array where 0 values are masked
    masked_hist = np.ma.masked_where(hist == 0, hist)
    
    # Plot the histogram with masked values appearing as white
    distri_plot = axScatter.pcolormesh(xedges, yedges, masked_hist.T, cmap='jet')
    distri_plot.set_clim(0.0001, np.max(hist))
    
    

    return fig, axScatter, axHistx, axHisty


def oscillation_spectrum(phase_track, fft_zero_padding=0):
    """
    Calculate oscillation spectrum (placeholder - would need to be imported from support_functions)
    """
    n_turns = len(phase_track)
    
    freq_array = np.fft.rfftfreq(n_turns+fft_zero_padding)
    fft_osc = np.abs(
        np.fft.rfft(
            phase_track-np.mean(phase_track),
            n_turns+fft_zero_padding)* 2/(n_turns))

    return freq_array, fft_osc

### Basic Accelerator and Beam Parameters
- Note that variables `r`, `h`, `dE_list` and `Phis` are lists (as they are intended for multiple harmonics) 
- For the first entry of `r`, keep it to 1 as it is multiplied by $V_1$ 


In [None]:
energy=10     #typically [MeV], [GeV], [TeV]
beta=0.9      #relativistic velocity factor
charge=1      #in units of the electron charge, just unity for protons
V=2           #RF voltage, typically in units of [V], [kV], [MV] (should be in a similar order of magnitude as the energy)
main_harmonic=1    #Harmonic number of the RF system
eta= -0.01      # 1/gamma_transition**2 - 1/gamma**2

# eta is negative when the particle is below the transition energy
if np.sign(eta) == -1:
    below_transition = True
else:
    below_transition = False
    
constant = np.abs(main_harmonic * eta / (beta**2))


filling_factor = None # What fraction of the potential is filled by particles (just for visualization)
r = [1, 1.3, 0.9, 1.2, 0.9] # Ratios of the n harmonics to the 1st harmonic
h = [1,2,3,4,5] # Harmonic numbers of the n harmonics
dE_list =  [0 , 0.2, 3] # Energy gain per turn of the particle (can be a list for multiple particles)
Phis = [0* np.pi,  np.pi , 0*np.pi, np.pi, 0*np.pi] # Absolute phase shifts of the n harmonics relative to the 1st harmonic


#### Example of a plotting sequence 
Shows the accelerating voltage, RF potential and separatrices for a triple harmonic operation


In [None]:
# Example: side-by-side accelerating voltage, potentials and separatrices for multiple r values
# Tip: if you want to make space and change the plot, use to open up an interactive window with mpl.use('TkAgg')
# but carefull because it will block the code execution and you will have to close the window to continue
# AND MORE IMPORTANTLY, it will mess up the animations, so restart your kernel if you use it before 
# making an animation

# mpl.use('TkAgg') # Uncomment this line to open up an interactive window with the plot

# Compute the phi_s values for each dE 
phi_s_list = []
for i in range(len(dE_list)):
    phi_s_list.append(dE_to_phi_s(dE_list[i], V, r, h, Phis, below_transition))


fig, (ax1, ax2, ax3) = plt.subplots(1, 3)

if below_transition:   
    phi = np.linspace(-1.5*np.pi, 1.5*np.pi, 1000)
else:
    phi = np.linspace(-0.5*np.pi, 2.5*np.pi, 1000)
    
# Plot the voltage on ax1
ax1.plot(phi, voltage_mh(V, r, h, phi, Phis)) 
if below_transition:
    ax1.axvline(-np.pi, color='red', linestyle='--')
    ax1.axvline(np.pi, color='red', linestyle='--') 
else:
    ax1.axvline(0, color='red', linestyle='--')
    ax1.axvline(2*np.pi, color='red', linestyle='--')
ax1.set_xlabel(r"$\phi$ [rad]")
ax1.set_ylabel(r"Voltage $V(\phi)$ [arb. units]")

label = r"$V(\phi)$ for $r_2 = %.2f, r_3 = %.2f$" % ( r[1], r[2])
ax1.set_title(label)

_, _, _, colors = plot_potentials_multi_phi_s(V, r, h, phi_s_list, Phis, dE_list, ax=ax2, fill_potential=filling_factor, below_transition=below_transition)

# draw separatrices for the last r as an example on the right plot, using the same color
for i, phi_s in enumerate(phi_s_list):
    separatrices_mh_colored(V, r, h,phi_s, Phis , dE_list[i], energy, constant, ax=ax3, base_color=colors[i], phi_margin_percent=10,epsilon=1e-12, fill_potential=filling_factor, below_transition=below_transition)
plt.tight_layout()
plt.show()



# Example of contour plots for no acceleration (index 0)

fig, ax1 = plt.subplots(1,1)

fig.suptitle(r"Hamiltonian Contours for $\Delta E_s = " + str(dE_list[0]) + r"$ arb. units")


separatrices_mh_contours(V, r, h,phi_s_list[0], Phis, dE_list[0], energy, constant, ax=ax1, below_transition=below_transition)
  
plt.tight_layout()
plt.show()


# Example of contour plots for acceleration (index 2)

fig, ax1 = plt.subplots(1,1)

fig.suptitle(r"Hamiltonian Contours for $\Delta E_s = " + str(dE_list[0]) + r"$ arb. units")

separatrices_mh_contours(V, r, h,phi_s_list[2], Phis, dE_list[2], energy, constant, ax=ax1, below_transition=below_transition)
  
plt.tight_layout()
plt.show(block=True)


### Example of ParticleBeam Class Utilization

In [None]:
# Generate initial particle distribution
n_particles = 30000
bunch_position = 0
bunch_length = 4 # rad
bunch_energy = 0
energy_spread = 40 # Whatever units the initial energy is in
type = 'gaussian' # can be 'gaussian' or 'rectangular'

particle_phases, particle_energies = generateBunch(bunch_position, bunch_length, 
                                                  bunch_energy, energy_spread, n_particles, type)



# Create initial particle state
particles = np.vstack([particle_phases, particle_energies])



# Create ParticleBeam instance
particle_beam = ParticleBeam(
    particles=particles,
    V=V,  # from your parameters
    r=r,  # from your parameters  
    h=h,  # from your parameters
    dE_s=dE_list[0],  # from your parameters
    Phis=Phis,  # from your parameters
    below_transition=below_transition,  # from your parameters
    constant=constant,  # from your parameters
    E=energy,  # from your parameters
    use_contour_bg=True
)



# # Plot the waveform, potential, separatrix and distribution
particle_beam.plot_all()
plt.show()


# Plot initial and final phase space after 1000 turns
particle_beam.advance_and_plot_phase_space(1000)
plt.show()





### Example Animation with RF System Changes 

In [None]:
# # Run animation
turns = 400 # Number of turns to simulate
run_multi_harmonic_animation(
    particle_beam=particle_beam,
    figname='Multi-Harmonic Animation',
    iterations=turns,
    framerate=30,
    name='multi_harmonic_demo_mod',
    xbins=50,
    ybins=50,
    xlim=(-np.pi, np.pi) if below_transition else (0, 2*np.pi),
    ylim=None,
    modifications = [[V*1.1, r, h, [0* np.pi,  0*np.pi , np.pi , 0*np.pi, np.pi], dE_list[0]], [V*1.5, r, h, [0* np.pi,  np.pi , 0*np.pi , np.pi, 0*np.pi], dE_list[0]]], # Parameters of the modification
    modification_turns = [turns//3, turns//2] # Turn at which the modification is applied
)