# Choice of MAST shots to load for scrapping

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr
import pathlib
import tqdm
from functools import partial
import os
import psutil
import time
import json

## Selection of shot_index and variable_channels with no NaN

In [None]:
def to_dask(shot: int, group: str, level: int = 2) -> xr.Dataset:
    """
    Return a Dataset from the MAST Zarr store.

    Parameters
    shot: Shot ID to retrieve data for.
    group: Diagnostic group to retrieve data from.
    level: Data level to retrieve (default is 2).
    """
    return xr.open_zarr(
        f"https://s3.echo.stfc.ac.uk/mast/level{level}/shots/{shot}.zarr",
        group=group,
    )

In [None]:
def retry_to_dask(shot_id, group, retries=3, delay=1):
    """
    Retry loading a shot's data as a Dask Dataset with exponential backoff.

    Parameters
    shot_id: Shot ID to retrieve data for.
    group: Diagnostic group to retrieve data from.
    retries: Number of retry attempts (default is 3).
    delay: Delay in seconds between retries (default is 5).

    Returns
    xr.Dataset
        The Dask Dataset for the specified shot and group.
    or Error
    """
    for attempt in range(retries):
        try:
            return to_dask(shot_id, group)
        except Exception as e:
            if attempt < retries - 1:
                print(f"Retrying connection to {shot_id} in group {group} (attempt {attempt + 1}/{retries})")
                time.sleep(delay)
            else:
                raise e

In [None]:
def process_shot(shot_id, group, verbose=False):
    """
    Process a single shot, returning a dictionary of results.

    Parameters
    shot_id: Shot ID to retrieve data for.
    group: Diagnostic group to retrieve data from.
    verbose: bool
        Displays messages for debugging or tracking.
    
    Returns
    shot_id: int
        The shot ID processed.
    shot_result: dict
        Dictionary containing the presence of variables and their channels:
    """
    shot_result = {}
    try:
        ds = retry_to_dask(shot_id, group)
        shot_vars = set(ds.data_vars)

        for var in shot_vars:
            da = ds[var]
            time_dims = [dim for dim in da.dims if 'time' in dim.lower()]
            
            if not time_dims:
                if verbose:
                    print(f"Skipping {var} with no time dimension: {da.dims}")
                continue
            
            time_dim = time_dims[0]
            other_dims = [dim for dim in da.dims if dim != time_dim]

            # 1. Variable 1D
            if not other_dims:
                key = var
                try:
                    if np.issubdtype(da.dtype, np.floating):
                        has_valid = da.notnull().any().compute()
                        shot_result[key] = True if has_valid else None
                    else:
                        # For non-floating types, we assume presence if the variable exists.
                        shot_result[key] = True
                except Exception as e:
                    if verbose:
                        print(f"Error processing {var} in shot {shot_id}: {e}")
                    shot_result[key] = False

            # 2. Variable 2D
            else:
                for dim in other_dims:
                    if dim not in ds.coords:
                        if verbose:
                            print(f"Warning: {dim} not found for {var} in shot {shot_id}")
                        continue
                    
                    try:
                        if np.issubdtype(da.dtype, np.floating):
                            channel_has_data = da.notnull().any(dim=time_dim)
                        else:
                            # For non-floating types, we assume presence if the variable exists.
                            channel_has_data = xr.ones_like(ds[dim], dtype=bool)
                        
                        channel_results = channel_has_data.compute()
                        
                        for coord_val in ds[dim].values:
                            key = f"{var}::{coord_val}"
                            try:
                                chan_valid = channel_results.sel({dim: coord_val}).item()
                                shot_result[key] = True if chan_valid else None
                            except Exception as e:
                                if verbose:
                                    print(f"Error accessing {coord_val} in {var}: {e}")
                                shot_result[key] = False
                    
                    except Exception as e:
                        if verbose:
                            print(f"Error processing {var} in shot {shot_id}: {e}")
                        for coord_val in ds[dim].values:
                            key = f"{var}::{coord_val}"
                            shot_result[key] = False

    except Exception as e:
        if verbose:
            print(f"Error processing shot {shot_id}: {e}")
        shot_result = {}

    return shot_id, shot_result



In [None]:
def get_optimal_workers(task_type="cpu"):
    """
    Get the optimal number of workers for parallel processing based on the task type.

    Parameters
    task_type: str
        Type of task to determine the optimal number of workers (only possibility is "cpu").

    Returns the optimal number of workers based on the task type.
    """
    cpu_logical = os.cpu_count()
    
    if task_type == "cpu":
        try:
            return psutil.cpu_count(logical=False)
        except:
            return max(1, cpu_logical // 2)  # Default to half of logical cores
    else:
        return max(1, cpu_logical - 1) # Default to one less than logical cores
    

In [None]:
def check_variable_presence_parallel(
    shots: list[int],
    group: str,
    verbose: bool = False,
    max_workers: int = None
) -> pd.DataFrame:
    
    """
    Parallelized version with ThreadPoolExecutor.
    Check presence of all variables (and their channels) across shots in a group.

    Parameters
    shots: list of int
        List of shot IDs to be processed.
    group: str
        Diagnostic group to which the variable belongs (e.g., “magnetics,” “summary,” etc.).
    verbose: bool
        Displays messages for debugging or tracking.
    max_workers: int, optional
        Maximum number of workers to use for parallel processing. If None, uses an optimal value based on get_optimal_worker.
    """
    if max_workers is None:
        max_workers = get_optimal_workers()
    print(f"Use of {max_workers} workers (type: cpu)")
    
    var_presence = {}
    seen_shots = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = []
        for shot_id in shots:
            futures.append(executor.submit(partial(process_shot, group=group, verbose=verbose), shot_id))
        
        for future in tqdm.tqdm(
            concurrent.futures.as_completed(futures),
            total=len(shots),
            desc="Checking variables (parallel)"
        ):
            shot_id, shot_result = future.result()
            seen_shots.append(shot_id)
            
            for key, present in shot_result.items():
                if key not in var_presence:
                    var_presence[key] = {s: False for s in seen_shots}
                var_presence[key][shot_id] = present
            
            # Mark missing variables/channels as False: case where there was a "continue" in process_shot.
            for key in var_presence:
                if shot_id not in var_presence[key]:
                    var_presence[key][shot_id] = False

    df = pd.DataFrame(var_presence).T
    df = df.reindex(sorted(df.columns), axis=1)
    
    return df

In [None]:
URL = 'https://mastapp.site'
shots_disappearance = pd.read_parquet(f'{URL}/parquet/level2/shots')
sorted_disappearance = shots_disappearance.sort_values("timestamp")
shots = sorted_disappearance['shot_id'].tolist()
print("Number of shots: ", len(shots))

In [None]:
group = 'magnetics'

In [None]:
variable_presence_all = check_variable_presence_parallel(
    shots=shots, 
    group=group, 
    verbose=True,
    max_workers=None)

In [None]:
file_path = f"notebooks/result_files/all_shots_{group}"
path = pathlib.Path().absolute().parent.parent.parent / file_path / f"variable_presence_all_shots_{group}.csv"
variable_presence_all.to_csv(path, index=True)

#### Plot results

In [None]:
# Visualization of variable presence
from matplotlib.colors import ListedColormap


def plot_variable_presence(variable_presence_all: pd.DataFrame, plot: bool = True, register: bool = False, register_path: str = "foo.png") -> None:
    """
    Plot the presence of variables across all shots in a group.

    Parameters
    variable_presence_all: pd.DataFrame
        DataFrame with variable presence across shots.
    group: str
        Diagnostic group name (e.g., "magnetics").
    plot: bool
        If True, display the plot.
    register: bool
        If True, save the plot to a file.
    register_name: str
        Name of the file to save the plot (if register is True).

    Returns
    None    
    """
    
    # Create an integer matrix for colormapping: 0 = False, 1 = True
    plot_df = variable_presence_all.fillna(False).infer_objects(copy=False)
    plot_matrix = plot_df.to_numpy()
    int_matrix = plot_matrix.astype(int)

    # Dimensions
    n_vars, n_shots = int_matrix.shape

    # Colors: 0 = red, 1 = green
    cmap = ListedColormap(["#4caf50", "#f44336"])  # green for True, red for False

    # Plot
    fig, ax = plt.subplots(figsize=(n_shots * 0.01, max(16, n_vars * 0.25)))
    im = ax.imshow(~int_matrix, aspect='auto', cmap=cmap, interpolation='none')


    ax.set_yticks(np.arange(n_vars))
    ax.set_yticklabels(variable_presence_all.index)
    ax.set_xticks(np.linspace(0, n_shots - 1, min(n_shots, 1000), dtype=int))
    ax.set_xticklabels([variable_presence_all.columns[i] for i in np.linspace(0, n_shots - 1, min(n_shots, 1000), dtype=int)], rotation=90)
    ax.set_xlabel("Shot ID")
    ax.set_ylabel("Variable::Channel")
    ax.set_title("Variable Presence Across Shots (green = present, red = missing)")
    ax.grid(False, axis='x')
    ax.set_yticks(np.arange(n_vars + 1) - 0.5, minor=True)
    ax.grid(True, axis='y', which='minor', color='gray', linestyle='--')

    if register:
        plt.savefig(register_path, dpi="figure")
        print(f"Plot saved to {register_path}")
    if plot:
        plt.show()
    fig.close()
    return None

In [None]:
register_path = pathlib.Path().absolute().parent.parent.parent / "results/figures" / f"variable_presence_{group}.png"
plot_variable_presence(variable_presence_all, plot=True, register=True, register_path=register_path)

## Selection of shots and variable channels in good health

In [None]:
def check_variable_presence_all(path: str, group: str = group, shot_threshold: float = 0.99, var_threshold: float = 0.8) -> list:
    """
    Select the good shots and variables::channels to avoid the presence of too many NaNs in the dataset.

    Parameters
    path: str
        Path to the CSV file containing variable presence data.

    Returns
    list(
        good_shots: list of shot IDs that are present in enough variables.
        bad_shots: list of shot IDs that are missing too many variables.
        good_vars: list of variable::channel names that are present in enough shots.
        bad_vars: list of variable::channel names that are missing too many shots.)
    """
    # Load the variable presence DataFrame
    path = pathlib.Path().absolute().parent.parent.parent / path
    df = pd.read_csv(path, index_col=0)
    df = df.fillna(False).infer_objects(copy=False)

    # Then delete the variables::channels for which at least 50% of the content is True. And store them in an array.
    bad_vars = df.index[df.sum(axis=1) < (len(df.columns) * var_threshold)].tolist()
    df = df.drop(index=bad_vars)
    plot_variable_presence(df, group=group, plot=True, register=False)

    # First delete the shots for which all variables are False. And store them in an array.
    bad_shots = df.columns[df.sum(axis=0) < (len(df.index) * shot_threshold)].tolist()
    df = df.drop(columns=bad_shots)
    plot_variable_presence(df, group=group, plot=True, register=True, register_name=f'variable_presence_{group}_filtered.png')


    return [
        df.columns.tolist(),    # Good shots
        bad_shots,              # Bad shots
        df.index.tolist(),      # Good variables::channels
        bad_vars                # Bad variables::channels
        ]

In [None]:
good_shot_ids, bad_shot_ids, good_vars_ids, bad_vars_ids = check_variable_presence_all(
    path=f"notebooks/result_files/all_shots_{group}/variable_presence_all_shots_{group}.csv",
    group=group,
    shot_threshold=0.99,
    var_threshold=0.8)

print(f"Number of bad variable::channel IDs: {len(bad_vars_ids)}")
print("Bad variable::channel IDs:", bad_vars_ids[:25])

print(f"Number of bad shot IDs: {len(bad_shot_ids)}")
print("Bad shot IDs:", bad_shot_ids[:25])


In [None]:
data = {
    "good_shot_ids": good_shot_ids,
    "bad_shot_ids": bad_shot_ids,
    "goobad_vars_ids": good_vars_ids,
    "bad_vars_ids": bad_vars_ids,
}

path = pathlib.Path().absolute().parent.parent.parent / f"notebooks/result_files/all_shots_{group}"
file_name = path / f"result_lists{group}.json"
with open(file_name, "w") as f:
    json.dump(data, f, indent=4)