# Plasma Equilibrium Dataset Preparation

## Introduction

This notebook follows a similar format to the Plasma Volume Dataset notebook, focusing on preparing data for the Plasma Equilibrium challenge. Like the other challenge notebooks, we access the FAIR-MAST API to collect, process, and format data into training and test sets.

The Plasma Equilibrium challenge differs by focusing on 2D poloidal flux maps rather than scalar plasma volume values. The dataset preparation follows these key steps:

1. Accessing diagnostic data from the MAST database
2. Combining inputs from multiple measurement systems
3. Creating a standardized time base for all signals
4. Using the `to_dataset` function to combine data from multiple plasma shots
5. Shuffling shot IDs with a fixed random seed (7) for reproducible results
6. Splitting data into training (5 shots) and testing (2 shots) sets
7. Keeping both signals and targets in the training set
8. Removing target values (`psi`) from the test set and saving them separately

This structured approach ensures that participants can focus on developing predictive models rather than data wrangling. We store ground truth values in a separate solution file for final validation only, simulating real-world scenarios where models predict outcomes for unseen data.

In [None]:
# uncomment and run when installing in a Colab notebook
# !uv pip install git+https://github.com/Simon-McIntosh/data-science-challenges.git

In [5]:
from importlib import resources
import pathlib

import numpy as np
import pandas as pd
import xarray as xr

In [6]:
def to_dask(shot: int, group: str, level: int = 2) -> xr.Dataset:
    """Return a Dataset from the MAST Zarr store."""
    return xr.open_zarr(
        f"https://s3.echo.stfc.ac.uk/mast/level{level}/shots/{shot}.zarr",
        group=group,
    )


In [7]:
def to_dataset(shots: pd.Series):
    """Return a concatenated xarray Dataset for a series of input shots."""
    dataset = []
    for shot_index, shot_id in shots.items():
        target = to_dask(shot_id, "equilibrium")['psi']
        signal = []
        for group in ["magnetics", "spectrometer_visible", "soft_x_rays", "thomson_scattering"]: 
            data = to_dask(shot_id, group).interp({"time": target.time})
            if "major_radius" in data:
                data = data.interp({"major_radius": target.major_radius})
            other_times = set()
            for var in data.data_vars:  # Interpolate to the target time
                time_dim = next((dim for dim in data[var].dims 
                                 if dim.startswith('time')), 'time')
                if time_dim != "time":
                    other_times.add(time_dim)
                data[var] = data[var].interp({time_dim: target.time})               
                data[var] = data[var].transpose("time", ...)
                data[var].attrs |= {"group": group}
            data = data.drop_vars(other_times)
            signal.append(data)
        signal = xr.merge(signal, combine_attrs="drop_conflicts")
        signal["shot_index"] = "time", shot_index * np.ones(target.sizes["time"], int)
        dataset.append(xr.merge([signal, target], combine_attrs="override"))
    return xr.concat(dataset, "time", join="override", combine_attrs="drop_conflicts")


In [8]:
source_ids = np.array([15585, 15212, 15010, 14998, 30410, 30418, 30420])

rng = np.random.default_rng(7)
rng.shuffle(source_ids)
source_ids = pd.Series(source_ids)

split_ids = {
    "train": source_ids[:5],
    "test": source_ids[5:],
}

dataset = {mode: to_dataset(shot_ids) for mode, shot_ids in split_ids.items()}

# extract solution
psi = dataset["test"].psi.data.reshape((dataset["test"].sizes["time"], -1))
solution = pd.DataFrame(psi)
solution.index.name = "index"
shot_index = dataset["test"].shot_index.data
solution["Usage"] = [{5: "Public", 6: "Private"}.get(index) for index in shot_index]
# delete solution from test file
dataset["test"] = dataset["test"].drop_vars("psi")

# write to file
pkg_path = resources.files("data_science_challenges")
data_path = pkg_path / "fair_mast_data" / "plasma_equilibrium"
data_path.mkdir(exist_ok=True)
dataset["train"].to_netcdf(data_path / "train.nc")
dataset["test"].to_netcdf(data_path / "test.nc")
solution.to_csv(data_path / "solution.csv")