In [1]:
import tempfile
import shutil
from pprint import pprint
import yaml
import sys
from pathlib import Path

# Make sure the project root (the folder that contains "src/") is on sys.path
project_root = Path.cwd().parent  # since the notebook is inside /notebooks/
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from src.utils.config_reader import ConfigReader


# ----------------------------------------------------------------------
# Utility helpers
# ----------------------------------------------------------------------

def write_yaml(path: Path, data: dict):
    with open(path, "w") as f:
        yaml.safe_dump(data, f)


def make_base_config(tmpdir: Path):
    """Generate a fully valid config structure based on your current schema."""
    (tmpdir / "data/raw/pd").mkdir(parents=True, exist_ok=True)
    (tmpdir / "data/raw/lgd").mkdir(parents=True, exist_ok=True)

    data = {
        "split": {"test_size": 0.2, "val_size": 0.2, "cv_splits": 3, "seed": 42, "row_limit": 1000},
        "paths": {"pd_dir": "data/raw/pd", "lgd_dir": "data/raw/lgd"},
        "dataset_pd": {"0001.gmsc": True},
        "dataset_lgd": {"0001.heloc": False},
    }

    experiment = {
        "categorical_encoding": "ordinal",
        "numerical_encoding": "quantile",
        "normalization": "standard",
        "num_nan_policy": "mean",
        "cat_nan_policy": "most_frequent",
        "max_epochs": 10,
        "batch_size": 32,
        "tune": True,
        "n_trials": 5,
        "early_stopping": True,
        "early_stopping_patience": 5,
    }

    evaluation = {
        "round_digits": 5,
        "cv_metric": "f1",
        "metrics": {
            "pd": {"accuracy": True, "f1": True, "aucroc": True},
            "lgd": {"mse": True, "mae": True, "r2": False, "rmse": False},
        },
    }

    methods = {
        "methods": {
            "pd": {"tabpfn": True, "rf": False},
            "lgd": {"lr": False, "tabpfn": False},
        }
    }

    cfg_dir = tmpdir / "config"
    cfg_dir.mkdir()

    write_yaml(cfg_dir / "CONFIG_DATA.yaml", data)
    write_yaml(cfg_dir / "CONFIG_EXPERIMENT.yaml", experiment)
    write_yaml(cfg_dir / "CONFIG_EVALUATION.yaml", evaluation)
    write_yaml(cfg_dir / "CONFIG_METHOD.yaml", methods)
    return cfg_dir


# ----------------------------------------------------------------------
# Core test runner
# ----------------------------------------------------------------------

def run_test_case(title, override_fn=None, expect_ok=True):
    tmpdir = Path(tempfile.mkdtemp())
    cfg_dir = make_base_config(tmpdir)

    # Apply optional override
    if override_fn:
        override_fn(cfg_dir)

    print(f"\n=== 🧪 TEST CASE: {title} ===")
    try:
        cfg = ConfigReader(config_dir=str(cfg_dir)).load().validate().to_dict()
        if expect_ok:
            print("✅ Validation PASSED")
            pprint(cfg)
        else:
            print("❌ Expected failure but passed unexpectedly!")
    except Exception as e:
        if expect_ok:
            print("❌ Validation FAILED unexpectedly:")
            print(e)
        else:
            print("✅ Validation FAILED as expected:")
            print(e)
    finally:
        shutil.rmtree(tmpdir)


# ----------------------------------------------------------------------
# ✅ VALID CASES
# ----------------------------------------------------------------------

run_test_case("VALID: Base config (PD only)")

run_test_case(
    "VALID: row_limit None",
    lambda d: (
        yaml.safe_dump(
            {
                **yaml.safe_load(open(d / "CONFIG_DATA.yaml")),
                "split": {"test_size": 0.2, "val_size": 0.2, "cv_splits": 3, "seed": 42, "row_limit": None},
            },
            open(d / "CONFIG_DATA.yaml", "w"),
        )
    ),
)

run_test_case(
    "VALID: early_stopping + patience",
    lambda d: (
        yaml.safe_dump(
            {
                **yaml.safe_load(open(d / "CONFIG_EXPERIMENT.yaml")),
                "early_stopping": True,
                "early_stopping_patience": 3,
            },
            open(d / "CONFIG_EXPERIMENT.yaml", "w"),
        )
    ),
)


# ----------------------------------------------------------------------
# ❌ INVALID CASES
# ----------------------------------------------------------------------

run_test_case(
    "ERROR: test_size + val_size > 0.6",
    lambda d: (
        yaml.safe_dump(
            {
                **yaml.safe_load(open(d / "CONFIG_DATA.yaml")),
                "split": {"test_size": 0.5, "val_size": 0.2, "cv_splits": 3, "seed": 42, "row_limit": 1000},
            },
            open(d / "CONFIG_DATA.yaml", "w"),
        )
    ),
    expect_ok=False,
)

run_test_case(
    "ERROR: Missing cv_metric",
    lambda d: (
        yaml.safe_dump(
            {k: v for k, v in yaml.safe_load(open(d / "CONFIG_EVALUATION.yaml")).items() if k != "cv_metric"},
            open(d / "CONFIG_EVALUATION.yaml", "w"),
        )
    ),
    expect_ok=False,
)

run_test_case(
    "ERROR: categorical_encoding invalid",
    lambda d: (
        yaml.safe_dump(
            {
                **yaml.safe_load(open(d / "CONFIG_EXPERIMENT.yaml")),
                "categorical_encoding": "invalid_option",
            },
            open(d / "CONFIG_EXPERIMENT.yaml", "w"),
        )
    ),
    expect_ok=False,
)

run_test_case(
    "ERROR: early_stopping_patience not integer",
    lambda d: (
        yaml.safe_dump(
            {
                **yaml.safe_load(open(d / "CONFIG_EXPERIMENT.yaml")),
                "early_stopping_patience": "five",
            },
            open(d / "CONFIG_EXPERIMENT.yaml", "w"),
        )
    ),
    expect_ok=False,
)

run_test_case(
    "ERROR: PD dataset active but no PD metrics",
    lambda d: (
        yaml.safe_dump(
            {
                **yaml.safe_load(open(d / "CONFIG_EVALUATION.yaml")),
                "metrics": {"pd": {k: False for k in ["accuracy", "f1", "aucroc"]}, "lgd": {"mse": True, "mae": True}},
            },
            open(d / "CONFIG_EVALUATION.yaml", "w"),
        )
    ),
    expect_ok=False,
)



=== 🧪 TEST CASE: VALID: Base config (PD only) ===
❌ Validation FAILED unexpectedly:
Configuration validation failed:
[DATA] paths.pd_dir does not exist: data/raw/pd
[DATA] paths.lgd_dir does not exist: data/raw/lgd

=== 🧪 TEST CASE: VALID: row_limit None ===
❌ Validation FAILED unexpectedly:
Configuration validation failed:
[DATA] paths.pd_dir does not exist: data/raw/pd
[DATA] paths.lgd_dir does not exist: data/raw/lgd

=== 🧪 TEST CASE: VALID: early_stopping + patience ===
❌ Validation FAILED unexpectedly:
Configuration validation failed:
[DATA] paths.pd_dir does not exist: data/raw/pd
[DATA] paths.lgd_dir does not exist: data/raw/lgd

=== 🧪 TEST CASE: ERROR: test_size + val_size > 0.6 ===
✅ Validation FAILED as expected:
Configuration validation failed:
[DATA] test_size + val_size must be ≤ 0.8
[DATA] paths.pd_dir does not exist: data/raw/pd
[DATA] paths.lgd_dir does not exist: data/raw/lgd

=== 🧪 TEST CASE: ERROR: Missing cv_metric ===
✅ Validation FAILED as expected:
Configuration