In [None]:
# NOTE: This file explains every line using simple words.
# To actually run it, you also need:
#   import numpy as np
#   import os
#   DATA_PATH = "/path/to/folder/with/csvs"
# Each CSV is expected to have:
#   - Many spectral columns (wavenumber columns) first
#   - Then 4 label columns next (positions -5:-1)
#   - Then 1 column at the very end with cross-validation fold index (position -1)

import pandas as pd  # Load the pandas library and call it "pd" to read and handle CSV files

# List of device/dataset names we plan to use
dataset_names = ['anton_532', 'anton_785', 'kaiser', 'mettler_toledo', 'metrohm', 'tec5', 'timegate', 'tornado']

# For each dataset, define the smallest wavenumber we allow
lower_bounds = {
    'anton_532': 200,
    'anton_785': 100,
    'kaiser': -37,
    'mettler_toledo': 300,
    'metrohm': 200,
    'tec5': 85,
    'timegate': 200,
    'tornado': 300,
}

# For each dataset, define the largest wavenumber we allow
upper_bounds = {
    'anton_532': 3500,
    'anton_785': 2300,
    'kaiser': 1942,
    'mettler_toledo': 3350,
    'metrohm': 3350,
    'tec5': 3210,
    'timegate': 2000,
    'tornado': 3300,
}

def get_csv_dataset(
    dataset_name,      # text name of the dataset (must match a CSV file name)
    lower_wn=-1000,    # requested minimum wavenumber (we will clamp to device-specific bounds)
    upper_wn=10000,    # requested maximum wavenumber (we will clamp to device-specific bounds)
    dtype=None,        # desired numeric dtype, default will be np.float64
):
    # Make sure the lower bound is not smaller than what the device supports
    lower_wn = max(lower_wn, lower_bounds[dataset_name])
    # Make sure the upper bound is not larger than what the device supports
    upper_wn = min(upper_wn, upper_bounds[dataset_name])
    # If dtype is not given, use np.float64 (NOTE: np must be imported)
    dtype = dtype or np.float64

    # Read the CSV file for this dataset from the DATA_PATH folder
    # The file is expected to be named "<dataset_name>.csv"
    df = pd.read_csv(
        os.path.join(
            DATA_PATH,
            '%s.csv' % dataset_name,
        ),
    )

    # Figure out which spectral columns (by wavenumber) fall inside [lower_wn, upper_wn]
    # df.columns[:-5] are the spectral columns (strings); convert to float wavenumbers
    spectra_selection = np.logical_and(
        lower_wn <= np.array([float(one) for one in df.columns[:-5]]),
        np.array([float(one) for one in df.columns[:-5]]) <= upper_wn,
    )

    # Pull out the spectral matrix (all rows, spectral columns only, then slice columns we selected)
    # .values gives us a NumPy array
    spectra = df.iloc[:, :-5].iloc[:, spectra_selection].values
    # Pull out the label matrix (the 4 columns right before the last column)
    label = df.iloc[:, -5:-1].values
    # Pull out the last column as cross-validation fold indices
    cv_indices = df.iloc[:, -1].values
    # All row indices as a NumPy array, e.g., [0, 1, 2, ..., N-1]
    all_indices = np.array(range(len(cv_indices)))

    # Build cross-validation folds:
    # For each unique fold id, create a pair (train_indices, val_indices)
    cv_folds = [
        (
            # training indices are all rows not equal to this fold id
            all_indices[cv_indices != fold_idx],
            # validation indices are all rows equal to this fold id
            all_indices[cv_indices == fold_idx],
        )
        # range(len(set(cv_indices))) assumes fold ids are 0..K-1 without gaps
        for fold_idx in range(len(set(cv_indices)))
    ]
    
    # Save the wavenumbers (as floats) for the selected spectral columns
    wavenumbers = np.array([
        float(one) for one in df.columns[:-5]
    ])[spectra_selection]

    # Return:
    # - spectra: spectral matrix (num_samples x num_selected_wavenumbers)
    # - label: label matrix (num_samples x 4)
    # - None: placeholder (could be for extra metadata)
    # - cv_folds: list of (train_idx, val_idx) pairs
    # - wavenumbers: the wavenumber axis corresponding to spectra columns
    return (
        spectra.astype(dtype),
        label.astype(dtype),
        None,
        cv_folds,
        wavenumbers.astype(dtype)
    )

def load_joint_dataset(
    dataset_names,          # list of dataset names to load and merge
    lower_wn=-1000,         # requested min wavenumber across all devices
    upper_wn=10000,         # requested max wavenumber across all devices
    dtype=None,             # desired numeric dtype, default np.float64
    leave_out_one_device=False,  # if True, each device becomes its own validation split
):
    # If dtype is not given, use np.float64 (NOTE: np must be imported)
    dtype = dtype or np.float64

    # Compute the common wavenumber window that all devices share:
    # the max of all lower bounds
    lower_wn = max(
        lower_wn,
        *[lower_bounds[name] for name in dataset_names])
    # and the min of all upper bounds
    upper_wn = min(
        upper_wn,
        *[upper_bounds[name] for name in dataset_names]
    )

    # Print the final chosen wavenumber limits
    print("Lower WN: ", lower_wn)
    print("Upper WN: ", upper_wn)

    # Load each dataset using the same clamped window and dtype
    datasets = [
        get_csv_dataset(
            dataset_name,
            lower_wn=lower_wn,
            upper_wn=upper_wn,
            dtype=dtype,
        )
        for dataset_name in dataset_names
    ]

    # Create a common integer wavenumber grid from lower_wn to upper_wn inclusive
    joint_wns = np.arange(lower_wn, upper_wn + 1)
    print("Joint WNS: ", joint_wns)
    
    # For each dataset, interpolate each spectrum onto the common wavenumber grid
    # np.interp expects:
    #   xp = original wavenumbers (wns)
    #   fp = original spectrum values
    #   joint_wns = new x positions to sample at
    interpolated_data = [
        np.array([
            np.interp(
                joint_wns,
                xp=wns,
                fp=spectrum,
            )
            for spectrum in spectra
        ])
        for spectra, _, _, _, wns in datasets
    ]
    
    # Normalize each spectrum by its own maximum, then stack all datasets together
    normed_spectra = np.concatenate(
        [
            spectra / np.max(spectra)
            for spectra in interpolated_data
        ],
        axis=0,
    )
    
    # Compute starting index offsets for each dataset in the stacked array
    # Example: [0, N0, N0+N1, N0+N1+N2, ...]
    dataset_offsets = np.concatenate(
        [
            [0],
            np.cumsum([len(one[0]) for one in datasets])[:-1]
        ]
    )

    # Total number of spectra across all datasets
    num_items = sum((len(one[0]) for one in datasets))
    if leave_out_one_device:
        # Validation indices are entire device blocks (leave-one-device-out CV)
        val_indices = [
            np.arange(start, end, 1)
            for start, end in zip(
                dataset_offsets,
                np.concatenate([dataset_offsets[1:], np.array([num_items])])
            )
        ]
    else:
        # Otherwise, use the per-row CV folds inside each dataset,
        # but shift (offset) them to match the global stacked indices
        val_indices = [
            val_idxs + offset
            for one, offset in zip(datasets, dataset_offsets)
            for train_idxs, val_idxs in one[3]
        ]

    # All global indices as a set for easy subtraction
    all_indices = set(range(num_items))

    # Build final CV folds as (train_indices, val_indices) pairs in global indexing
    cv_folds = [
        (np.array(list(all_indices - set(val_idxs))), val_idxs)
        for val_idxs in val_indices
    ]
    return (
        # Stacked, interpolated, per-spectrum-max-normalized spectra
        normed_spectra,
        # Stacked labels; only first 3 columns kept here (note: source had 4)
        np.concatenate([one[1] for one in datasets])[:, :3],
        # Cross-validation folds in global indices
        cv_folds,
        # Boundaries where each dataset starts in the stacked array (length = num_datasets + 1)
        np.concatenate(
            [
                [0],
                np.cumsum([len(one[0]) for one in datasets])
            ]
        ),
    )
