This is an example for the usage of [Kwasniok/astra-web](https://github.com/Kwasniok/astra-web).

Please read the README.md first to ensure a proper setup of the server.

To become familiar with astra-web it is best to run this notebook on the same machine as the astra-server (`localhost`).
Experienced users may immediately connect to a remote astra-server by changing the `server_url` variable below.


# What you will learn in this example

- configure access to the astra-web server
- instruct the server to do the simulations for you (dispatch)
- optional: instruct the server to use a SLURM server for the simulations instead (dispatch to SLURM)

This example shows a typical workflow for creating a large scale simulation campaign with ASTRA.
For deomnstration purposes the number of simulations and the complexity of the simulations is kept low.

# Preparation

1. start the astra-web server (see README.md)
2. create a file in this folder called `.env` and setup `ASTRA_WEB_API_KEY="<api_key>"` matching the servers configuration (see README.md)
3. adjust variables below as you encounter them (e.g. `server_url`, `server_port`, `slurm_enabled`, `slurm_user_name`, etc.)

# Import


In [1]:
from typing import Any, Iterable
import os
from time import sleep
from datetime import datetime
from IPython.display import clear_output
from tqdm import tqdm
import numpy as np
import scipy
import pandas as pd
from astra_web_request import request, RequestMethod, ClientResponseError
import plotly.express as px
import plotly.graph_objects as go
import plotly.subplots
from sklearn.model_selection import ParameterSampler

import slurm_requests as slurm

# Config


## Connection to Server

In [2]:
# The name of the machine which is running the astra-web serve.
# Use `localhost` if you are running the server on the same machine as this notebook.
# Otherwise it might look something like `my.server.com` or `192.168.1.1`.
server_hostname = f"localhost"

# Advanced:
# The port on which the astra-web server is running.
# Default is ` 8000` unless you have changed it in the server configuration.
# Note: You can have multiple astra-web servers running on the same machine by using different ports.
server_port = 8000

os.environ["ASTRA_WEB_URL"] = f"http://{server_hostname}:{server_port}/"

## SLURM (optional)

Enable SLURM only if you have setup and configured your server correctly for SLURM!
If you are unsure just leave it disabled.

Expert users may be interested in that astra-web talks to SLURM via its [REST API](https://slurm.schedmd.com/rest_api.html).

In [3]:
slurm_enabled = False
slurm_user_name = "YOUR SLURM USER NAME HERE"
slurm_partition = "cpu"
slurm_constraints = None

# SLURM preparation


### check if SLURM is alive


In [4]:
if slurm_enabled:
    # ping
    response = await request("slurm/ping", RequestMethod.GET)
    print("OK" if len(response["errors"] + response["warnings"]) == 0 else "ERROR")

### configuration


In [5]:
if slurm_enabled:
    config = await request(
        f"slurm/configuration",
        RequestMethod.GET,
    )

    config["user_name"] = slurm_user_name
    config["partition"] = slurm_partition
    config["constraints"] = slurm_constraints

    # general slurm client for queue inspection
    slurm.init_defaults(
        url=config["base_url"],
        api_version=config["api_version"],
        user_name=config["user_name"],
        user_token=config["user_token"],
    )

    config = await request(
        f"slurm/configuration",
        RequestMethod.PUT,
        config,
    )

### define how to renew JWT token (please adapt this if needed!)
SLURM uses temporary token to authenticate users.
These are valid only for a limited time each.
Adapt `get_new_slurm_token` to make sure it works with your cluster.
Some clusters may require you to enable these tokens first (e.g. use portal tokens).
Please contact your local IT service regarding **JWT tokens for SLURM** if you are unsure.

E.g for the SLUR cluster Maxwell at DESY, Hamburg see https://docs.desy.de/maxwell/services/slurm_rest_api/#json-web-token-jwt for more info (internal web page accessible to DESY users only).

In [6]:
_last_token_renewal: datetime | None = None

# ATTENTION: Adapt this function to your cluster's token generation method!
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
def get_new_slurm_token():
    o = !/software/tools/bin/slurm_token -l $((60*60*24))
    slurm_token = o[0].split("=")[1]
    return slurm_token
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

async def renew_slurm_token_if_needed(show: bool = True, max_age_hours: int = 1):

    global _last_token_renewal

    # check if renewal is needed
    if _last_token_renewal is not None:
        delta = datetime.now() - _last_token_renewal
        if delta.total_seconds() < 60*60*max_age_hours:
            if show:
                print(
                    f"SLURM token renewed {delta.total_seconds()/(60*60):.1f} hours ago, skipping renewal."
                )
            # token is still fresh, skip renewal
            return

    # request new token
    slurm_token = get_new_slurm_token()
    _last_token_renewal = datetime.now()

    # update stored token
    await request(
        f"slurm/configuration/user_token?value={slurm_token}",
        RequestMethod.PUT,
    )
    slurm.default.user_token = slurm_token

    # check credentials and overall status
    success = False
    try:
        response = await request("slurm/diagnose", RequestMethod.GET)
        success = len(response["errors"] + response["warnings"]) == 0
        await slurm.diagnose()
    except RuntimeError as e:
        pass

    # report status
    if success:
        print("SLURM credentials and token OK")
    else:
        raise RuntimeError("SLURM credentials or token not OK.")
    
if slurm_enabled:
    await renew_slurm_token_if_needed()

# Definition of Experiment (= Bundle of Related Simulations)


## Name

In [7]:
# The database can hold many simulations at the same time.
# Each experiment (=collection of related simulations) should have a unique name.
experiment_name = "example v1"

# Advanced:
# In case of multiple, large experiments, you can speed up downloading data from the server by pre-filtering the simulations.
# Leave this untouched unless you know what you are doing.
# Improper filters may lead to simulations being silently skipped during feature extraction!
feature_extraction_filter_experiment_name_regex = None  # example: r"example v*"
feature_extraction_filter_date_start = "2000-01-01"
feature_extraction_filter_date_end = "2999-12-31"

## Logbook

- 2026-01-01:
  - `experiment_name = "example v1"`
  - some example runs which can be deleted later

### Field Files


Fields must be provided as field files (see ASTRA manual for details).

ATTENTION: **Field files are shared among all simulations!**

Use unique names per field.
Later during upload already existing field files are not uploaded again.

Here we assume the files describe the amplitudes of the longitudinal field components.

In [28]:
def load_fields(file_names: Iterable[str]) -> dict[str, dict[str, Any]]:
    return {
        k: pd.read_csv(
            os.path.abspath(f"./data/in/fields/{k}"), names=["z", "v"], sep=r"\s+"
        ).to_dict("list")
        for k in file_names
    }


fields = load_fields(
    (
        "C1_E.dat",
        "S1_B.dat",
    )
)

### Generator Configuration


This example uses the `generator` program proveded with ASTRA to create an initial particle distribution based on parameters.
In case this is not sufficient for your needs you can also provide your own initial particle distribution files and upload them instead (see documentation of astra-web for details).

Here we create a simple electron emission profile with a radially uniform distribution  in the transverse directions, a flattop distribution in time and an isotropic momentum distribution.
For this experiment the total charge, laser pulse pulse duration, and number of macro particles can be adjusted.
All other parameters stay fixed.

In [9]:
def generator_input(
    comment: str | None = None,
    particle_count: int = 20_000,
    total_charge: float = 0.250,  # nC
    emission_time_fwhm: float = 10e-3,  # ns
) -> dict[str, Any]:

    return {
        "comment": comment,
        "particle_count": particle_count,
        "total_charge": total_charge,
        # x
        "dist_x": "radial_uniform",
        "dist_x_bunch_size_rms": 0.3,
        # y
        "dist_y": "radial_uniform",
        "dist_y_bunch_size_rms": 0.3,
        # z/t
        "dist_z": "flattop",
        "use_time_spread": True,
        "dist_t_emission_time": emission_time_fwhm,
        "dist_t_emission_time_rise": 0.001,  # ns
        # pz
        "dist_pz": "isotropic",
        "dist_pz_energy_width": 0.55e-3,  # keV
        # reference particle
        "reference_kinetic_energy": 0,  # MeV
    }

### Modules Configuration


The modules describe the layout of the accelerator. As an example, a gun and its focussing solenoid are used here.

Notice the attached optional comments for each module for better identification.

In [10]:
def modules_input(
    gun_E_max: float = 60,  # MV/m
    gun_phase: float = 0,  # deg
    sol_B_max: float = 0.21,  # T
) -> dict[str, Any]:

    return {
        "cavities": [
            {
                "comment": "gun",
                "field_file_name": "C1_E.dat",
                "max_field_strength": gun_E_max,  # MV/m
                "frequency": 1.3,  # GHz
                "phase": gun_phase,
                "z": 0.0,  # m
                "smoothing_iterations": 0,
                "higher_order": True,
            },
        ],
        "solenoids": [
            {
                "comment": "focussing solenoid",
                "field_file_name": "S1_B.dat",
                "max_field_strength": sol_B_max,  # T
                "smoothing_iterations": 10,
                "z": 0.0,
            }
        ],
    }

#### Plot Example Setup


Use the default values to plot an example setup of the modules/accelerator layout.

In [11]:
def make_modules_plt(input: dict[str:Any]) -> go.Figure:
    fig = plotly.subplots.make_subplots(specs=[[{"secondary_y": True}]])

    # add cavities
    for c in input["cavities"]:
        field = pd.DataFrame(fields[c["field_file_name"]])
        z = field["z"] + c["z"]
        e = field["v"] * c["max_field_strength"] / field["v"].abs().max()
        fig.add_scatter(
            x=z,
            y=e,
            secondary_y=False,
            mode="lines",
            name=f'{c["comment"]} @ {c["frequency"]} GHz, phi={c["phase"]}°',
        )
        fig.add_vline(x=c["z"], line_width=0.5, line_dash="solid", line_color="gray")
        if c["z"] != 0.0:
            fig.add_annotation(
                text=str(c["z"]),
                x=c["z"],
                y=1,
                xref="x",
                yref="paper",
                textangle=-45,
            )
    # add solenoids
    for s in input["solenoids"]:
        field = pd.DataFrame(fields[s["field_file_name"]])
        z = field["z"] + s["z"]
        b = field["v"] * s["max_field_strength"] / field["v"].abs().max()
        fig.add_scatter(
            x=z,
            y=b,
            secondary_y=True,
            mode="lines",
            name=s["comment"],
        )
        fig.add_vline(x=s["z"], line_width=0.5, line_dash="solid", line_color="gray")
        if s["z"] != 0.0:
            fig.add_annotation(
                text=str(s["z"]),
                x=s["z"],
                y=1,
                xref="x",
                yref="paper",
                textangle=-45,
            )

    fig.update_layout(
        title="Longitudinal Field Amplitudes",
        xaxis_title="z [m]",
    )
    fig.update_yaxes(title="E [MV/m] ", secondary_y=False)
    fig.update_yaxes(title="B [T]", secondary_y=True)
    fig.update_xaxes(range=[0, None])

    return fig


make_modules_plt(modules_input()).show()

### Simulation Run Configuration


Each simulation requires the name of the initial particle distribution, the modules layout, and various simulation parameters.

In [12]:
def simulation_input(
    gen_id: str,  # the ID of the initial particle distribution (distributions are usually created by the generator)
    comment: str | None,
    # general:
    z_stop: float = 1.0,  # m
    output_intervals: int = 2,  # 2 = start & end only
    # technical:
    timeout: int = 15 * 60,  # s
    thread_num: int = 1,  #  1=single thread, >1=multi thread via parallel astra
    # initial bunch rescaling:
    bunch_xy_rms: float = 1.0,  # mm, ASTRA allows us to rescale the initial bunch at the beginning of the simulation
    bunch_t_rms: float = 10e-3,  # ns, ASTRA allows us to rescale the initial bunch at the beginning of the simulation
    # gun:
    gun_E_max: float = 60,  # MV/m
    gun_phase: float = 0,  # deg
    sol_B_max: float = 0.21,  # T
) -> dict[str, Any]:

    return {
        "comment": comment,
        # optional optimization:
        # - Compress particle distribution data after run to save storage space.
        # - This will be invisible to the user and is recommended.
        # - Files which fail to ensure a relative error for **any** coefficient during compression below `max_rel_err` are stay uncompressed.
        "auto_compress_after_run": {
            "precision": "f32",
            "max_rel_err": 1e-4,
        },
        "run": {
            "generator_id": gen_id,
            "z_cathode": 0.0,
            "bunch_initial_xy_rms": bunch_xy_rms,
            "bunch_initial_t_rms": bunch_t_rms,
            "integrator_iteration_max": 100_000,
            "integrator_step_max": 0.001,
            "cavity_phase_auto": True,  # ASTRA autophasing
            "thread_num": thread_num,
            "timeout": timeout,
            "particle_track_all": True,
        },
        "output": {
            "z_stop": z_stop,
            "emittance_intervals": max(output_intervals, 2),
            "distribution_intervals": max(output_intervals, 2),
            "save_distribution_with_high_resolution": True,
            "save_emittance": True,
            "save_emittance_trace_space": True,
            "save_distribution": output_intervals > 0,
            "save_probe_particles": True,
            "save_space_charge_scaling": True,
            "save_space_charge_field_on_cathode": True,
            # optional screens (for output only):
            # "screens": [
            #     {"z": 2},  # m
            # ],
        },
        "space_charge": {
            "enable": True,
            "enable_mirror_charge": True,
            "scaling_relative_deviation_threshold": 0.05,
            "scaling_step_threshold": 50,
            "grid_2d_radial_cell_count": 35,
            "grid_2d_longitudinal_cell_count": 70,
            "grid_transition_z": 4.5,  # m
            "grid_3d_x_cell_count": 32,
            "grid_3d_y_cell_count": 32,
            "grid_3d_z_cell_count": 64,
            "emitted_particle_num_per_step": 200,
        },
        **modules_input(
            gun_phase=gun_phase,
            gun_E_max=gun_E_max,
            sol_B_max=sol_B_max,
        ),
    }

### Dispatch Helper Functions


# Upload Data & Dispatch Simulation Runs


## Ensure Field Files are Uploaded
Note error 409 can be ignored here (file already exists).

In [None]:
for file_name, table in fields.items():
    try:
        response = await request(f"fields/{file_name}", RequestMethod.PUT, table)
    except ClientResponseError as e:
        if e.status == 409:
            print(f"Field table '{file_name}' already exists, skipping.")
        else:
            raise RuntimeError(f"Error uploading {file_name}") from e

## Experiment Parameter Distributions


### Helper Functions for Parameter Distributions


In [14]:
constant = lambda x: scipy.stats.rv_discrete(values=([x], [1.0]))
random_uniform = lambda min, max: scipy.stats.uniform(min, max - min)
random_truncated_normal = lambda mu, sigma, min, max: scipy.stats.truncnorm(
    a=(min - mu) / sigma, b=(max - mu) / sigma, loc=mu, scale=sigma
)

In [15]:
def iter_rows_as_dict(df: pd.DataFrame) -> Iterable[dict[str, Any]]:
    """Iterate over rows of a DataFrame as dictionaries."""
    for row in df.itertuples(index=False):
        yield row._asdict()


def unique_values_per_column(df: pd.DataFrame) -> dict[str, set[Any]]:
    """Get unique values per column in a DataFrame."""
    return {col: set(df[col].unique()) for col in df.columns}


def print_unique_values_per_column(df: pd.DataFrame) -> None:
    """Print unique values per column in a DataFrame."""
    uv = unique_values_per_column(df)
    for col, values in uv.items():
        values = sorted(values)
        print(
            f"{col}  ({values[0]}...{values[-1]}, count={len(values)}): "
            + " ".join(map(str, values[:3] + ["..."] + values[-3:]))
        )

### Actual Values


In [16]:
seed = 2026_01_01
number_of_samples = 3
timeout = 60 * 60  # s

param_region = {
    # physical:
    "gun_phase": random_uniform(-10, 10),  # deg
    "gun_sol_B_max": constant(0.21),  # T
    "bunch_xy_rms": constant(1.0),  # mm
    "bunch_t_rms": random_uniform(10e-3, 100e-3),  # ns
    # technical:
    "particle_count": constant(10_000),  # e.g. 10_000 ... 2_000_000
    "z_stop": constant(0.5),  # m
    "output_intervals": constant(10),  # 2 = start & end only
    "thread_num": constant(16),
}

rng = np.random.RandomState(seed)

param_table = pd.DataFrame(
    ParameterSampler(
        param_region,
        n_iter=number_of_samples,
        random_state=rng,
    ),
    columns=param_region.keys(),
)
param_table.drop_duplicates(inplace=True)

# show summary
print(f"{len(param_table)} simulations to run.")
print(f"cumulative simulation timeout: {len(param_table)*timeout/3600:.1f} h")
print()
print_unique_values_per_column(param_table.round(decimals=4))

plot_param_dist = lambda param_name: px.histogram(
    param_table[param_name], title=param_name, width=600, height=300
)

for param_name in param_region.keys():
    plot_param_dist(param_name).show()

3 simulations to run.
cumulative simulation timeout: 3.0 h

gun_phase  (-1.7927...6.1397, count=3): -1.7927 1.0037 6.1397 ... -1.7927 1.0037 6.1397
gun_sol_B_max  (0.21...0.21, count=1): 0.21 ... 0.21
bunch_xy_rms  (1.0...1.0, count=1): 1.0 ... 1.0
bunch_t_rms  (0.0461...0.0715, count=3): 0.0461 0.068 0.0715 ... 0.0461 0.068 0.0715
particle_count  (10000...10000, count=1): 10000 ... 10000
z_stop  (0.5...0.5, count=1): 0.5 ... 0.5
output_intervals  (10...10, count=1): 10 ... 10
thread_num  (16...16, count=1): 16 ... 16


## Dispatch Simulation Runs

### Dispatch Helper Functions


Generator

In [17]:
async def dispatch_particle_generation(
    input: dict[str, Any],
    *,
    host: str = "local",
) -> str:
    """
    Dispatch particle distribution generation via ASTRA geneator.

    note: `host=local` by default which ensures, that the particle distribution exists upon returning.
    """
    response = await request(
        f"particles?host={host}",
        RequestMethod.POST,
        input,
    )
    gen_id = response["gen_id"]
    return gen_id


async def download_particle_distribution(gen_id: str) -> dict[str, Any]:
    response = await request(
        f"particles/{gen_id}",
        RequestMethod.GET,
    )
    return response["output"]["particles"]


async def upload_particle_distribution(
    dist: dict[str, Any],
    comment: str | None = None,
) -> str:
    response = await request(
        f"particles",
        RequestMethod.PUT,
        body=dict(
            comment=comment,
            output=dist,
        ),
    )
    return response["gen_id"]

Simulation

In [18]:
async def dispatch_simulation(
    input: dict[str, Any],
    *,
    host: str = "local",
    timeout: int = 1 * 60 * 60,  # sec
) -> str:
    """
    Dispatch simulation run via ASTRA.

    note: Execution may be done remotely (e.g. "host=slurm"). In that case return does not mean completion.
    """
    response = await request(
        f"simulations?host={host}",
        RequestMethod.POST,
        input,
        **(dict(timeout=int(timeout * 1.1)) if host == "local" else {}),
    )
    sim_id = response["sim_id"]
    return sim_id

Generation + Simulation in One Go

In [19]:
async def dispatch_generation_and_simulation_run(
    gun_phase: float,
    gun_sol_B_max: float,
    bunch_xy_rms: float,
    bunch_t_rms: float,
    # technical:
    particle_count: int,
    z_stop: float,
    output_intervals: int,
    thread_num: int,
):
    """
    Dispatch particle generation and simulation run as one request.
    """

    input = generator_input(
        comment=experiment_name,
        particle_count=particle_count,
    )
    gen_id = await dispatch_particle_generation(
        input,
    )

    input = simulation_input(
        comment=experiment_name,
        gen_id=gen_id,
        thread_num=thread_num,
        timeout=timeout,
        z_stop=z_stop,
        output_intervals=output_intervals,
        gun_phase=gun_phase,
        sol_B_max=gun_sol_B_max,
        bunch_xy_rms=bunch_xy_rms,
        bunch_t_rms=bunch_t_rms,
    )
    sim_id = await dispatch_simulation(
        input,
        host="slurm" if slurm_enabled else "local",
        timeout=timeout,
    )

    return sim_id

In [20]:
async def slurm_jobs_in_queue(print_status: bool = False) -> dict[str, int]:
    """Count number of SLURM jobs in pending or running state in given queue."""

    jobs = (await slurm.jobs_list())["jobs"]

    def filter_and_extract_id(jobs, state=None, partition=slurm_partition):
        return set(
            job["job_id"]
            for job in jobs
            if (
                state is None
                or state.lower()
                in set(map(lambda s: s.lower(), job["state"]["current"]))
            )
            and (partition is None or partition == job.get("partition", None))
        )

    ids_all = filter_and_extract_id(jobs)
    ids_pending = filter_and_extract_id(jobs, "pending")
    ids_running = filter_and_extract_id(jobs, "running")
    ids_completed = filter_and_extract_id(jobs, "completed")
    ids_failed = filter_and_extract_id(jobs, "failed")

    counts = dict(
        total=len(ids_all),
        pending=len(ids_pending),
        running=len(ids_running),
        completed=len(ids_completed),
        failed=len(ids_failed),
    )

    if print_status:
        # print status table
        clear_output(wait=True)
        df = pd.DataFrame(
            dict(
                total=len(ids_all),
                pending=len(ids_pending),
                running=len(ids_running),
                completed=len(ids_completed),
                failed=len(ids_failed),
            ),
            index=[f"all SLURM jobs (in '{slurm_partition}')"],
        )
        print(df)

    return counts


async def pending_and_running_slurm_jobs_in_queue(print_status: bool = False) -> int:
    counts = await slurm_jobs_in_queue(print_status=print_status)
    return counts["pending"] + counts["running"]

### Dispatch


In [21]:
# Dispatch this many simulations in one chunk before waiting for queue space.
# Note: 2 * dispatch_chunk_size has to be below the SLURM queue limit for the user.
dispatch_chunk_size = 50

# Advanced:
# In case of dispatch hiccup, skip the first n simulations.
skip_first_n_sims = 0


async def slurm_chunk_preparation(
    check_every_n_sec: int = 10 * 60,  # seconds
    print_status: bool = False,
):
    sleep(10)

    if print_status:
        print("Waiting for SLURM queue to have space...")

    # wait until enough space in queue
    while True:
        await renew_slurm_token_if_needed(show=print_status)
        if (
            await pending_and_running_slurm_jobs_in_queue(print_status=print_status)
        ) < dispatch_chunk_size:
            # queue has enough space
            break
        sleep(check_every_n_sec)

    # ensure SLURM token is fresh
    await renew_slurm_token_if_needed(show=print_status)


if slurm_enabled:
    await slurm_chunk_preparation()

# dispatch sims in chunks to avoid overfilling the queue
for idx, params in tqdm(
    enumerate(iter_rows_as_dict(param_table), 1),
    total=len(param_table),
    desc="Dispatching simulations",
):

    if idx <= skip_first_n_sims:
        continue

    await dispatch_generation_and_simulation_run(
        **params,
    )

    if idx % dispatch_chunk_size == 0:
        if slurm_enabled:
            await slurm_chunk_preparation()

if slurm_enabled:
    print("Waiting for all simulations to complete...")
    raise NotImplementedError(
        "Automatic waiting not implemented. Please wait manually."
    )

Dispatching simulations: 100%|██████████| 3/3 [03:42<00:00, 74.20s/it]


# Appendix

Automatic brake to prevent appendix from running when notebook is executed end-to-end.

In [22]:
raise RuntimeError(
    "Deliberate stop. The execution of the following cells is reserved for manual actions only."
)

RuntimeError: Deliberate stop. The execution of the following cells is reserved for manual actions only.

## Rerun a failed Simulation
E.g. a simulation may fail and needs to be restarted from the beginning.

In [None]:
async def redispatch_simulation(
    sim_id: str,
    *,
    host: str = "local",
    timeout: int = 1 * 60 * 60,  # sec
) -> None:
    await request(
        f"simulations/{sim_id}/redispatch?host={host}",
        RequestMethod.PUT,
        **(dict(timeout=int(timeout * 1.1)) if host == "local" else {}),
    )

In [None]:
await redispatch_simulation(
    "2026-01-01-00-00-00-ABCDEFGH",
)