In [25]:
import functools
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import Dropdown, IntRangeSlider, Checkbox, VBox, interactive_output
from IPython.display import Markdown, display

plt.rcParams['figure.figsize'] = (10, 6)
candidates = [pathlib.Path.cwd(), *pathlib.Path.cwd().parents]
DATA_DIR = pathlib.Path('data')
for candidate in candidates:
    if (candidate / 'data').is_dir():
        DATA_DIR = candidate / 'data'
        break
SHOT_FILES = {
    '24209': '24209_torax_training.npz',
    '24210': '24210_torax_training.npz',
    '24211': '24211_torax_training.npz',
}
CONTROL_NAMES = ['P_nbi', 'Ip', 'nebar', 'S_gas', 'S_rec', 'S_nbi']

In [26]:
@functools.lru_cache(None)
def load_shot(shot_key):
    path = DATA_DIR / SHOT_FILES[shot_key]
    data = np.load(path)
    controls = {name: data[name] for name in CONTROL_NAMES}
    return {
        't_raw': data['t'],
        't_ts': data['t_ts'],
        'rho': data['rho'],
        'Te': data['Te'],
        'ne': data['ne'],
        'controls': controls,
    }

In [27]:
def verify_training_packs():
    display(Markdown('## Training pack verification'))
    for shot_key, filename in SHOT_FILES.items():
        path = DATA_DIR / filename
        display(Markdown(f'### Shot {shot_key}'))
        if not path.exists():
            display(Markdown(f'*Missing pack:* {path}'))
            continue
        with np.load(path) as pack:
            te = pack['Te']
            ne = pack['ne']
            t_len = pack['t'].size
            ts_len = pack['t_ts'].size
            S_rec = pack.get('S_rec', np.zeros_like(pack['t']))
            controls = [name for name in CONTROL_NAMES if name in pack]
            control_summary = ', '.join(controls) if controls else 'none'
            display(Markdown(
                f'- summary time samples: {t_len}\n'
                f'- Te sampling grid length: {ts_len}\n'
                f'- Te grid: {te.shape}, min {np.nanmin(te):.1f}, mean {np.nanmean(te):.1f}, max {np.nanmax(te):.1f}\n'
                f'- ne grid: {ne.shape}, min {np.nanmin(ne):.1e}, mean {np.nanmean(ne):.1e}, max {np.nanmax(ne):.1e}\n'
                f'- S_rec range: {np.nanmin(S_rec):.1e} to {np.nanmax(S_rec):.1e} (nonzero {np.count_nonzero(S_rec)})\n'
                f'- controls included: {control_summary}'
            ))

verify_training_packs()

## Training pack verification

### Shot 24209

- summary time samples: 2200
- Te sampling grid length: 110
- Te grid: (110, 65), min 4.2, mean 89.7, max 284.7
- ne grid: (110, 65), min 3.1e+17, mean 1.0e+19, max 2.5e+19
- S_rec range: 0.0e+00 to 5.4e+00 (nonzero 1752)
- controls included: P_nbi, Ip, nebar, S_gas, S_rec, S_nbi

### Shot 24210

- summary time samples: 2353
- Te sampling grid length: 118
- Te grid: (118, 65), min 4.5, mean 63.9, max 138.2
- ne grid: (118, 65), min 2.9e+17, mean 1.3e+19, max 5.1e+19
- S_rec range: 0.0e+00 to 4.8e+00 (nonzero 1896)
- controls included: P_nbi, Ip, nebar, S_gas, S_rec, S_nbi

### Shot 24211

- summary time samples: 2062
- Te sampling grid length: 104
- Te grid: (104, 65), min 5.6, mean 84.9, max 172.0
- ne grid: (104, 65), min 4.4e+17, mean 1.4e+19, max 5.3e+19
- S_rec range: 0.0e+00 to 5.2e+00 (nonzero 1607)
- controls included: P_nbi, Ip, nebar, S_gas, S_rec, S_nbi

In [28]:
def plot_shot_series(shot_key, time_slice, show_controls):
    shot = load_shot(shot_key)
    start, end = time_slice
    start = max(0, start)
    end = min(len(shot['t_ts']) - 1, end)
    if start > end:
        start, end = end, start
    t_ts = shot['t_ts']
    dt = np.diff(t_ts)
    duration = t_ts[-1] - t_ts[0]
    info = (
        f"**Shot:** {shot_key}  \n"
        f"**Duration:** {duration:.3f} s  \n"
        f"**Mean sampling:** {dt.mean():.4f} s (std {dt.std():.4f})"
    )
    display(Markdown(info))
    node_indices = {
        'core': 0,
        'mid': shot['rho'].size // 2,
        'edge': shot['rho'].size - 1,
    }

    def draw_series_grid(series_map, title, ylabel):
        if not series_map:
            return
        rows = len(series_map)
        fig, axs = plt.subplots(rows, 1, sharex=True, figsize=(10, 2.4 * rows))
        if not isinstance(axs, np.ndarray):
            axs = np.array([axs])
        for ax, (label, values) in zip(axs, series_map.items()):
            ax.plot(t_ts, values, color='C0')
            ax.set_ylabel(f"{ylabel} ({label})")
            ax.axvspan(t_ts[start], t_ts[end], color='0.94', zorder=-1)
            ax.grid(True, linewidth=0.5, alpha=0.3)
        axs[-1].set_xlabel('Time (s)')
        fig.suptitle(title, fontsize=12)
        fig.tight_layout()
        fig.subplots_adjust(top=0.9)

    te_series = {
        alias: shot['Te'][:, idx]
        for alias, idx in node_indices.items()
    }
    draw_series_grid(te_series, 'Electron temperature across nodes', 'Electron temp (eV)')

    ne_series = {
        alias: shot['ne'][:, idx]
        for alias, idx in node_indices.items()
    }
    draw_series_grid(ne_series, 'Electron density across nodes', 'Electron density (m^-3)')

    if show_controls:
        controls_interp = {
            name: np.interp(t_ts, shot['t_raw'], shot['controls'][name])
            for name in CONTROL_NAMES
        }
        draw_series_grid(controls_interp, 'Control signals (interpolated)', 'Control signal (arb)')

shot_dropdown = Dropdown(options=list(SHOT_FILES.keys()), description='Shot')
ts_length = len(load_shot(list(SHOT_FILES)[0])['t_ts'])
time_slider = IntRangeSlider(value=(0, ts_length - 1), min=0, max=ts_length - 1, description='Time slice')
controls_checkbox = Checkbox(value=True, description='Show controls')
ui = VBox([shot_dropdown, time_slider, controls_checkbox])
output = interactive_output(plot_shot_series, {
    'shot_key': shot_dropdown,
    'time_slice': time_slider,
    'show_controls': controls_checkbox,
})
display(Markdown('### Shot controls & timing'))
display(ui, output)

### Shot controls & timing

VBox(children=(Dropdown(description='Shot', options=('24209', '24210', '24211'), value='24209'), IntRangeSlideâ€¦

Output()