# ML Model Validation

This notebook demonstrates model validation using `ml_validate`, which evaluates model performance using cross-validation or a hold-out split.

We'll validate a Random Forest classifier, which is faster than deep learning models and works well for this demonstration.

Results are written to `metrics.json`.

Notes
1.This is cross validation - parameters and training set.
2. The next one : Fitted model and validation set

sits_kfold_validate (1), sits_validate(2)

impute_linear - gap filling.

rewrite .validate_sits internally and drop ml_method, validation_split, and prediction, 
check also spatial cross validation - Jakub Nowosad and Hannah Meyer.

In [9]:
import openeo # type: ignore

In [2]:
connection = openeo.connect(url="http://127.0.0.1:8000")
connection.authenticate_basic("brian", "123456")

<Connection to 'http://127.0.0.1:8000/' with BasicBearerAuth>

In [3]:
training_set = "https://github.com/e-sensing/sitsdata/raw/main/data/samples_deforestation_rondonia.rds"

## Validation Options

The `ml_validate` process supports:
- **Hold-out split**: Set `cv` to 0 or 1, uses `validation_split` fraction (default 0.2)
- **K-fold cross-validation**: Set `cv` to number of folds (e.g., 5 or 10)
- **Separate validation set**: Provide `validation_data` parameter

Cross-validation provides more robust estimates but takes longer.

## Random Forest Validation

Random Forest is a fast classical ML algorithm that works well for satellite time-series classification.

In [4]:
# Validate Random Forest with 5-fold cross-validation
process_graph_rf = {
    "rf1": {
        "process_id": "mlm_class_random_forest",
        "arguments": {
            "num_trees": 100,
            "seed": 42
        },
    },
    "validate1": {
        "process_id": "ml_validate",
        "arguments": {
            "model": {"from_node": "rf1"},
            "training_data": training_set,
            "target": "label",
            "cv": 5,  # 5-fold cross-validation
            "seed": 42
        },
        "result": True,
    },
}

job_rf = connection.create_job(
    process_graph=process_graph_rf,
    title="Random Forest validation (5-fold CV)",
    description="Validate Random Forest with cross-validation",
)
job_rf.start_and_wait()
results_rf = job_rf.get_results()

0:00:00 Job 'f72c5ff20ba45a9fc7a9ad958d52dc1a': send 'start'
0:00:01 Job 'f72c5ff20ba45a9fc7a9ad958d52dc1a': running (progress N/A)
0:00:06 Job 'f72c5ff20ba45a9fc7a9ad958d52dc1a': running (progress N/A)
0:00:12 Job 'f72c5ff20ba45a9fc7a9ad958d52dc1a': running (progress N/A)
0:00:20 Job 'f72c5ff20ba45a9fc7a9ad958d52dc1a': running (progress N/A)
0:00:29 Job 'f72c5ff20ba45a9fc7a9ad958d52dc1a': running (progress N/A)
0:00:42 Job 'f72c5ff20ba45a9fc7a9ad958d52dc1a': running (progress N/A)
0:00:57 Job 'f72c5ff20ba45a9fc7a9ad958d52dc1a': running (progress N/A)
0:01:16 Job 'f72c5ff20ba45a9fc7a9ad958d52dc1a': running (progress N/A)
0:01:40 Job 'f72c5ff20ba45a9fc7a9ad958d52dc1a': running (progress N/A)
0:02:10 Job 'f72c5ff20ba45a9fc7a9ad958d52dc1a': finished (progress N/A)


In [5]:
results_rf.download_files("data/outputs_validation_rf/")

[PosixPath('data/outputs_validation_rf/metrics'),
 PosixPath('data/outputs_validation_rf/job-results.json')]

## Hold-out Validation Example

For faster validation (e.g., during development), use a simple train/test split instead of cross-validation.

In [6]:
# Validate with 80/20 hold-out split (faster than CV)
process_graph_holdout = {
    "rf1": {
        "process_id": "mlm_class_random_forest",
        "arguments": {
            "num_trees": 100,
            "seed": 42
        },
    },
    "validate1": {
        "process_id": "ml_validate",
        "arguments": {
            "model": {"from_node": "rf1"},
            "training_data": training_set,
            "target": "label",
            "cv": 0,  # No cross-validation, use hold-out split
            "validation_split": 0.2,  # 20% for validation
            "seed": 42
        },
        "result": True,
    },
}

job_holdout = connection.create_job(
    process_graph=process_graph_holdout,
    title="Random Forest validation (hold-out)",
    description="Validate Random Forest with 80/20 split",
)
job_holdout.start_and_wait()
results_holdout = job_holdout.get_results()

0:00:00 Job '6362c0b3c638477c81e40d343bbefb44': send 'start'
0:00:01 Job '6362c0b3c638477c81e40d343bbefb44': running (progress N/A)
0:00:06 Job '6362c0b3c638477c81e40d343bbefb44': running (progress N/A)
0:00:12 Job '6362c0b3c638477c81e40d343bbefb44': running (progress N/A)
0:00:20 Job '6362c0b3c638477c81e40d343bbefb44': finished (progress N/A)


In [7]:
results_holdout.download_files("data/outputs_validation_holdout/")

[PosixPath('data/outputs_validation_holdout/metrics'),
 PosixPath('data/outputs_validation_holdout/job-results.json')]

## View Results

Load and display the validation metrics.

In [8]:
import json
from pathlib import Path

# Load CV results (file is named 'metrics' - downloaded using asset key name)
cv_metrics_file = Path("data/outputs_validation_rf/metrics")
if cv_metrics_file.exists():
    with open(cv_metrics_file) as f:
        cv_metrics = json.load(f)
    print("5-Fold Cross-Validation Results:")
    print(json.dumps(cv_metrics, indent=2))
else:
    print(f"CV metrics file not found: {cv_metrics_file}")

# Load holdout results  
holdout_metrics_file = Path("data/outputs_validation_holdout/metrics")
if holdout_metrics_file.exists():
    with open(holdout_metrics_file) as f:
        holdout_metrics = json.load(f)
    print("\nHold-out Validation Results:")
    print(json.dumps(holdout_metrics, indent=2))
else:
    print(f"Holdout metrics file not found: {holdout_metrics_file}")

5-Fold Cross-Validation Results:
[
  {
    "Accuracy": 0.8914599633760613,
    "Kappa": 0.8721160418326684,
    "AccuracyLower": 0.8833196017863246,
    "AccuracyUpper": 0.8992179221816992,
    "AccuracyNull": 0.20759114366572332,
    "AccuracyPValue": 0
  }
]

Hold-out Validation Results:
[
  {
    "Accuracy": 0.8860232945091514,
    "Kappa": 0.8658275558916054,
    "AccuracyLower": 0.8666906256110423,
    "AccuracyUpper": 0.9034429051987628,
    "AccuracyNull": 0.20715474209650583,
    "AccuracyPValue": 0
  }
]
