# Step 2: Train ResNet From Precomputed Fused Data (Colab GPU)

Run this notebook after Step 1.
It trains/evaluates the stable normalized ResNet pipeline using precomputed Bare Earth + SpectralGPT fused features.


## Runtime

Set `Runtime -> Change runtime type -> T4/A100 GPU`, then run all cells in order.


In [None]:
from google.colab import drive
import os
import shutil
import subprocess
from pathlib import Path

print("[1/11] Mount + clone repo")

drive.mount('/content/drive', force_remount=True)

USE_GIT_CLONE = True
REPO_GIT_URL = "https://github.com/JackOnThePaddock/soil-resnet-model.git"
DRIVE_REPO_DIR = "/content/drive/MyDrive/soil-resnet-model"
PROJECT_DIR = "/content/soil-resnet-model"

if os.path.exists(PROJECT_DIR):
    shutil.rmtree(PROJECT_DIR)

if USE_GIT_CLONE:
    subprocess.run(["git", "clone", REPO_GIT_URL, PROJECT_DIR], check=True)
else:
    if not os.path.exists(DRIVE_REPO_DIR):
        raise FileNotFoundError(f"Repo not found at {DRIVE_REPO_DIR}")
    shutil.copytree(DRIVE_REPO_DIR, PROJECT_DIR)

os.chdir(PROJECT_DIR)
print("Project dir:", os.getcwd())


In [None]:
print("[2/11] Install dependencies")
!python -V
!pip -q install --upgrade pip
!pip -q install -e .


In [None]:
print("[3/11] Check GPU")
import torch
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


In [None]:
print("[4/11] Configure run")
import json
import yaml
import numpy as np
import pandas as pd
from pathlib import Path

RUN_ALPHA = True
RUN_FUSED = True
RUN_INDEPENDENT_VALIDATION = True
FAST_MODE = False

# If True, copy precomputed Step 1 files from Drive into this repo clone.
COPY_PRECOMPUTED_FROM_DRIVE = True
DRIVE_PRECOMPUTED_DIR = "/content/drive/MyDrive/soil-resnet-outputs/step1_bareearth_sgpt_YYYYMMDD_HHMMSS"

TRAINING = {
    "ensemble_size": 3 if FAST_MODE else 5,
    "n_splits": 3 if FAST_MODE else 5,
    "epochs_alpha": 60 if FAST_MODE else 300,
    "final_epochs_alpha": 40 if FAST_MODE else 200,
    "epochs_fused": 80 if FAST_MODE else 350,
    "final_epochs_fused": 60 if FAST_MODE else 220,
    "batch_size": 32,
    "learning_rate": 5e-5 if FAST_MODE else 8e-5,
    "weight_decay": 1e-4,
    "patience": 20 if FAST_MODE else 35,
    "lr_patience": 10 if FAST_MODE else 20,
    "lr_factor": 0.5,
    "grad_clip": 1.0,
    "random_seed": 42,
}

paths = {
    "raw": Path("data/processed/features.csv"),
    "normalized_points": Path("data/processed/features_normalized_points.csv"),
    "be_sgpt": Path("data/processed/features_normalized_points_be_sgpt.csv"),
    "sgpt_only": Path("data/processed/features_normalized_points_sgpt_embeddings.csv"),
    "sgpt_raw": Path("data/processed/features_normalized_points_sgpt_official_raw.csv"),
    "fused_feat": Path("data/processed/features_normalized_points_fused_feat.csv"),
    "fused_meta": Path("data/processed/features_normalized_points_fused_meta.json"),
    "cfg_alpha": Path("configs/colab_resnet_alpha_normalized_stable.yaml"),
    "cfg_fused": Path("configs/colab_resnet_fused_normalized_stable.yaml"),
    "model_alpha": Path("models/colab_resnet_alpha_norm_stable"),
    "model_fused": Path("models/colab_resnet_fused_norm_stable"),
    "validation": Path("data/validation/national_independent_1368.csv"),
}

for req in [paths["raw"]]:
    if not req.exists():
        raise FileNotFoundError(f"Missing required input: {req}")

print("Config ready")



In [None]:
print("[5/11] Import Step 1 precomputed files (optional)")
if COPY_PRECOMPUTED_FROM_DRIVE:
    drive_dir = Path(DRIVE_PRECOMPUTED_DIR)
    if not drive_dir.exists():
        raise FileNotFoundError(f"DRIVE_PRECOMPUTED_DIR not found: {drive_dir}")

    required = [
        paths["normalized_points"].name,
        paths["be_sgpt"].name,
        paths["sgpt_only"].name,
        paths["sgpt_raw"].name,
        paths["fused_feat"].name,
        paths["fused_meta"].name,
    ]
    for name in required:
        src = drive_dir / name
        if not src.exists():
            raise FileNotFoundError(f"Missing precomputed file: {src}")
        dst = Path("data/processed") / name
        shutil.copy2(src, dst)
        print("Copied", src, "->", dst)
else:
    print("Using precomputed files already present in repo clone")

for req in [paths["normalized_points"], paths["fused_feat"], paths["fused_meta"]]:
    if not req.exists():
        raise FileNotFoundError(f"Missing required precomputed file: {req}")

print("Precomputed inputs ready")


In [None]:
print("[6/11] Write stable training configs")
# Clean model dirs for fresh run
for d in [paths["model_alpha"], paths["model_fused"]]:
    if d.exists():
        shutil.rmtree(d)

TARGETS = ["ph", "cec", "esp", "soc", "ca", "mg", "na"]

fused_cols_df = pd.read_csv(paths["fused_feat"], nrows=1)
n_fused_features = len([c for c in fused_cols_df.columns if c.startswith("feat_")])

alpha_cfg = {
    "model": {"input_dim": 64, "hidden_dim": 128, "num_blocks": 2, "dropout": 0.2, "activation": "silu"},
    "targets": ["pH", "CEC", "ESP", "SOC", "Ca", "Mg", "Na"],
    "training": {
        "ensemble_size": TRAINING["ensemble_size"],
        "n_splits": TRAINING["n_splits"],
        "epochs": TRAINING["epochs_alpha"],
        "final_epochs": TRAINING["final_epochs_alpha"],
        "batch_size": TRAINING["batch_size"],
        "learning_rate": TRAINING["learning_rate"],
        "weight_decay": TRAINING["weight_decay"],
        "patience": TRAINING["patience"],
        "lr_patience": TRAINING["lr_patience"],
        "lr_factor": TRAINING["lr_factor"],
        "grad_clip": TRAINING["grad_clip"],
        "cv_strategy": "group_kfold",
        "loss_name": "weighted_huber",
        "huber_delta": 1.0,
        "esp_consistency_weight": 0.0,
        "target_weight_mode": "inverse_frequency",
        "sample_weight_mode": "rare_target_average",
        "auto_target_transforms": False,
        "target_transforms": {t: "identity" for t in TARGETS},
        "specialist_targets": ["cec", "esp", "soc"],
        "specialist_epochs": 60 if FAST_MODE else 120,
        "specialist_patience": 15 if FAST_MODE else 20,
        "specialist_val_fraction": 0.2,
        "specialist_blend_weight": 0.4,
        "random_seed": TRAINING["random_seed"],
    },
    "data": {"feature_prefix": "band_", "n_features": 64, "group_by": "latlon", "group_round": 4, "reference_data": None},
    "output": {"model_dir": str(paths["model_alpha"]), "metrics_dir": "results/metrics", "scaler_file": "scaler.pkl"},
}

fused_cfg = {
    "model": {"input_dim": n_fused_features, "hidden_dim": 192, "num_blocks": 3, "dropout": 0.2, "activation": "silu"},
    "targets": ["pH", "CEC", "ESP", "SOC", "Ca", "Mg", "Na"],
    "training": {
        "ensemble_size": TRAINING["ensemble_size"],
        "n_splits": TRAINING["n_splits"],
        "epochs": TRAINING["epochs_fused"],
        "final_epochs": TRAINING["final_epochs_fused"],
        "batch_size": TRAINING["batch_size"],
        "learning_rate": TRAINING["learning_rate"],
        "weight_decay": TRAINING["weight_decay"],
        "patience": TRAINING["patience"],
        "lr_patience": TRAINING["lr_patience"],
        "lr_factor": TRAINING["lr_factor"],
        "grad_clip": TRAINING["grad_clip"],
        "cv_strategy": "group_kfold",
        "loss_name": "weighted_huber",
        "huber_delta": 1.0,
        "esp_consistency_weight": 0.0,
        "target_weight_mode": "inverse_frequency",
        "sample_weight_mode": "rare_target_average",
        "auto_target_transforms": False,
        "target_transforms": {t: "identity" for t in TARGETS},
        "specialist_targets": ["cec", "esp", "soc"],
        "specialist_epochs": 80 if FAST_MODE else 150,
        "specialist_patience": 15 if FAST_MODE else 25,
        "specialist_val_fraction": 0.2,
        "specialist_blend_weight": 0.4,
        "random_seed": TRAINING["random_seed"],
    },
    "data": {"feature_prefix": "feat_", "n_features": None, "group_by": "latlon", "group_round": 4, "reference_data": None},
    "output": {"model_dir": str(paths["model_fused"]), "metrics_dir": "results/metrics", "scaler_file": "scaler.pkl"},
}

paths["cfg_alpha"].write_text(yaml.safe_dump(alpha_cfg, sort_keys=False), encoding="utf-8")
paths["cfg_fused"].write_text(yaml.safe_dump(fused_cfg, sort_keys=False), encoding="utf-8")
print("Wrote", paths["cfg_alpha"])
print("Wrote", paths["cfg_fused"])
print("Fused feature count:", n_fused_features)


In [None]:
print("[7/11] Train alpha model")
if RUN_ALPHA:
    cmd = [
        "python", "scripts/train_resnet_ensemble.py",
        "--data", str(paths["normalized_points"]),
        "--config", str(paths["cfg_alpha"]),
        "--output", str(paths["model_alpha"]),
        "--cv-strategy", "group_kfold",
        "--seed", str(TRAINING["random_seed"]),
    ]
    print("Running:", " ".join(cmd))
    subprocess.run(cmd, check=True)
else:
    print("Skipped alpha training")


In [None]:
print("[8/11] Train fused model")
if RUN_FUSED:
    cmd = [
        "python", "scripts/train_resnet_ensemble.py",
        "--data", str(paths["fused_feat"]),
        "--config", str(paths["cfg_fused"]),
        "--output", str(paths["model_fused"]),
        "--cv-strategy", "group_kfold",
        "--seed", str(TRAINING["random_seed"]),
    ]
    print("Running:", " ".join(cmd))
    subprocess.run(cmd, check=True)
else:
    print("Skipped fused training")


In [None]:
print("[9/11] Validate checkpoints finite")
import torch


def check_model_dir(model_dir: Path):
    model_files = sorted(model_dir.glob("model_*.pth"))
    if not model_files:
        raise FileNotFoundError(f"No model_*.pth found in {model_dir}")

    print("Checking", model_dir)
    for mf in model_files:
        ckpt = torch.load(mf, map_location="cpu", weights_only=False)
        state = ckpt["model_state_dict"]
        bad = []
        for k, v in state.items():
            if torch.is_tensor(v) and (torch.isnan(v).any() or torch.isinf(v).any()):
                bad.append(k)
        if bad:
            raise RuntimeError(f"Invalid weights in {mf.name}: {bad[:5]}")
        print("  OK", mf.name)

if RUN_ALPHA:
    check_model_dir(paths["model_alpha"])
if RUN_FUSED:
    check_model_dir(paths["model_fused"])
print("All checkpoints finite")


In [None]:
print("[10/11] CV summary + independent validation")
import re


def summarize_ensemble_metrics(csv_path: Path, model_label: str) -> pd.DataFrame:
    if not csv_path.exists():
        return pd.DataFrame([{"model": model_label, "target": "(missing)", "metric": "", "mean": np.nan, "std": np.nan}])
    df = pd.read_csv(csv_path)
    rows = []
    for col in df.columns:
        m = re.match(r"^(.*)_(r2|rmse|mae)$", col)
        if m:
            rows.append({
                "model": model_label,
                "target": m.group(1).lower(),
                "metric": m.group(2),
                "mean": float(df[col].mean()),
                "std": float(df[col].std(ddof=0)),
            })
    if not rows:
        rows.append({"model": model_label, "target": "(no per-target metrics)", "metric": "", "mean": np.nan, "std": np.nan})
    return pd.DataFrame(rows)

alpha_cv = summarize_ensemble_metrics(paths["model_alpha"] / "ensemble_metrics.csv", "alpha_norm_stable")
fused_cv = summarize_ensemble_metrics(paths["model_fused"] / "ensemble_metrics.csv", "fused_norm_stable")
cv_summary = pd.concat([alpha_cv, fused_cv], ignore_index=True)
display(cv_summary)

cv_out = Path("results/metrics/colab_cv_summary_train_only.csv")
cv_out.parent.mkdir(parents=True, exist_ok=True)
cv_summary.to_csv(cv_out, index=False)
print("Saved", cv_out)

if RUN_INDEPENDENT_VALIDATION:
    from src.models.ensemble import SoilEnsemble
    from src.evaluation.metrics import compute_metrics

    val_path = paths["validation"]
    if not val_path.exists():
        raise FileNotFoundError(f"Missing validation file: {val_path}")

    val_df = pd.read_csv(val_path)
    for i in range(64):
        src = f"A{i:02d}"
        if src in val_df.columns:
            val_df[f"band_{i}"] = val_df[src]

    feature_cols = [f"band_{i}" for i in range(64)]
    missing = [c for c in feature_cols if c not in val_df.columns]
    if missing:
        raise ValueError(f"Validation missing feature columns: {missing[:5]}")

    model = SoilEnsemble(paths["model_alpha"])
    pred_norm, _ = model.predict_batch(val_df[feature_cols].values.astype(np.float32))

    targets = ["ph", "cec", "esp", "soc", "ca", "mg", "na"]
    train_raw = pd.read_csv(paths["raw"], usecols=targets)
    mu = train_raw.mean(axis=0)
    sigma = train_raw.std(axis=0, ddof=0).replace(0.0, 1.0)

    pred_norm_df = pd.DataFrame(pred_norm, columns=model.target_names)
    pred_raw_df = pred_norm_df.copy()
    for t in model.target_names:
        pred_raw_df[t] = pred_norm_df[t] * float(sigma[t]) + float(mu[t])

    y_true_raw = pd.DataFrame(np.nan, index=val_df.index, columns=model.target_names)
    truth_map = {"ph": "ph", "cec": "cec_cmolkg", "esp": "esp_pct", "na": "na_cmolkg"}
    for t, src_col in truth_map.items():
        if src_col in val_df.columns:
            y_true_raw[t] = pd.to_numeric(val_df[src_col], errors="coerce")

    metrics = compute_metrics(
        y_true=y_true_raw[model.target_names].values,
        y_pred=pred_raw_df[model.target_names].values,
        target_names=model.target_names,
    )

    indep_df = pd.DataFrame(metrics).T.reset_index().rename(columns={"index": "target"}) if metrics else pd.DataFrame()
    display(indep_df)

    indep_out = Path("results/metrics/colab_independent_validation_alpha_train_only.csv")
    indep_df.to_csv(indep_out, index=False)
    print("Saved", indep_out)
else:
    print("Skipped independent validation")


In [None]:
print("[11/11] Save outputs to Drive")
from datetime import datetime

stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
drive_out = Path(f"/content/drive/MyDrive/soil-resnet-outputs/step2_resnet_train_{stamp}")
drive_out.mkdir(parents=True, exist_ok=True)

copy_targets = [
    paths["cfg_alpha"],
    paths["cfg_fused"],
    Path("results/metrics/colab_cv_summary_train_only.csv"),
    Path("results/metrics/colab_independent_validation_alpha_train_only.csv"),
]

for p in copy_targets:
    if p.exists():
        shutil.copy2(p, drive_out / p.name)

if paths["model_alpha"].exists():
    shutil.copytree(paths["model_alpha"], drive_out / paths["model_alpha"].name)
if paths["model_fused"].exists():
    shutil.copytree(paths["model_fused"], drive_out / paths["model_fused"].name)

print("Saved outputs to:", drive_out)
