In [None]:
#!/usr/bin/env julia
# =============================================================================
# Global uSPAC calibration against tower LE (multi-site) using EKP
# Parameters (8 total, shared across all sites):
#   θ = [
#     "alpha_R", "beta_R",    # ΠR ~ aridity
#     "alpha_F", "beta_F",    # ΠF ~ aridity
#     "alpha_T", "beta_Ts",   # ΠT ~ sand
#     "alpha_S", "beta_Ss"    # ΠS ~ sand
# ]
# Forward model:
#   For each site: build LandModel with uSPACConductancePi, run, take 24-hour diurnal
#   Concatenate all sites’ diurnals into one long observation vector (24*Nsites).
# =============================================================================

using ClimaLand
using ClimaLand.Domains: Column
using ClimaLand.Canopy
using ClimaLand.Simulations
import ClimaLand.FluxnetSimulations as FNS
import ClimaLand.Parameters as LP
import ClimaLand.LandSimVis as LandSimVis
import ClimaDiagnostics
import ClimaUtilities.TimeManager as TM
import ClimaUtilities.TimeManager: date
import EnsembleKalmanProcesses as EKP
import EnsembleKalmanProcesses.ParameterDistributions as PD
using CairoMakie
using Statistics
using Logging
import Random
using Dates

# ---------------- Helpers ----------------

# Compute hourly diurnal mean (vector length 24) between spinup_date and stop_date
function get_diurnal_average(var, start_date, spinup_date, stop_date)
    (times, data) = var
    model_dates = times isa Vector{DateTime} ? times : date.(times)
    spinup_idx = findfirst(spinup_date .<= model_dates)
    stop_idx   = findlast(model_dates .< stop_date)
    model_dates = model_dates[spinup_idx:stop_idx]
    data        = data[spinup_idx:stop_idx]
    hour_of_day = Hour.(model_dates)
    return [mean(data[hour_of_day .== Hour(i)]) for i in 0:23]
end

# Extract 30-min averaged latent heat flux from diagnostics
get_lhf(simulation) = ClimaLand.Diagnostics.diagnostic_as_vectors(
    simulation.diagnostics[1].output_writer, "lhf_30m_average"
)

# ---------------- Global configuration ----------------

const FT = Float32
rng_seed = 1234
rng = Random.MersenneTwister(rng_seed)

# Dates (must be within each site’s forcing range)
start_date = DateTime(2010, 5, 1)
stop_date  = DateTime(2010, 7, 1)
Δt = 450.0 # seconds
spinup = Day(20)

# Sites to include (add/remove as needed; ensure forcing is available for each)
const SITES = [
    "US-MOz",
    # "US-Var",
    # "US-SRM",
    # "DE-Hai",
    # ...
]

# ---------------- Site bundle build ----------------

toml_dict = LP.create_toml_dict(FT)

struct SiteConfig{FT}
    site_id::String
    site_id_val::Symbol
    lat::FT
    long::FT
    time_offset::FT
    atmos_h::FT
    domain::Any
    forcing::Any
    LAI::Any
end

function build_site_config(site_id::String)::SiteConfig{FT}
    site_id_val = FNS.replace_hyphen(site_id)
    (; time_offset, lat, long) = FNS.get_location(FT, Val(site_id_val))
    (; atmos_h) = FNS.get_fluxtower_height(FT, Val(site_id_val))

    # Domain: 2 m soil column, 10 elements (per-site long/lat)
    zmin, zmax = FT(-2), FT(0)
    domain = Column(; zlim=(zmin, zmax), nelements=10, longlat=(long, lat))

    # Forcing & LAI
    forcing = FNS.prescribed_forcing_fluxnet(
        site_id, lat, long, time_offset, atmos_h, start_date, toml_dict, FT
    )
    LAI = ClimaLand.Canopy.prescribed_lai_modis(domain.space.surface, start_date, stop_date)

    return SiteConfig{FT}(site_id, site_id_val, lat, long, time_offset, atmos_h, domain, forcing, LAI)
end

SITE_CFGS = [build_site_config(s) for s in SITES]

# ---------------- Parameters ----------------

# θ = [αR, βR, αF, βF, αT, βTs, αS, βSs]
unpack_params(θ::AbstractVector) = begin
    @assert length(θ) == 8
    αR, βR, αF, βF, αT, βTs, αS, βSs = θ
    (αR=αR, βR=βR, αF=αF, βF=βF, αT=αT, βTs=βTs, αS=αS, βSs=βSs)
end

# ---------------- Per-site model ----------------

function model(θ::AbstractVector, cfg::SiteConfig{FT})
    pθ = unpack_params(θ)
    ΓR = (FT(pθ.αR), FT(pθ.βR))     # aridity → ΠR
    ΓF = (FT(pθ.αF), FT(pθ.βF))     # aridity → ΠF
    ΓT = (FT(pθ.αT), FT(pθ.βTs))    # sand    → ΠT
    ΓS = (FT(pθ.αS), FT(pθ.βSs))    # sand    → ΠS

    uspac_pars = ClimaLand.Canopy.uSPACPiParameters{FT}(; ΓR, ΓF, ΓT, ΓS)
    conductance = ClimaLand.Canopy.uSPACConductancePi{FT}(uspac_pars)
    prognostic_land_components = (:canopy, :snow, :soil, :soilco2)

    canopy_domain  = ClimaLand.Domains.obtain_surface_domain(cfg.domain)
    canopy_forcing = (; cfg.forcing.atmos, cfg.forcing.radiation, ground = ClimaLand.PrognosticGroundConditions{FT}())

    space    = canopy_domain.space.surface
    defaults = ClimaLand.Canopy.clm_photosynthesis_parameters(space)
    farq_params   = FarquharParameters(toml_dict; is_c3 = defaults.is_c3, Vcmax25 = defaults.Vcmax25)
    photosynthesis = FarquharModel{FT}(farq_params)

    canopy = ClimaLand.Canopy.CanopyModel{FT}(
        canopy_domain,
        canopy_forcing,
        cfg.LAI,
        toml_dict;
        photosynthesis,
        prognostic_land_components,
        conductance,
    )

    land_model = LandModel{FT}(cfg.forcing, cfg.LAI, toml_dict, cfg.domain, Δt; canopy)

    set_ic! = FNS.make_set_fluxnet_initial_conditions(cfg.site_id, start_date, cfg.time_offset, land_model)

    output_vars = ["lhf"]
    diagnostics = ClimaLand.default_diagnostics(
        land_model, start_date;
        output_writer = ClimaDiagnostics.Writers.DictWriter(),
        output_vars,
        reduction_period = :halfhourly,
    )

    simulation = Simulations.LandSimulation(
        start_date, stop_date, Second(Δt), land_model;
        set_ic! = set_ic!,
        updateat = Second(Δt),
        user_callbacks = (),
        diagnostics = diagnostics,
    )
    solve!(simulation)
    return simulation
end

# ---------------- Global forward operator ----------------

# G_multi(θ): run ALL sites; return vcat of each site’s 24-value LE diurnal
function G_multi(θ::AbstractVector)
    out = Vector{Float64}[]
    for cfg in SITE_CFGS
        sim = model(θ, cfg)
        lhf = get_lhf(sim)
        diurn = Float64.(get_diurnal_average(lhf, start_date, start_date + spinup, stop_date))
        push!(out, diurn)
    end
    return vcat(out...)  # length = 24 * Nsites
end

# ---------------- Observations (concatenate across sites) ----------------

function load_obs_diurnals()
    outs = Vector{Float64}[]
    for cfg in SITE_CFGS
        ds = FNS.get_comparison_data(cfg.site_id, cfg.time_offset)
        diurn = get_diurnal_average((ds.UTC_datetime, ds.lhf), start_date, start_date + spinup, stop_date)
        push!(outs, Float64.(diurn))
    end
    return vcat(outs...)
end

observations = load_obs_diurnals()
N_obs = length(observations) # = 24 * length(SITES)

# Simple homoskedastic noise (tune per site if you’d like block-diagonal)
noise_covariance = (0.05) * EKP.I(N_obs)

# ---------------- Priors & EKP config ----------------

names = [
    "alpha_R", "beta_R",
    "alpha_F", "beta_F",
    "alpha_T", "beta_Ts",
    "alpha_S", "beta_Ss",
]
priors = PD.ParameterDistribution[
    PD.constrained_gaussian(names[i], 0.0, 0.30, -Inf, Inf) for i in 1:8
]
prior = PD.combine_distributions(priors)

ensemble_size = 80   # scale with problem size
N_iterations = 4

initial_ensemble = EKP.construct_initial_ensemble(rng, prior, ensemble_size)
ekp = EKP.EnsembleKalmanProcess(
    initial_ensemble, observations, noise_covariance, EKP.Inversion();
    scheduler = EKP.DataMisfitController(terminate_at = Inf, on_terminate = "continue"),
    rng,
)

# ---------------- EKP loop ----------------

function do_calibration!(ekp)
    Logging.with_logger(SimpleLogger(devnull, Logging.Error)) do
        for it in 1:N_iterations
            @info "Iteration $it (N_obs=$(length(observations)), N_sites=$(length(SITE_CFGS)))"
            θs = EKP.get_ϕ_final(prior, ekp)           # size: N_par × N_ens
            G_ens = hcat(map(G_multi, eachcol(θs))...)  # size: N_obs × N_ens
            EKP.update_ensemble!(ekp, G_ens)
        end
    end
end

do_calibration!(ekp)


In [None]:
# ---------------- Results viz ----------------

dim_size = sum(length.(EKP.batch(prior)))
fig = CairoMakie.Figure(size = ((dim_size + 1) * 500, 500))
for i in 1:dim_size
    EKP.Visualize.plot_ϕ_over_iters(fig[1, i], ekp, prior, i)
end
EKP.Visualize.plot_error_over_iters(fig[1, dim_size + 1], ekp)
CairoMakie.save("constrained_params_and_error_global.png", fig)
fig