In [1]:
import xarray as xr
import numpy as np

from pathlib import Path
import os


In [2]:
# Open the MAST dataset

ds = xr.open_dataset("mast_magnetics_data_constant_time.nc")
ds

In [19]:
def build_and_save_npz_from_ds(
        ds: xr.Dataset, 
        var_name: str, 
        channel: int | str | None = None, 
        val_ratio: float = 0.2, 
        filename: str = 'dataset_from_ds.npz', 
        out_dir: str = 'data/uploaded', 
        normalize: bool = True, 
        coef: float = 2.0, 
        seed: int = 42
        ) -> str:
    """
    Utility: build and save a .npz dataset from an xarray Dataset
             This function extracts a 2D (shots x time) array from `ds[var_name]` and optionally
             selects one channel when the variable has a channel dimension.

    Parameters:
        - ds: xarray.Dataset
        - var_name: str, name of the DataArray in ds to extract (e.g. 'ip')
        - channel: optional int or str. If int -> select by index on the 3rd dim.
            If str -> select by coordinate label on the 3rd dim. If None -> use first channel or the 2D array as-is.
        - val_ratio: fraction of samples to mark as validation (default 0.2)
        - filename: output filename (if no path provided it will be created in current dir)
        - out_dir: directory under which to save the file if filename is a basename
        - normalize: whether to perform a simple min-max normalization (global)
        - coef: scale factor applied after normalization

    Returns: path to saved .npz file
    """
    if var_name not in ds:
        raise KeyError(f"Variable '{var_name}' not found in dataset")

    da = ds[var_name]

    # Only support 2D (shots x time/features) or 3D (shots x time x channel)
    if da.ndim == 3:
        # decide which dim is channel (assume last)
        channel_dim = da.dims[2]
        if channel is None:
            # default: take first channel index
            sel_da = da.isel({channel_dim: 0})
        else:
            # channel provided: either int index or label
            if isinstance(channel, int):
                sel_da = da.isel({channel_dim: int(channel)})
            elif isinstance(channel, str):
                try:
                    # try selection by label (works if coordinate values are strings)
                    sel_da = da.sel({channel_dim: channel})
                except Exception:
                    # fallback: try to find matching value in coordinate values and use isel
                    coord_vals = list(map(lambda x: str(x), list(da.coords[channel_dim].values)))
                    if channel in coord_vals:
                        idx = coord_vals.index(channel)
                        sel_da = da.isel({channel_dim: idx})
                    else:
                        raise ValueError(f"Channel label '{channel}' not found in coords for dim '{channel_dim}'")
            else:
                raise TypeError('channel must be int, str or None')
        data = np.array(sel_da.values)
        # if after channel selection we still have >2 dims, flatten the trailing dims
        if data.ndim > 2:
            data = data.reshape(data.shape[0], -1)

    elif da.ndim == 2:
        data = np.array(da.values)

    else:
        # Per your note: do not attempt to handle arbitrary higher dims here.
        raise ValueError('DataArray must be 2D or 3D (shots, time[, channel]); higher dims are not supported in this helper')

    # Now data should be 2D: (n_samples, n_time/features)
    if data.ndim != 2:
        raise ValueError(f'Extracted array has unexpected ndim={data.ndim}; expected 2')

    # Handle NaNs: replace with 0
    if np.isnan(data).any():
        data[np.isnan(data)] = 0

    # Ensure numeric dtype
    data = data.astype(np.float32)

    # Optional normalization (global min-max * coef)
    if normalize:
        amin = data.min()
        amax = data.max()
        if amax - amin != 0:
            data = (data - amin) / (amax - amin) * coef
        else:
            data = data - amin      # avoid division by zero

    n_samples = data.shape[0]

    # Create labels (default zeros)
    labels = np.zeros(n_samples, dtype=int)

    # Create is_train boolean mask (reproducible)
    rng = np.random.default_rng(seed)
    is_train = np.ones(n_samples, dtype=bool)
    n_val = max(1, int(n_samples * float(val_ratio)))
    val_idx = rng.choice(n_samples, size=n_val, replace=False)
    is_train[val_idx] = False

    # Prepare path and save
    if not os.path.isabs(filename):
        os.makedirs(out_dir, exist_ok=True)
        filepath = os.path.join(out_dir, filename)
    else:
        filepath = filename

    np.savez_compressed(filepath, data=data, labels=labels, is_train=is_train)

    print(f"Saved dataset to: {filepath}")
    print('data.shape =', data.shape)
    print('labels.shape =', labels.shape)
    print('n_train =', int(is_train.sum()), 'n_val =', int((~is_train).sum()))

    return filepath


# Example usage:
# fp = build_and_save_npz_from_ds(ds, 'ip', channel=None, filename='ip_dataset_from_notebook.npz')


In [20]:
out_directory = Path().absolute()

In [21]:
ip_fp = build_and_save_npz_from_ds(ds, 'ip', channel=None, filename='ip_dataset_from_notebook.npz', out_dir=out_directory)

Saved dataset to: /home/ITER/brussel/Documents/ITER-autoencoder-for-labelling/notebooks/ip_dataset_from_notebook.npz
data.shape = (6102, 1000)
labels.shape = (6102,)
n_train = 4882 n_val = 1220


In [22]:
ccbv1_fp = build_and_save_npz_from_ds(ds, 'b_field_pol_probe_ccbv_field', channel='AMB_CCBV01', filename='ccbv1_dataset_from_notebook.npz', out_dir=out_directory)

Saved dataset to: /home/ITER/brussel/Documents/ITER-autoencoder-for-labelling/notebooks/ccbv1_dataset_from_notebook.npz
data.shape = (6102, 1000)
labels.shape = (6102,)
n_train = 4882 n_val = 1220
