# Composable Callables & Full‑Stack `GrowthModel` — Practical Guide

This notebook shows how to:

1. **Wrap scientific models as `Callable`s** (e.g., Elfving 2010 DBH increment, Söderberg 1986 heights,
   Fridman–Ståhl 2006 mortality, and Edgren–Nylinder 1949 taper + a `PriceList`).
2. Build a **full‑stack `GrowthModel`** that runs all components inside `grow()`.
3. Use the **factory** to adapt Angle‑Count stands (Bitterlich) to safe working inventories.
4. Run the **DSL** controller (triggers/schedules).
5. Broadcast aggregate runs with the **ContextEnsemble** (Python/Numba/JAX engines).

> All types and utilities come from `pyforestry.base.simulation` & `pyforestry.base.helpers`.
Everything lives in `base/` as per the project conventions.

## 1) Imports & quick recap

- `GrowthModel` is the **factory + behavior**.
- `SimulationContext` is the **sandbox** (all mutation + history live here).
- `SimulationSetup` is the **controller** (pre/post triggers, schedules).
- `AdapterRegistry` provides **Angle‑Count → pseudo tree‑list/spatial/diameter‑class** adapters.
- `ContextEnsemble` broadcasts aggregate steps through a **batch engine** (Python/Numba/JAX).

In [4]:
from typing import Callable, Optional, Any, Dict
from dataclasses import dataclass

from pyforestry.base.simulation import (
    GrowthModel, Requirements, SimulationSetup, TriggerSpec, SimulationContext,
    AdapterRegistry, ContextEnsemble, PythonEngine
)
from pyforestry.base.helpers import (
    Stand, CircularPlot, Tree, PICEA_ABIES, PINUS_SYLVESTRIS, AngleCount
)

ModuleNotFoundError: No module named 'pyforestry'

## 2) Define the scientific components as `Callable`s

We keep each study/model as a function. In production, you’ll call your real implementations here.

In [None]:
# Signatures (type hints are for clarity; not strictly required)
ElfvingDbhIncrementFn = Callable[[Any, float, Optional[float], Any, float, Dict[str, Any]], float]
SoderbergHeightFn     = Callable[[Any, float, Any, Optional[float]], float]
FridmanStahlSurvivalFn= Callable[[Any, float, Optional[float], Any, float, Dict[str, Any]], float]
EdgrenNylinderVolFn   = Callable[[Any, float, float, Any], float]
PriceFn               = Callable[[Any, float, Any], float]

# --- Replace these with your calibrated implementations ---
def elfving_dbh_increment(sp, dbh_cm, h_m, site, dt, state):
    # Δdbh in cm over dt years (placeholder)
    return 0.25 * dt

def soderberg_height(sp, dbh_cm, site, age):
    # height in m from DBH (placeholder)
    return max(1.3, 1.3 + 0.6 * (dbh_cm ** 0.5))

def fridman_stahl_survival(sp, dbh_cm, h_m, site, dt, state):
    # survival fraction in [0,1] for the step (placeholder)
    return max(0.0, min(1.0, 1.0 - 0.006 * dt))

def edgren_nylinder_volume(sp, dbh_cm, h_m, site):
    # taper-based whole-stem volume in m3 (placeholder)
    return 0.00007854 * (dbh_cm ** 2) * h_m

def price_list(sp, vol_m3, site):
    # SEK per m3 (placeholder)
    return 500.0

## 3) A full‑stack `GrowthModel` that chains the callables inside `grow()`

This model prefers `tree_list`/`spatial` inventories. If the input Stand uses **Angle‑Count**, we’ll ask the factory to adapt to a pseudo tree‑list so per-tree operations work safely.

In [None]:
@dataclass
class FullStackCallableModel(GrowthModel):
    dbh_increment_fn: ElfvingDbhIncrementFn
    height_fn:       SoderbergHeightFn
    survival_fn:     FridmanStahlSurvivalFn
    volume_fn:       EdgrenNylinderVolFn
    price_fn:        PriceFn
    remove_zero_weight: bool = True

    def requirements(self) -> Requirements:
        # We need per-tree inventories to run taper/price properly.
        return Requirements(inventory="tree_list")

    def grow(self, ctx: SimulationContext, dt: float) -> None:
        if ctx.mode not in ("tree_list", "spatial"):
            raise RuntimeError(f"{self.__class__.__name__} requires per-tree mode; got {ctx.mode}.")
        ctx.state["years_since_thin"] = ctx.state.get("years_since_thin", 0.0) + dt
        step_value = 0.0
        for p in ctx.plots:
            for t in p.trees:
                sp = getattr(t, "species", None)
                if sp is None:
                    continue
                dbh = float(getattr(t, "diameter_cm", 0.0) or 0.0)
                h   = float(getattr(t, "height_m",   0.0) or 0.0)
                age = getattr(t, "age", None)

                # 1) DBH increment (Elfving 2010)
                ddbh = float(self.dbh_increment_fn(sp, dbh, (h if h > 0 else None), ctx.site, dt, ctx.state))
                dbh = max(0.0, dbh + ddbh)
                t.diameter_cm = dbh

                # 2) Height update (Söderberg 1986)
                h = float(self.height_fn(sp, dbh, ctx.site, age))
                t.height_m = h

                # 3) Mortality (Fridman–Ståhl 2006)
                surv = float(self.survival_fn(sp, dbh, h, ctx.site, dt, ctx.state))
                surv = 0.0 if surv < 0.0 else (1.0 if surv > 1.0 else surv)
                t.weight_n = float(getattr(t, "weight_n", 1.0)) * surv

                # 4) Value (Edgren–Nylinder 1949 taper + PriceList)
                if dbh > 0.0 and h > 0.0 and t.weight_n > 0.0:
                    vol_m3 = float(self.volume_fn(sp, dbh, h, ctx.site))
                    price  = float(self.price_fn(sp, vol_m3, ctx.site))
                    step_value += vol_m3 * price * float(t.weight_n)

        if self.remove_zero_weight:
            for p in ctx.plots:
                p.trees = [t for t in p.trees if float(getattr(t, "weight_n", 0.0) or 0.0) > 1e-9]

        ctx.state["last_step_value_SEK_per_ha"] = step_value
        ctx.state["cum_value_SEK_per_ha"] = ctx.state.get("cum_value_SEK_per_ha", 0.0) + step_value

## 4) Build a Stand and run the model via the DSL

We’ll use a small synthetic tree‑list Stand. The `GrowthModel` factory builds a sandboxed `SimulationContext` for us and the `SimulationSetup` orchestrates the run.

In [None]:
# --- A tiny tree-list stand ---
stand = Stand(
    area_ha=1.0,
    plots=[
        CircularPlot(id=1, area_m2=200.0, trees=[
            Tree(species=PICEA_ABIES, diameter_cm=20.0, height_m=15.0, weight_n=6),
            Tree(species=PICEA_ABIES, diameter_cm=18.0, height_m=13.0, weight_n=5),
        ])
    ],
)

model = FullStackCallableModel(
    dbh_increment_fn=elfving_dbh_increment,
    height_fn=soderberg_height,
    survival_fn=fridman_stahl_survival,
    volume_fn=edgren_nylinder_volume,
    price_fn=price_list,
)

ok, missing = model.can_build(stand, allow_adapters=True, mode_hint="tree_list")
assert ok, f"missing: {missing}"
ctx = model.build_context(stand, mode_hint="tree_list")

setup = SimulationSetup(start_t=0.0, end_t=10.0, dt=1.0)
setup.run(ctx)
print("Cumulative value (SEK/ha):", ctx.state.get("cum_value_SEK_per_ha", 0.0))
ctx.to_pandas().tail()

## 5) Angle‑Count stands → pseudo tree‑lists (safe adapter)

If the Stand carries Bitterlich tallies, you can still run per‑tree logic by **opting in** to a pseudo tree‑list at build time (your original Stand stays immutable).

In [None]:
# --- Bitterlich tally stand ---
ac_plot = CircularPlot(id="ac1", area_m2=10000.0, AngleCount=[
    AngleCount(ba_factor=10.0, species=[PICEA_ABIES, PINUS_SYLVESTRIS], value=[3.0, 2.0])
])
ac_stand = Stand(area_ha=1.0, plots=[ac_plot])

ok, missing = model.can_build(ac_stand, allow_adapters=True, mode_hint="tree_list")
assert ok, missing
ctx_ac = model.build_context(ac_stand, mode_hint="tree_list")  # auto AC→pseudo tree-list

SimulationSetup(0.0, 5.0, 1.0).run(ctx_ac)
print("Inventory origin:", ctx_ac.attrs.get("inventory_origin"))
ctx_ac.to_pandas().head(3)

## 6) Using triggers/schedules (DSL)

You can still add management actions with the DSL. Here we add a simple **post** trigger that prints a message whenever QMD exceeds a threshold (placeholder for a thinning action).

In [None]:
def qmd_exceeds(ctx: SimulationContext, threshold_cm: float = 22.0) -> bool:
    return float(ctx.metrics["QMD"]["TOTAL"]) > threshold_cm

def announce(ctx: SimulationContext):
    print(f"t={ctx.state['t']:.1f} → QMD now {float(ctx.metrics['QMD']['TOTAL']):.2f} cm")

trg = TriggerSpec(
    name="qmd_watch",
    check_phase="post",
    predicate=lambda c: qmd_exceeds(c, 22.0),
    action=announce,
    once=False,
)

ctx2 = model.build_context(stand, mode_hint="tree_list")
SimulationSetup(start_t=0.0, end_t=5.0, dt=1.0, triggers=[trg]).run(ctx2)

## 7) Vectorized aggregate runs with `ContextEnsemble`

If you also keep an **aggregate** approximation of your model, you can opt in to the batch engine by
implementing `has_batch_engine()` and `batch_grow_step(ba, n, dt, fert_mask)`.

In [None]:
import numpy as np

class AggregateBANModel(GrowthModel):
    def __init__(self, ba_rel_per_year=0.03, mort_per_year=0.006):
        self.ba_rel = ba_rel_per_year
        self.mort = mort_per_year
    def requirements(self) -> Requirements:
        return Requirements(inventory="aggregate")
    def has_batch_engine(self) -> bool:
        return True
    def batch_grow_step(self, ba: np.ndarray, n: np.ndarray, dt: float, fert_mask: np.ndarray):
        # Simple BA relative growth and mortality on N (vectorized)
        ba2 = ba * (1.0 + self.ba_rel * dt)
        n2  = n  * (1.0 - self.mort * dt)
        return ba2, n2
    def grow(self, ctx: SimulationContext, dt: float) -> None:
        # Fallback scalar (rarely used when batched)
        ba = float(ctx.metrics["BasalArea"]["TOTAL"]) * (1.0 + self.ba_rel * dt)
        n  = float(ctx.metrics["Stems"]["TOTAL"]) * (1.0 - self.mort * dt)
        ctx.set_aggregate_metrics(ba_total=ba, stems_total=n)

agg_model = AggregateBANModel()

# Build a few contexts in aggregate mode (force with mode_hint)
stands = [stand, stand]  # reuse same stand for brevity
agg_contexts = []
for s in stands:
    ok, _ = agg_model.can_build(s, mode_hint="aggregate")
    ctxa = agg_model.build_context(s, mode_hint="aggregate")
    agg_contexts.append(ctxa)

ens = ContextEnsemble(contexts=agg_contexts, model=agg_model)
for _ in range(5):
    ens.grow(1.0)

ens.to_pandas().tail()

## 8) Wrap‑up

- Keep **each scientific component** testable as a `Callable`.
- The **full‑stack model** chains them in one place (`grow`).
- The **factory** adapts inventories (Angle‑Count → pseudo tree‑list) when you ask for per‑tree modes.
- The **DSL** orchestrates **when** things happen; history is logged on every step.
- For many aggregate scenarios, opt‑in to the **batch engine** by exposing `batch_grow_step`.