In [None]:
#!/usr/bin/env julia
using ClimaLand
using ClimaLand.Domains: Column
using ClimaLand.Canopy
using ClimaLand.Simulations
import ClimaLand.Parameters as LP
import ClimaDiagnostics
import EnsembleKalmanProcesses as EKP
import EnsembleKalmanProcesses.ParameterDistributions as PD
using NCDatasets
using Statistics, Dates, Random, Logging

# ---------- Config ----------
const FT = Float32
rng = Random.MersenneTwister(1234)

# Time window (match your RS product availability)
start_date = DateTime(2010, 5, 1)
stop_date  = DateTime(2010, 7, 1)
Δt = 3600.0 # 1hr; match ERA5-Land cadence
spinup = Day(10)

# Choose an RS LE product (here: FLUXCOM monthly or 8-daily; adapt to your file)
const FLUXCOM_PATH = "/path/to/FLUXCOM_LE_0p5deg_2010.nc"

# Downsample grid to keep runs reasonable at first
MAX_POINTS = 2000   # increase as you scale up

# ---------- Load RS grid & data ----------
struct GridPoint{FT}  # minimal geo struct
    lat::FT
    lon::FT
    i::Int   # RS grid row
    j::Int   # RS grid col
end

function load_fluxcom_subset(path::AbstractString; max_points=MAX_POINTS)
    ds = NCDataset(path)
    lat = vec(ds["lat"][:]); lon = vec(ds["lon"][:])
    # Example var name; change to your FLUXCOM variable (e.g., "LE" or "LE_FluxCom")
    LE  = ds["LE"]  # dims (time, lat, lon) or (lat, lon, time) depending on file
    # Normalize dimension order to (time, lat, lon)
    dims = Dict(NCDatasets.dimnames(LE) .=> axes(LE))
    time_dim = findfirst(==("time"), NCDatasets.dimnames(LE))
    lat_dim  = findfirst(==("lat"),  NCDatasets.dimnames(LE))
    lon_dim  = findfirst(==("lon"),  NCDatasets.dimnames(LE))
    @assert !isnothing(time_dim) && !isnothing(lat_dim) && !isnothing(lon_dim)

    # Build a list of active land grid cells (skip NaN)
    points = GridPoint{FT}[]
    # Simple stride sampling to keep a subset manageable
    stride = max(1, round(Int, max(length(lat)*length(lon) / max_points, 1)))
    for (ii, φ) in enumerate(lat[1:stride:end])
        for (jj, λ) in enumerate(lon[1:stride:end])
            i = 1 + (ii-1)*stride
            j = 1 + (jj-1)*stride
            # test first time slice for validity
            val = (time_dim == 1) ? LE[1, i, j] : LE[i, j, 1]
            if isfinite(val)
                push!(points, GridPoint{FT}(FT(φ), FT(λ), i, j))
            end
        end
    end
    close(ds)
    return points
end

POINTS = load_fluxcom_subset(FLUXCOM_PATH)

# ---------- Forcing/Domain builders ----------
toml_dict = LP.create_toml_dict(FT)

"""
    build_column_forcing(lat, lon, start_date; FT)

Provide atmosphere & radiation forcing for a single column.
Replace this stub with your ERA5-Land loader that returns the ClimaLand
`prescribed` forcing structs expected by LandModel.
"""
function build_column_forcing(lat::FT, lon::FT, start_date::DateTime)
    # TODO: implement ERA5-Land loader -> (atmos, radiation) with hourly cadence
    error("Implement ERA5-Land loader for (lat=$lat, lon=$lon).")
end

"""
    build_column(lat, lon)

Creates a 2-m soil column at (lat, lon) and returns (domain, forcing, LAI).
"""
function build_column(lat::FT, lon::FT)
    zmin, zmax = FT(-2), FT(0)
    domain = Column(; zlim=(zmin, zmax), nelements=10, longlat=(lon, lat))  # (lon, lat)
    forcing = build_column_forcing(lat, lon, start_date)
    LAI = ClimaLand.Canopy.prescribed_lai_modis(domain.space.surface, start_date, stop_date)
    return (; domain, forcing, LAI)
end

# Cache column configs for each RS grid cell
@info "Building $(length(POINTS)) column configs…"
COLUMN_CFGS = Dict{Tuple{Int,Int}, Any}()
for p in POINTS
    COLUMN_CFGS[(p.i, p.j)] = build_column(p.lat, p.lon)
end

# ---------- Your Π-mapping (unchanged) ----------
# θ = [αR, βR, αF, βF, αT, βTs, αS, βSs]
unpack_params(θ::AbstractVector) = begin
    @assert length(θ) == 8
    αR, βR, αF, βF, αT, βTs, αS, βSs = θ
    (αR=αR, βR=βR, αF=αF, βF=βF, αT=αT, βTs=βTs, αS=αS, βSs=βSs)
end

function build_canopy(conductance, domain, forcing, LAI)
    space = ClimaLand.Domains.obtain_surface_domain(domain).space.surface
    defaults = ClimaLand.Canopy.clm_photosynthesis_parameters(space)
    farq = FarquharParameters(toml_dict; is_c3=defaults.is_c3, Vcmax25=defaults.Vcmax25)
    ClimaLand.Canopy.CanopyModel{FT}(
        ClimaLand.Domains.obtain_surface_domain(domain),
        (; forcing.atmos, forcing.radiation, ground = ClimaLand.PrognosticGroundConditions{FT}()),
        LAI, toml_dict;
        photosynthesis = FarquharModel{FT}(farq),
        prognostic_land_components = (:canopy, :snow, :soil, :soilco2),
        conductance,
    )
end

function model_LE_timeseries(θ::AbstractVector, cfg)::Vector{Float64}
    pθ = unpack_params(θ)
    ΓR = (FT(pθ.αR), FT(pθ.βR))
    ΓF = (FT(pθ.αF), FT(pθ.βF))
    ΓT = (FT(pθ.αT), FT(pθ.βTs))
    ΓS = (FT(pθ.αS), FT(pθ.βSs))

    uspac_pars = ClimaLand.Canopy.uSPACPiParameters{FT}(; ΓR, ΓF, ΓT, ΓS)
    conductance = ClimaLand.Canopy.uSPACConductancePi{FT}(uspac_pars)

    canopy = build_canopy(conductance, cfg.domain, cfg.forcing, cfg.LAI)
    land = LandModel{FT}(cfg.forcing, cfg.LAI, toml_dict, cfg.domain, Δt; canopy)

    set_ic! = () -> nothing  # optional: provide a smarter IC
    diagnostics = ClimaLand.default_diagnostics(
        land, start_date;
        output_writer = ClimaDiagnostics.Writers.DictWriter(),
        output_vars   = ["lhf"],   # ensure this is total LE, not transpiration-only
        reduction_period = :hourly,
    )

    sim = ClimaLand.Simulations.LandSimulation(
        start_date, stop_date, Second(Δt), land;
        set_ic! = set_ic!, updateat = Second(Δt), user_callbacks=(), diagnostics
    )
    solve!(sim)

    (times, le) = ClimaLand.Diagnostics.diagnostic_as_vectors(
        sim.diagnostics[1].output_writer, "lhf_1h_average" # name may differ; check your diagnostics
    )
    # Drop spinup window
    idx0 = findfirst((start_date+spinup) .<= times)
    idx1 = findlast(times .< stop_date)
    return Float64.(le[idx0:idx1])
end

# ---------- Global forward operator ----------
"""
    G_global(θ) -> vector

Runs all columns (subset of RS grid) and returns a single vector stacking
model LE in the same order as the RS time×lat×lon slice used below.
"""
function G_global(θ::AbstractVector)
    # Option A (simple, serial). Replace with Threads.@threads or pmap.
    outs = Vector{Float64}[]
    for p in POINTS
        cfg = COLUMN_CFGS[(p.i, p.j)]
        push!(outs, model_LE_timeseries(θ, cfg))
    end
    return vcat(outs...)
end

# ---------- Build observation vector from FLUXCOM ----------
"""
    RS_vector() -> vector

Loads the FLUXCOM LE over the same time window, extracts the same (i,j)
cells as POINTS, and stacks them in the same order/time cadence as G_global.
"""
function RS_vector()
    ds = NCDataset(FLUXCOM_PATH)
    # You may need to convert RS timestep (8-day or monthly) to hourly with simple
    # replication or, better, compare on a *daily* mean. Here we compare *daily means*.
    # 1) read RS daily or aggregate to daily; 2) aggregate model to daily too.
    # For brevity, we assume RS provides daily LE(time, lat, lon):
    LE  = ds["LE"]    # W m^-2
    time = ds["time"][:]
    # Map your start/stop to RS indices
    # (Implement your calendar handling here; omitted for brevity.)
    i0, i1 = 1, size(LE, 1)  # TODO: select matching range
    # Extract and stack
    outs = Vector{Float64}[]
    for p in POINTS
        # dims assumed (time, lat, lon)
        push!(outs, Float64.(LE[i0:i1, p.i, p.j]))
    end
    close(ds)
    return vcat(outs...)
end

observations = RS_vector()
N_obs = length(observations)

# Covariance: start simple; consider per-gridcell weights or block-diag later
noise_cov = 0.10 * EKP.I(N_obs)  # tune to RS uncertainty

# ---------- Priors & EKP ----------
names = ["alpha_R","beta_R","alpha_F","beta_F","alpha_T","beta_Ts","alpha_S","beta_Ss"]
prior = PD.combine_distributions([PD.constrained_gaussian(n, 0.0, 0.30, -Inf, Inf) for n in names])
N_ens, N_iter = 80, 4

ekp = EKP.EnsembleKalmanProcess(
    EKP.construct_initial_ensemble(rng, prior, N_ens),
    observations,
    noise_cov,
    EKP.Inversion();
    scheduler = EKP.DataMisfitController(terminate_at=Inf, on_terminate="continue"),
    rng,
)

@info "Starting global EKP: points=$(length(POINTS))"
for it in 1:N_iter
    θs = EKP.get_ϕ_final(prior, ekp)
    G_ens = hcat(map(G_global, eachcol(θs))...)
    EKP.update_ensemble!(ekp, G_ens)
end

# (Optional) visualize parameter evolution & misfit
# EKP.Visualize.plot_ϕ_over_iters(...); EKP.Visualize.plot_error_over_iters(...)
