# Minimal Working Example for Importance Sampling with Gen

In [None]:
include("../convective_adjustment_utils.jl")
using Gen

### Definition of the Initial Conditions & Global Parameters

In [None]:
grid = RegularGrid(Nz=32, Lz=128)
const stop_time = 1800
if USE_CPU
    T0 = zeros(Float64, grid.Nz)
else
    T0 = CUDA.zeros(Float64, grid.Nz)
end
z = zᶜ(grid)
temperature_gradient = 1e-4
surface_flux = 1e-4
T0 .= 20 .+ temperature_gradient .* z

### Definition of Baseline Scientific Model, in this Case a Convective Adjustment Example

In [None]:
function model(grid, surface_flux, T, convective_diffusivity, background_diffusivity)
    
    # Calculate Δt & Nt
    Δt = 0.2 * grid.Δz^2 / convective_diffusivity
    Nt = ceil(Int, stop_time / Δt)

    # Wrap into arrays for Enzyme
    convective_diffusivity = adapt(typeof(T), [convective_diffusivity])
    background_diffusivity = adapt(typeof(T), [background_diffusivity])

    for _ in 2:Nt
        T = convect!(T, grid, background_diffusivity, convective_diffusivity, surface_flux, Δt)
    end
    return T
end

### Construct the Generative Model on the Basis of the Baseline Model with Priors

In [None]:
@gen function convective_adjustment(grid, surface_flux, T)

    # Construct priors (random debug priors for now)
    convective_diffusivity = @trace(normal(10, 2), :convective_diffusivity)
    background_diffusivity = @trace(normal(1e-4, 3e-5), :background_diffusivity)

    # Do I need to genify this at this point?
    model(grid, surface_flux, T, convective_diffusivity, background_diffusivity)
end

### Generate the Dataset

Here we generate diversity in the dataset by varying the temperature gradient over which we place a normal prior to vary the start `T0`, hence inducing a further variation in the length of the simulation

In [None]:
function dataset_generation(datapoints::Int)

    # Define local convective diffusivity & background diffusivity
    local true_convective_diffusivity = 10
    local true_background_diffusivity = 1e-4

    # Vary T0 by sampling from a normal distribution over the temperature gradient
    function sample_TStart()
        T_start = 20 .+ normal(1e-4, 2e-4) .* z
        return T_start
    end
    
    # Generate the data with a list comprehension
    test_data = [
        model(grid, surface_flux, sample_TStart(), true_convective_diffusivity, true_background_diffusivity) for _ in 1:datapoints
    ]
    return test_data
end

After which we can freely select the number of datapoints we seek to generate, `500` is a choice of convenience for the time being.

In [None]:
ys = dataset_generation(500)

### Inference Program

We construct an inference program to perform [importance resampling](https://www.gen.dev/dev/ref/importance/) on the convective adjustment model. The dataset is taken as an input, on which a [choice-map](https://www.gen.dev/dev/ref/choice_maps/) is constructed by Gen. The inputs to the generative model are abstracted as `xs` in line with Gen's own nomenclature for convenience. The inference routine produces a [trace](https://www.gen.dev/dev/ref/gfi/#Traces-1), whose address space can later on be examined.

> An execution trace (or just trace) is a record of an execution of a generative function. Traces are the primary data structures manipulated by Gen inference programs.

In [None]:
function importance_sampling_inference(model, grid, surface_flux, T, ys, amount_of_computation)

    # Create the choice map to model addresses to observed
    # values ys[i]
    observations = Gen.choicemap()
    for (i, y) in enumerate(ys)
        observations[(:y, i)] = y
    end

    # In line with Gen's nomenclature we write our inputs as xs
    xs = (grid, surface_flux, T)

    # Perform importance sampling to find the most likely simulation trace
    # consistent with our observations
    (trace, _) = Gen.importance_resampling(model, xs, observations, amount_of_computation)
    return trace
end

With which we can then run the inference routine. The chosen amount of computation is once again a completely arbitrary choice.

In [None]:
trace = importance_sampling_inference(convective_adjustment, grid, surface_flux, T0, ys, 200)

### Analysis of the Trace

tbd...