# Custom dataset construction
This notebook narrows the training ranges and adjusts validation ranges to remain nested.

In [None]:
from copy import deepcopy
from physae.config_loader import load_data_config
from physae import config as physae_config

In [None]:
custom_cfg = load_data_config(name="default")
narrow_train_ranges = {
    "sig0": [3085.435, 3085.447],
    "dsig": [0.0015235, 0.0015335],
    "mf_CH4": [5.0e-06, 1.4e-05],
    "baseline0": [0.995, 1.005],
    "baseline1": [-0.00038, -0.00031],
    "baseline2": [-3.8e-08, -3.2e-08],
    "P": [450.0, 550.0],
    "T": [305.0, 309.0],
}
custom_cfg["train_ranges"] = {name: [float(v[0]), float(v[1])] for name, v in narrow_train_ranges.items()}
original_val_ranges = deepcopy(custom_cfg.get("val_ranges", {}))
adjusted_val_ranges = {}
for name, train_interval in custom_cfg["train_ranges"].items():
    train_min, train_max = map(float, train_interval)
    val_min, val_max = map(float, original_val_ranges.get(name, train_interval))
    adj_min = max(val_min, train_min)
    adj_max = min(val_max, train_max)
    if adj_min > adj_max:
        centre = 0.5 * (train_min + train_max)
        adj_min = adj_max = centre
    adjusted_val_ranges[name] = [adj_min, adj_max]
custom_cfg["val_ranges"] = adjusted_val_ranges
physae_config.set_norm_params({name: (float(r[0]), float(r[1])) for name, r in custom_cfg["train_ranges"].items()})
custom_cfg["val_ranges"]


In [None]:
from physae.dataset import SpectraDataset
from physae.physics import parse_csv_transitions

poly_freq_CH4 = [-2.3614803e-07, 1.2103413e-10, -3.1617856e-14]
transitions_ch4_str = """6;1;3085.861015;1.013E-19;0.06;0.078;219.9411;0.73;-0.00712;0.0;0.0221;0.96;0.584;1.12\n6;1;3085.832038;1.693E-19;0.0597;0.078;219.9451;0.73;-0.00712;0.0;0.0222;0.91;0.173;1.11\n6;1;3085.893769;1.011E-19;0.0602;0.078;219.9366;0.73;-0.00711;0.0;0.0184;1.14;-0.516;1.37\n6;1;3086.030985;1.659E-19;0.0595;0.078;219.9197;0.73;-0.00711;0.0;0.0193;1.17;-0.204;0.97\n6;1;3086.071879;1.000E-19;0.0585;0.078;219.9149;0.73;-0.00703;0.0;0.0232;1.09;-0.0689;0.82\n6;1;3086.085994;6.671E-20;0.055;0.078;219.9133;0.70;-0.00610;0.0;0.0300;0.54;0.00;0.0"""
transitions_dict = {"CH4": parse_csv_transitions(transitions_ch4_str)}

train_dataset = SpectraDataset(
    n_samples=16,
    num_points=800,
    poly_freq_CH4=poly_freq_CH4,
    transitions_dict=transitions_dict,
    sample_ranges={name: (float(r[0]), float(r[1])) for name, r in custom_cfg["train_ranges"].items()},
    strict_check=True,
    with_noise=False,
)
val_dataset = SpectraDataset(
    n_samples=8,
    num_points=800,
    poly_freq_CH4=poly_freq_CH4,
    transitions_dict=transitions_dict,
    sample_ranges={name: (float(r[0]), float(r[1])) for name, r in custom_cfg["val_ranges"].items()},
    strict_check=True,
    with_noise=False,
)
print("Train dataset mf_CH4 range:", train_dataset.sample_ranges["mf_CH4"])
print("Validation dataset mf_CH4 range:", val_dataset.sample_ranges["mf_CH4"])
print("Validation ranges ⊆ training ranges?",
      all(
          train_dataset.sample_ranges[name][0] <= val_dataset.sample_ranges[name][0]
          and val_dataset.sample_ranges[name][1] <= train_dataset.sample_ranges[name][1]
          for name in val_dataset.sample_ranges
      ))