In [7]:
import xarray as xr
import numpy as np

from pathlib import Path


In [8]:
# Open the MAST dataset

ds = xr.open_dataset("mast_magnetics_data_constant_time.nc")
ds

In [9]:
# Build a .npz dataset from the xarray `ip` variable
# Assumes `ds` is already defined in the notebook and contains a DataArray 'ip'
import os
import numpy as np

# extract ip DataArray
ip_da = ds['ip']  # xarray DataArray with dims (shot_id, time)

# Convert to numpy array: (n_shots, n_time)
data = ip_da.values

# If there are NaNs, replace them by the per-time mean (column mean)
if np.isnan(data).any():
    col_mean = np.nanmean(data, axis=0)
    # Where NaN, fill with corresponding column mean
    inds = np.where(np.isnan(data))
    if inds[0].size > 0:
        data[inds] = np.take(col_mean, inds[1])

# Ensure a numeric float dtype suitable for the model
data = data.astype(np.float32)
# normalize data (optional, depending on use case): min-max scaling
data = (data - data.min(axis=1, keepdims=True)) / (data.max(axis=1, keepdims=True) - data.min(axis=1, keepdims=True))

n_samples = data.shape[0]

# Create labels (no anomaly labels available): default to zeros
labels = np.zeros(n_samples, dtype=int)

# Create is_train boolean mask with a reproducible random split
val_ratio = 0.2
rng = np.random.default_rng(42)
is_train = np.ones(n_samples, dtype=bool)
n_val = max(1, int(n_samples * val_ratio))
val_idx = rng.choice(n_samples, size=n_val, replace=False)
is_train[val_idx] = False

# Prepare output path and save as .npz with keys expected by the app
#os.makedirs('data/uploaded', exist_ok=True)
filename = 'ip_dataset_from_notebook.npz'
#filepath = os.path.join('data/uploaded', filename)
filepath = filename

np.savez_compressed(filepath, data=data, labels=labels, is_train=is_train)

print(f"Saved dataset to: {filepath}")
print('data.shape =', data.shape)
print('labels.shape =', labels.shape)
print('n_train =', int(is_train.sum()), 'n_val =', int((~is_train).sum()))


Saved dataset to: ip_dataset_from_notebook.npz
data.shape = (6102, 1000)
labels.shape = (6102,)
n_train = 4882 n_val = 1220
