# Full Colab Pipeline: Normalized Data + Bare Earth + SpectralGPT + ResNet + Validation

This notebook runs your complete workflow on Colab GPU:
1. Prepare full normalized training data with coordinates.
2. Pull GA Barest Earth (Sentinel-2) for all training points.
3. Train SpectralGPT embeddings.
4. Train ResNet ensembles (`alpha-only` and `fused`).
5. Produce cross-validation and independent validation metrics.

## Data rights note

This notebook uses DEA/Geoscience Australia OWS endpoints for Bare Earth. Ensure your use complies with DEA/GA usage rights and constraints.

## 1) Runtime setup

In Colab, set `Runtime -> Change runtime type -> GPU` before running.

In [None]:
from google.colab import drive
import os
import shutil
import subprocess

# Mount Google Drive
drive.mount('/content/drive')

# ---- Configure these paths ----
USE_GIT_CLONE = True
REPO_GIT_URL = "https://github.com/JackOnThePaddock/soil-resnet-model.git"  # e.g. "https://github.com/<user>/soil-resnet-model.git"
DRIVE_REPO_DIR = "/content/drive/MyDrive/soil-resnet-model"
PROJECT_DIR = "/content/soil-resnet-model"

if os.path.exists(PROJECT_DIR):
    shutil.rmtree(PROJECT_DIR)

if USE_GIT_CLONE:
    if not REPO_GIT_URL:
        raise ValueError("Set REPO_GIT_URL or set USE_GIT_CLONE=False")
    subprocess.run(["git", "clone", REPO_GIT_URL, PROJECT_DIR], check=True)
else:
    if not os.path.exists(DRIVE_REPO_DIR):
        raise FileNotFoundError(f"Repo not found at {DRIVE_REPO_DIR}")
    shutil.copytree(DRIVE_REPO_DIR, PROJECT_DIR)

os.chdir(PROJECT_DIR)
print("Project dir:", os.getcwd())

In [None]:
!python -V
!pip -q install --upgrade pip
!pip -q install -e .

In [None]:
import torch
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

## 2) Run configuration

Set `FAST_MODE=False` for full training. `FAST_MODE=True` is useful only for quick pipeline checks.

In [None]:
from pathlib import Path
import json
import yaml
import numpy as np
import pandas as pd

FAST_MODE = False
RUN_ALPHA = True
RUN_FUSED = True
RUN_BARE_EARTH_PULL = True
RUN_INDEPENDENT_VALIDATION = True

SPECTRAL_DIM = 16
WCS_WORKERS = 16
WCS_TIMEOUT = 120
WCS_RETRIES = 3

TRAINING = {
    "ensemble_size": 3 if FAST_MODE else 5,
    "n_splits": 3 if FAST_MODE else 5,
    "epochs_alpha": 60 if FAST_MODE else 300,
    "final_epochs_alpha": 40 if FAST_MODE else 200,
    "epochs_fused": 80 if FAST_MODE else 350,
    "final_epochs_fused": 60 if FAST_MODE else 220,
    "batch_size": 32,
    "learning_rate": 1e-4,
    "weight_decay": 1e-4,
    "patience": 20 if FAST_MODE else 35,
    "lr_patience": 10 if FAST_MODE else 20,
    "lr_factor": 0.5,
    "grad_clip": 1.0,
    "random_seed": 42,
}

TARGETS = ["ph", "cec", "esp", "soc", "ca", "mg", "na"]

paths = {
    "raw": Path("data/processed/features.csv"),
    "normalized": Path("data/processed/features_normalized.csv"),
    "normalized_points": Path("data/processed/features_normalized_points.csv"),
    "be_sgpt": Path("data/processed/features_normalized_bareearth_sgpt.csv"),
    "sgpt_only": Path("data/processed/features_normalized_sgpt_embeddings.csv"),
    "fused_feat": Path("data/processed/features_normalized_fused_feat.csv"),
    "fused_meta": Path("data/processed/features_normalized_fused_meta.json"),
    "cfg_alpha": Path("configs/colab_resnet_alpha_normalized.yaml"),
    "cfg_fused": Path("configs/colab_resnet_fused_normalized.yaml"),
    "model_alpha": Path("models/colab_resnet_alpha_norm"),
    "model_fused": Path("models/colab_resnet_fused_norm"),
    "validation": Path("data/validation/national_independent_1368.csv"),
}

for k, p in paths.items():
    if k in {"raw", "normalized"} and not p.exists():
        raise FileNotFoundError(f"Missing required file: {p}")

print("Config loaded")

## 3) Build normalized training table with point coordinates

`features_normalized.csv` has normalized targets + AlphaEarth features, but no `lat/lon`. This cell joins `id/lat/lon` from `features.csv` by row order.

In [None]:
raw_meta = pd.read_csv(paths["raw"], usecols=["id", "lat", "lon"])
norm_df = pd.read_csv(paths["normalized"])

if len(raw_meta) != len(norm_df):
    raise ValueError(f"Row mismatch: raw={len(raw_meta)} normalized={len(norm_df)}")

norm_points = pd.concat([raw_meta.reset_index(drop=True), norm_df.reset_index(drop=True)], axis=1)
paths["normalized_points"].parent.mkdir(parents=True, exist_ok=True)
norm_points.to_csv(paths["normalized_points"], index=False)

print("Saved:", paths["normalized_points"], norm_points.shape)

## 4) Pull Bare Earth and train SpectralGPT embeddings on all normalized rows

In [None]:
import subprocess

if RUN_BARE_EARTH_PULL:
    cmd = [
        "python", "scripts/pull_bare_earth_embeddings.py",
        "--normalized-csv", str(paths["normalized"]),
        "--points-csv", str(paths["raw"]),
        "--output-csv", str(paths["be_sgpt"]),
        "--output-embeddings-csv", str(paths["sgpt_only"]),
        "--workers", str(WCS_WORKERS),
        "--timeout", str(WCS_TIMEOUT),
        "--retries", str(WCS_RETRIES),
        "--spectral-backend", "official_pretrained",
        "--official-request-chunk-size", "64",
        "--output-official-raw-csv", "data/processed/features_normalized_sgpt_official_raw.csv",
        "--spectral-dim", str(SPECTRAL_DIM),
        "--seed", str(TRAINING["random_seed"]),
    ]
    print("Running:", " ".join(cmd))
    subprocess.run(cmd, check=True)
else:
    if not paths["be_sgpt"].exists():
        raise FileNotFoundError(f"Expected existing file: {paths['be_sgpt']}")

print("Bare Earth + SGPT file:", paths["be_sgpt"].exists())
print("Embeddings-only file:", paths["sgpt_only"].exists())

## 5) Build fused `feat_*` columns for ResNet fused training

Training script expects one feature namespace (e.g. `feat_000..`). We pack:
- AlphaEarth: `band_*`
- Bare Earth: `be_*`
- SpectralGPT: `sgpt_*`

In [None]:
import re

fused = pd.read_csv(paths["be_sgpt"])


def sort_numeric_suffix(cols):
    def key(c):
        m = re.search(r"(\d+)$", c)
        return (0, int(m.group(1)), c) if m else (1, -1, c)
    return sorted(cols, key=key)

band_cols = sort_numeric_suffix([c for c in fused.columns if c.lower().startswith("band_")])
be_cols = sort_numeric_suffix([c for c in fused.columns if c.lower().startswith("be_")])
sgpt_cols = sort_numeric_suffix([c for c in fused.columns if c.lower().startswith("sgpt_")])

if len(band_cols) != 64:
    raise ValueError(f"Expected 64 alpha band columns, got {len(band_cols)}")
if len(be_cols) < 5:
    raise ValueError(f"Too few Bare Earth columns detected: {len(be_cols)}")
if len(sgpt_cols) == 0:
    raise ValueError("No SpectralGPT columns found")

source_cols = band_cols + be_cols + sgpt_cols
for i, col in enumerate(source_cols):
    fused[f"feat_{i:03d}"] = fused[col]

paths["fused_feat"].parent.mkdir(parents=True, exist_ok=True)
fused.to_csv(paths["fused_feat"], index=False)

meta = {
    "alpha_cols": band_cols,
    "bareearth_cols": be_cols,
    "spectral_cols": sgpt_cols,
    "feat_cols": [f"feat_{i:03d}" for i in range(len(source_cols))],
}
paths["fused_meta"].write_text(json.dumps(meta, indent=2), encoding="utf-8")

print("Saved fused table:", paths["fused_feat"], fused.shape)
print("Feature blocks:", {"band": len(band_cols), "be": len(be_cols), "sgpt": len(sgpt_cols), "feat_total": len(meta['feat_cols'])})

## 6) Write Colab training configs (normalized targets)

Important: because targets are already normalized, all target transforms are set to `identity`.

In [None]:
paths["cfg_alpha"].parent.mkdir(parents=True, exist_ok=True)

config_alpha = {
    "model": {
        "input_dim": 64,
        "hidden_dim": 128,
        "num_blocks": 2,
        "dropout": 0.2,
        "activation": "silu",
    },
    "targets": ["pH", "CEC", "ESP", "SOC", "Ca", "Mg", "Na"],
    "training": {
        "ensemble_size": TRAINING["ensemble_size"],
        "n_splits": TRAINING["n_splits"],
        "epochs": TRAINING["epochs_alpha"],
        "final_epochs": TRAINING["final_epochs_alpha"],
        "batch_size": TRAINING["batch_size"],
        "learning_rate": TRAINING["learning_rate"],
        "weight_decay": TRAINING["weight_decay"],
        "patience": TRAINING["patience"],
        "lr_patience": TRAINING["lr_patience"],
        "lr_factor": TRAINING["lr_factor"],
        "grad_clip": TRAINING["grad_clip"],
        "cv_strategy": "group_kfold",
        "loss_name": "weighted_huber",
        "huber_delta": 1.0,
        "esp_consistency_weight": 0.05,
        "target_weight_mode": "inverse_frequency",
        "sample_weight_mode": "rare_target_average",
        "auto_target_transforms": False,
        "target_transforms": {t: "identity" for t in TARGETS},
        "specialist_targets": ["cec", "esp", "soc"],
        "specialist_epochs": 80 if FAST_MODE else 120,
        "specialist_patience": 15 if FAST_MODE else 20,
        "specialist_val_fraction": 0.2,
        "specialist_blend_weight": 0.4,
        "random_seed": TRAINING["random_seed"],
    },
    "data": {
        "feature_prefix": "band_",
        "n_features": 64,
        "group_by": "latlon",
        "group_round": 4,
        "reference_data": None,
    },
    "output": {
        "model_dir": str(paths["model_alpha"]),
        "metrics_dir": "results/metrics",
        "scaler_file": "scaler.pkl",
    },
}

fused_cols_df = pd.read_csv(paths["fused_feat"], nrows=1)
n_fused_features = len([c for c in fused_cols_df.columns if c.startswith("feat_")])

config_fused = {
    "model": {
        "input_dim": n_fused_features,
        "hidden_dim": 192,
        "num_blocks": 3,
        "dropout": 0.2,
        "activation": "silu",
    },
    "targets": ["pH", "CEC", "ESP", "SOC", "Ca", "Mg", "Na"],
    "training": {
        "ensemble_size": TRAINING["ensemble_size"],
        "n_splits": TRAINING["n_splits"],
        "epochs": TRAINING["epochs_fused"],
        "final_epochs": TRAINING["final_epochs_fused"],
        "batch_size": TRAINING["batch_size"],
        "learning_rate": TRAINING["learning_rate"],
        "weight_decay": TRAINING["weight_decay"],
        "patience": TRAINING["patience"],
        "lr_patience": TRAINING["lr_patience"],
        "lr_factor": TRAINING["lr_factor"],
        "grad_clip": TRAINING["grad_clip"],
        "cv_strategy": "group_kfold",
        "loss_name": "weighted_huber",
        "huber_delta": 1.0,
        "esp_consistency_weight": 0.08,
        "target_weight_mode": "inverse_frequency",
        "sample_weight_mode": "rare_target_average",
        "auto_target_transforms": False,
        "target_transforms": {t: "identity" for t in TARGETS},
        "specialist_targets": ["cec", "esp", "soc"],
        "specialist_epochs": 100 if FAST_MODE else 150,
        "specialist_patience": 15 if FAST_MODE else 25,
        "specialist_val_fraction": 0.2,
        "specialist_blend_weight": 0.4,
        "random_seed": TRAINING["random_seed"],
    },
    "data": {
        "feature_prefix": "feat_",
        "n_features": None,
        "group_by": "latlon",
        "group_round": 4,
        "reference_data": None,
    },
    "output": {
        "model_dir": str(paths["model_fused"]),
        "metrics_dir": "results/metrics",
        "scaler_file": "scaler.pkl",
    },
}

paths["cfg_alpha"].write_text(yaml.safe_dump(config_alpha, sort_keys=False), encoding="utf-8")
paths["cfg_fused"].write_text(yaml.safe_dump(config_fused, sort_keys=False), encoding="utf-8")

print("Wrote", paths["cfg_alpha"])
print("Wrote", paths["cfg_fused"])
print("Fused input_dim:", n_fused_features)

## 7) Train ResNet ensemble: alpha-only (normalized)

In [None]:
import subprocess

if RUN_ALPHA:
    cmd = [
        "python", "scripts/train_resnet_ensemble.py",
        "--data", str(paths["normalized_points"]),
        "--config", str(paths["cfg_alpha"]),
        "--output", str(paths["model_alpha"]),
        "--cv-strategy", "group_kfold",
        "--seed", str(TRAINING["random_seed"]),
    ]
    print("Running:", " ".join(cmd))
    subprocess.run(cmd, check=True)
else:
    print("Skipped alpha training")

## 8) Train ResNet ensemble: fused (normalized + Bare Earth + SpectralGPT)

In [None]:
if RUN_FUSED:
    cmd = [
        "python", "scripts/train_resnet_ensemble.py",
        "--data", str(paths["fused_feat"]),
        "--config", str(paths["cfg_fused"]),
        "--output", str(paths["model_fused"]),
        "--cv-strategy", "group_kfold",
        "--seed", str(TRAINING["random_seed"]),
    ]
    print("Running:", " ".join(cmd))
    subprocess.run(cmd, check=True)
else:
    print("Skipped fused training")

## 9) Cross-validation error summary from trained ensembles

In [None]:
import re


def summarize_ensemble_metrics(csv_path: Path, model_label: str) -> pd.DataFrame:
    if not csv_path.exists():
        return pd.DataFrame()
    df = pd.read_csv(csv_path)
    rows = []
    for col in df.columns:
        m = re.match(r"^(.*)_(r2|rmse)$", col)
        if not m:
            continue
        target, metric = m.group(1), m.group(2)
        rows.append({
            "model": model_label,
            "target": target,
            "metric": metric,
            "mean": float(df[col].mean()),
            "std": float(df[col].std(ddof=0)),
        })
    out = pd.DataFrame(rows)
    if out.empty:
        return out
    out["target"] = out["target"].str.lower()
    return out.sort_values(["model", "target", "metric"]).reset_index(drop=True)

alpha_metrics = summarize_ensemble_metrics(paths["model_alpha"] / "ensemble_metrics.csv", "alpha_norm")
fused_metrics = summarize_ensemble_metrics(paths["model_fused"] / "ensemble_metrics.csv", "fused_norm")
cv_summary = pd.concat([alpha_metrics, fused_metrics], ignore_index=True)

if cv_summary.empty:
    print("No ensemble metrics found")
else:
    display(cv_summary)

pivot = cv_summary.pivot_table(index=["target", "metric"], columns="model", values="mean", aggfunc="first") if not cv_summary.empty else pd.DataFrame()
if not pivot.empty:
    print("\nMean CV metrics")
    display(pivot)

## 10) Independent validation (alpha model) on `national_independent_1368.csv`

This validation file has AlphaEarth `A00..A63` and target subset (`ph`, `cec_cmolkg`, `esp_pct`, `na_cmolkg`).
Because this notebook trains on normalized targets, predictions are converted back to raw units before metric reporting.

In [None]:
if RUN_INDEPENDENT_VALIDATION:
    from src.models.ensemble import SoilEnsemble
    from src.evaluation.metrics import compute_metrics

    if not paths["validation"].exists():
        raise FileNotFoundError(f"Missing validation file: {paths['validation']}")

    val_df = pd.read_csv(paths["validation"])

    # Map A00..A63 -> band_0..band_63 expected by alpha model.
    for i in range(64):
        src = f"A{i:02d}"
        if src in val_df.columns:
            val_df[f"band_{i}"] = val_df[src]

    feature_cols = [f"band_{i}" for i in range(64)]
    missing_feat = [c for c in feature_cols if c not in val_df.columns]
    if missing_feat:
        raise ValueError(f"Validation features missing: {missing_feat[:5]}")

    model = SoilEnsemble(paths["model_alpha"])
    pred_norm, pred_std_norm = model.predict_batch(val_df[feature_cols].values.astype(np.float32))

    # Recover normalization stats from raw training targets.
    train_raw = pd.read_csv(paths["raw"], usecols=TARGETS)
    mu = train_raw.mean(axis=0)
    sigma = train_raw.std(axis=0, ddof=0).replace(0.0, 1.0)

    pred_norm_df = pd.DataFrame(pred_norm, columns=model.target_names)
    pred_raw_df = pred_norm_df.copy()
    for t in model.target_names:
        pred_raw_df[t] = pred_norm_df[t] * float(sigma[t]) + float(mu[t])

    y_true_raw = pd.DataFrame(np.nan, index=val_df.index, columns=model.target_names)
    truth_map = {
        "ph": "ph",
        "cec": "cec_cmolkg",
        "esp": "esp_pct",
        "na": "na_cmolkg",
    }
    for t, src_col in truth_map.items():
        if src_col in val_df.columns:
            y_true_raw[t] = pd.to_numeric(val_df[src_col], errors="coerce")

    metrics = compute_metrics(
        y_true=y_true_raw[model.target_names].values,
        y_pred=pred_raw_df[model.target_names].values,
        target_names=model.target_names,
    )

    if metrics:
        indep_df = pd.DataFrame(metrics).T.reset_index().rename(columns={"index": "target"})
        indep_df = indep_df.sort_values("target").reset_index(drop=True)
        display(indep_df)
    else:
        print("No comparable targets found for independent validation")

    # Optional scatter preview for targets that exist in validation file.
    import matplotlib.pyplot as plt

    available_targets = [t for t in ["ph", "cec", "esp", "na"] if t in metrics]
    if available_targets:
        fig, axes = plt.subplots(1, len(available_targets), figsize=(5 * len(available_targets), 4))
        if len(available_targets) == 1:
            axes = [axes]
        for ax, t in zip(axes, available_targets):
            mask = np.isfinite(y_true_raw[t].values) & np.isfinite(pred_raw_df[t].values)
            ax.scatter(y_true_raw[t].values[mask], pred_raw_df[t].values[mask], s=8, alpha=0.5)
            lo = np.nanmin(y_true_raw[t].values[mask])
            hi = np.nanmax(y_true_raw[t].values[mask])
            ax.plot([lo, hi], [lo, hi], "k--", linewidth=1)
            ax.set_title(f"{t} (R2={metrics[t]['r2']:.3f})")
            ax.set_xlabel("Observed")
            ax.set_ylabel("Predicted")
        plt.tight_layout()
        plt.show()
else:
    print("Skipped independent validation")

## 11) Save outputs back to Drive

In [None]:
from datetime import datetime
import shutil

stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
drive_out = Path(f"/content/drive/MyDrive/soil-resnet-outputs/full_pipeline_{stamp}")
drive_out.mkdir(parents=True, exist_ok=True)

copy_targets = [
    paths["normalized_points"],
    paths["be_sgpt"],
    paths["sgpt_only"],
    paths["fused_feat"],
    paths["fused_meta"],
    paths["cfg_alpha"],
    paths["cfg_fused"],
]
for p in copy_targets:
    if p.exists():
        dest = drive_out / p.name
        shutil.copy2(p, dest)

if paths["model_alpha"].exists():
    shutil.copytree(paths["model_alpha"], drive_out / paths["model_alpha"].name)
if paths["model_fused"].exists():
    shutil.copytree(paths["model_fused"], drive_out / paths["model_fused"].name)

print("Saved outputs to:", drive_out)