Note: To speed up iteration and debugging, this notebook defines the custom modules inline instead of importing them from the external .jl files. Once the code stabilizes, I will switch back to proper `include("Module.jl")`-style imports.

# Packages

In [None]:
using Pkg

Pkg.activate(mktempdir())

Pkg.add([
    "Distributions",
    "StatsFuns",
    "DataStructures",
    "JSON",
    "StatsBase",
    "DataFrames",
    "Plots",
    "Zygote",
    "Optimisers",
    "SpecialFunctions"
])
Pkg.precompile()


 # Utils

In [2]:
module SampleHandling

using Random
using Statistics
using LinearAlgebra
using DataStructures: OrderedDict

# Flatten sample function
function flatten_sample(sample)
    if isa(sample, AbstractVector{<:AbstractVector})
        flat_sample = vcat([vec(element) for element in sample]...)
    else
        flat_sample = sample
    end
    return flat_sample
end

# Create unique list function
function create_unique_list(list_with_duplicates)
    return collect(OrderedDict(zip(list_with_duplicates, 1:length(list_with_duplicates))))
end

# Burn chain function
function burn_chain(samples, weights, burn_frac=nothing)
    if burn_frac !== nothing
        n = length(samples)
        nburn = Int(burn_frac * n)
        burned_samples = samples[nburn+1:end]
        burned_weights = weights[nburn+1:end]
        return burned_samples, burned_weights
    else
        return samples, weights
    end
end

end # module SampleHandling


Main.SampleHandling

In [3]:
module MathOps

using Random
using Statistics
using LinearAlgebra
using StatsFuns: logsumexp

# Softplus function
function softplus(x, beta=1.0, threshold=20.0)
    s = ifelse(x <= threshold, log(exp(beta * x) + 1.0) / beta, x)
    return s
end

# Inverse Softplus function
function inverse_softplus(s, beta=1.0, threshold=20.0)
    x = ifelse(s <= threshold, log(exp(beta * s) - 1.0) / beta, s)
    return x
end

# Covariance function
function covariance(x, y)
    return mean(x .* y, dims=1) - mean(x, dims=1) .* mean(y, dims=1)
end

end # module MathOps


Main.MathOps

In [4]:
module Plotting

using Random
using Statistics
using StatsBase
using DataFrames
ENV["GKSwstype"] = "100"
using Plots

# Function to create a histogram
function create_histogram(data, column, title)
    histogram(data[column], title=title)
end

# Function to create a scatter plot
function create_scatter(data, x_column, y_column, title)
    scatter(data[x_column], data[y_column], title=title)
end

# Function to create a heatmap
function create_heatmap(xlabels, ylabels, matrix, title)
    heatmap(xlabels, ylabels, matrix, title=title, show_text=true)
end

# Function to create a line plot
function create_line(data, x_column, y_column, title)
    plot(data[x_column], data[y_column], title=title)
end


end # module Plotting

Main.Plotting

# Distributions

In [5]:
module NormalDistribution

using Distributions: logpdf, Normal, ContinuousUnivariateDistribution
import Distributions: logpdf, Normal, pdf, quantile
using Main.MathOps: softplus, inverse_softplus, logsumexp
using Random
export NormalDist
import SpecialFunctions: erfinv
const scaling_function = "softplus"
positive_function, positive_inverse = if scaling_function == "softplus"
    (softplus, inverse_softplus)
elseif scaling_function == "exponential"
    (exp, log)
else
    error("Scaling function not recognized")
end

struct NormalDist{T<:Real} <: ContinuousUnivariateDistribution
    loc::T
    optim_scale::T
    function NormalDist(loc::T, scale::T; validate_args::Bool=false) where {T<:Real}
        new{T}(loc, positive_inverse(scale))
    end
end

function logpdf(d::NormalDist, x)
    σ = positive_function(d.optim_scale)
    return logpdf(Normal(d.loc, σ), x)
end

function params(d::NormalDist)
    return [d.loc, positive_function(d.optim_scale)]
end

function optim_params(d::NormalDist)
    return [d.loc, d.optim_scale]
end

# keeping a reference to the “inner” constructor
const _origNormalDist = NormalDist
# outer constructor: accept any Real pair, ignore kwargs, promote to Float64
function NormalDist(loc::Real, scale::Real; kwargs...)
    return _origNormalDist(float(loc), float(scale))
end


σ(d) = positive_function(d.optim_scale)

rand(rng::Random.AbstractRNG, d::NormalDist) =
    d.loc + σ(d) * randn(rng)

pdf(d::NormalDist, x) =  exp(logpdf(d, x))

quantile(d::NormalDist, p::Real) =
    d.loc + σ(d) * √2 * erfinv(2p - 1)

end # module NormalDistribution

Main.NormalDistribution

In [6]:
module CustomGamma

using Distributions: logpdf,
    Normal, Beta, Exponential, Uniform, Gamma, Dirichlet, Categorical, Bernoulli,
    ContinuousUnivariateDistribution, DiscreteUnivariateDistribution,
    ContinuousMultivariateDistribution, DiscreteMultivariateDistribution
import Distributions: logpdf, Normal
using Main.MathOps: softplus, inverse_softplus, logsumexp
using Random

const scaling_function = "softplus"
positive_function, positive_inverse = if scaling_function == "softplus"
    (softplus, inverse_softplus)
elseif scaling_function == "exponential"
    (exp, log)
else
    error("Scaling function not recognized")
end

struct GammaDist{T<:Real} <: ContinuousUnivariateDistribution
    concentration::T
    optim_rate::T
    function GammaDist(concentration::T, rate::T) where {T<:Real}
        new{T}(concentration, positive_inverse(rate))
    end
end

function logpdf(d::GammaDist, x)
    α = positive_function(d.concentration)
    β = positive_function(d.optim_rate)
    return logpdf(Gamma(α, β), x)
end

function params(d::GammaDist)
    return [positive_function(d.concentration), positive_function(d.optim_rate)]
end

function optim_params(d::GammaDist)
    return [d.concentration, d.optim_rate]
end

end # module CustomGamma


Main.CustomGamma

In [7]:
module CustomExponential

using Distributions: logpdf,
    Normal, Beta, Exponential, Uniform, Gamma, Dirichlet, Categorical, Bernoulli,
    ContinuousUnivariateDistribution, DiscreteUnivariateDistribution,
    ContinuousMultivariateDistribution, DiscreteMultivariateDistribution
import Distributions: logpdf, Normal
using Main.MathOps: softplus, inverse_softplus, logsumexp
using Random

const scaling_function = "softplus"
positive_function, positive_inverse = if scaling_function == "softplus"
    (softplus, inverse_softplus)
elseif scaling_function == "exponential"
    (exp, log)
else
    error("Scaling function not recognized")
end

struct ExponentialDist{T<:Real} <: ContinuousUnivariateDistribution
    optim_rate::T
    function ExponentialDist(rate::T) where {T<:Real}
        new{T}(positive_inverse(rate))
    end
end

function logpdf(d::ExponentialDist, x)
    λ = positive_function(d.optim_rate)
    return logpdf(Exponential(λ), x)
end

function params(d::ExponentialDist)
    return [positive_function(d.optim_rate)]
end

function optim_params(d::ExponentialDist)
    return [d.optim_rate]
end

end # module CustomExponential


Main.CustomExponential

In [8]:
module CustomDirichlet

using Distributions: logpdf,
    Normal, Beta, Exponential, Uniform, Gamma, Dirichlet, Categorical, Bernoulli,
    ContinuousUnivariateDistribution, DiscreteUnivariateDistribution,
    ContinuousMultivariateDistribution, DiscreteMultivariateDistribution
import Distributions: logpdf, Normal
using Main.MathOps: softplus, inverse_softplus, logsumexp
using Random

const scaling_function = "softplus"
positive_function, positive_inverse = if scaling_function == "softplus"
    (softplus, inverse_softplus)
elseif scaling_function == "exponential"
    (exp, log)
else
    error("Scaling function not recognized")
end

struct DirichletDist{T<:AbstractVector} <: ContinuousMultivariateDistribution
    optim_concentration::T
    function DirichletDist(concentration::T) where {T<:AbstractVector}
        new{T}(positive_inverse.(concentration))
    end
end

function logpdf(d::DirichletDist, x)
    α = positive_function.(d.optim_concentration)
    return Distributions.logpdf(Dirichlet(α), x)
end

function params(d::DirichletDist)
    return [positive_function.(d.optim_concentration)]
end

function optim_params(d::DirichletDist)
    return d.optim_concentration
end

end # module CustomDirichlet


Main.CustomDirichlet

In [9]:
module CustomBeta

using Distributions: logpdf,
    Normal, Beta, Exponential, Uniform, Gamma, Dirichlet, Categorical, Bernoulli,
    ContinuousUnivariateDistribution, DiscreteUnivariateDistribution,
    ContinuousMultivariateDistribution, DiscreteMultivariateDistribution
import Distributions: logpdf, Normal
using Main.MathOps: softplus, inverse_softplus, logsumexp
using Random

const scaling_function = "softplus"
positive_function, positive_inverse = if scaling_function == "softplus"
    (softplus, inverse_softplus)
elseif scaling_function == "exponential"
    (exp, log)
else
    error("Scaling function not recognized")
end

struct BetaDist{T<:Real} <: ContinuousUnivariateDistribution
    optim_concentration1::T
    optim_concentration0::T
    function BetaDist(concentration1::T, concentration0::T) where {T<:Real}
        new{T}(positive_inverse(concentration1), positive_inverse(concentration0))
    end
end

function logpdf(d::BetaDist, x)
    α = positive_function(d.optim_concentration1)
    β = positive_function(d.optim_concentration0)
    return logpdf(Beta(α, β), x)
end

function params(d::BetaDist)
    return [positive_function(d.optim_concentration1), positive_function(d.optim_concentration0)]
end

function optim_params(d::BetaDist)
    return [d.optim_concentration1, d.optim_concentration0]
end

end # module CustomBeta


Main.CustomBeta

In [10]:
module CustomCategorical

using Distributions: logpdf,
    Normal, Beta, Exponential, Uniform, Gamma, Dirichlet, Categorical, Bernoulli,
    ContinuousUnivariateDistribution, DiscreteUnivariateDistribution,
    ContinuousMultivariateDistribution, DiscreteMultivariateDistribution
import Distributions: logpdf, Normal
using Main.MathOps: softplus, inverse_softplus, logsumexp
using Random

struct CategoricalDist{T<:AbstractVector} <: DiscreteUnivariateDistribution
    logits::T
    function CategoricalDist(probs::T) where {T<:AbstractVector}
        logits = log.(probs) .- logsumexp(log.(probs))
        new{T}(logits .- logsumexp(logits))
    end
end

function logpdf(d::CategoricalDist, x)
    return logpdf(Categorical(logits=d.logits), x)
end

function params(d::CategoricalDist)
    return [d.logits]
end

function optim_params(d::CategoricalDist)
    return [d.logits]
end

end # module CustomCategorical


Main.CustomCategorical

In [11]:
module DiracDeltaDistribution

import Random: rand
using Distributions: logpdf,
    Normal, Beta, Exponential, Uniform, Gamma, Dirichlet, Categorical, Bernoulli,
    ContinuousUnivariateDistribution, DiscreteUnivariateDistribution,
    ContinuousMultivariateDistribution, DiscreteMultivariateDistribution
import Distributions: logpdf, Normal
using Main.MathOps: softplus, inverse_softplus, logsumexp
using Random

# Type representing a delta distribution
struct DeltaDistribution
    x::Any
end

rand(::Random.AbstractRNG, dist::DeltaDistribution) = dist.x
rand(dist::DeltaDistribution)                      = dist.x

# Log probability for the delta distribution
function logpdf(d::DeltaDistribution, x)
    return d.x == x ? Inf : -Inf
end

# Function to return different types of distributions based on the scheme
function dirac_delta_distribution(x...; scheme="normal")
    if scheme == "normal"
        return Normal(x[1], 0.1)
    elseif scheme == "uniform"
        return Uniform(x[1] - 0.05, x[1] + 0.05)
    elseif scheme == "delta"
        return DeltaDistribution(x[1])
    else
        error("Dirac delta scheme not recognized")
    end
end

end # module DiracDeltaDistribution


Main.DiracDeltaDistribution

In [12]:
module CustomBernoulli

using Distributions: logpdf,
    Normal, Beta, Exponential, Uniform, Gamma, Dirichlet, Categorical, Bernoulli,
    ContinuousUnivariateDistribution, DiscreteUnivariateDistribution,
    ContinuousMultivariateDistribution, DiscreteMultivariateDistribution
import Distributions: logpdf, Normal
using Main.MathOps: softplus, inverse_softplus, logsumexp
using Random

struct BernoulliDist{T<:Real} <: DiscreteUnivariateDistribution
    logits::T
    function BernoulliDist(probs::T) where {T<:Real}
        new{T}(log(probs / (1 - probs)))
    end
end

function logpdf(d::BernoulliDist, x)
    return logpdf(Bernoulli(logits=d.logits), x)
end

function params(d::BernoulliDist)
    return [d.logits]
end

function optim_params(d::BernoulliDist)
    return [d.logits]
end

end # module CustomBernoulli


Main.CustomBernoulli

# Main

In [13]:
module DistributionsPrep

using Distributions: Normal, Beta, Exponential, Uniform, Categorical, Bernoulli, Gamma, Dirichlet
using StatsFuns: logsumexp

using Random

# Import custom distributions
using Main.NormalDistribution: NormalDist
using Main.CustomGamma: GammaDist
using Main.CustomBeta: BetaDist
using Main.CustomExponential: ExponentialDist
using Main.CustomDirichlet: DirichletDist
using Main.CustomCategorical: CategoricalDist
using Main.CustomBernoulli: BernoulliDist
using Main.DiracDeltaDistribution: dirac_delta_distribution

# List of all supported distributions
distributions = [
    "normal",
    "beta",
    "exponential",
    "uniform-continuous",
    "discrete",
    "bernoulli",
    "gamma",
    "dirichlet",
    "flip",
    "dirac"
]

# Dictionary mapping distribution names to Julia distribution constructors
distribution_constructors = Dict(
    "normal" => NormalDist,
    "beta" => BetaDist,
    "exponential" => ExponentialDist,
    "uniform-continuous" => Uniform,
    "discrete" => CategoricalDist,
    "bernoulli" => BernoulliDist,
    "gamma" => GammaDist,
    "dirichlet" => DirichletDist,
    "flip" => BernoulliDist,
    "dirac" => dirac_delta_distribution
)

# Starting parameters for distributions for necessary cases
distribution_params = Dict(
    "normal-params" => (0.0, 1.0),
    "beta-params" => (1.0, 1.0),
    "exponential-params" => (1.0,),
    "uniform-continuous-params" => (0.0, 1.0),
    "discrete-params" => ([1/3, 1/3, 1/3],),
    "bernoulli-params" => (0.5,),
    "gamma-params" => (1.0, 1.0),
    "dirichlet-params" => ([1.0, 1.0, 1.0],),
    "flip-params" => (0.5,)
)

end # module DistributionsPrep


Main.DistributionsPrep

In [None]:
module BasicPrimitives

using Random
using LinearAlgebra
using Distributions
import Main.DistributionsPrep: distribution_constructors

# helper functions

vector(xs...) = collect(xs)

function _get(container, idx)
    if container isa AbstractArray
        return container[Int(idx)+1]
    elseif container isa Dict
        return container[idx]
    else
        error("`get` not defined for $(typeof(container))")
    end
end

function _put(container, idx, val)
    if container isa AbstractArray
        container[Int(idx)+1] = val
        return container
    elseif container isa Dict
        container[idx] = val
        return container
    else
        error("`put` not defined for $(typeof(container))")
    end
end

_append(a::AbstractVector, v) = (push!(a, v); a)
hashmap(kv...) = Dict(kv...)
_and(a,b) = a && b
_or(a,b)  = a || b
_mat_tanh(M) = tanh.(M)

# Convert a nested vector to a proper Matrix{Float64}
function to_matrix(V)
    V isa AbstractMatrix && return V
    rows = length(V)
    cols = length(V[1])
    M = Array{Float64}(undef, rows, cols)
    @inbounds for i in 1:rows, j in 1:cols
        M[i,j] = Float64(V[i][j])
    end
    return M
end

# Create a zero-based categorical distribution
function zero_based_categorical(p::AbstractVector{<:Real})
    cats = collect(0:length(p)-1)    # categories 0,1,2,...
    return DiscreteNonParametric(cats, p)
end


const primitives = Dict{String,Any}(

    # logic & comparison
    "<"   => (<),
    "<="  => (<=),
    ">"   => (>),
    ">="  => (>=),
    "="   => (==),
    "and" => _and,
    "or"  => _or,

    # arithmetic
    "+"   => (+),
    "-"   => (-),
    "*"   => (*),
    "/"   => (/),
    "exp" => exp,
    "sqrt"=> sqrt,
    "abs" => abs,

    # container ops
    "vector" => vector,
    "get"    => _get,
    "put"    => _put,
    "append" => _append,
    "first"  => xs -> xs[1],
    "second" => xs -> xs[2],
    "last"   => xs -> xs[end],
    "rest"   => xs -> xs[2:end],
    "hash-map" => hashmap,

    # matrix ops
    "mat-mul"       => (A,B)   -> to_matrix(A) * to_matrix(B),
    "mat-add"       => (A,B)   -> to_matrix(A) .+ to_matrix(B),
    "mat-transpose" => M       -> transpose(to_matrix(M)),
    "mat-repmat"    => (M,r,c) -> repeat(to_matrix(M), (Int(r),Int(c))),
    "mat-tanh"      => _mat_tanh,

    # “normal” primitive
    "normal"        => (μ, σ; kwargs...) -> Normal(float(μ), float(σ)),

    # discrete / categorical
    "discrete"       => (p; kwargs...) -> zero_based_categorical(p),
    "discrete-guide" => (p; kwargs...) -> zero_based_categorical(p),

    # all distribution constructors and their guides
    distribution_constructors...,
    [k * "-guide" => v for (k,v) in distribution_constructors]...,

    # domain-specific stub
    "oneplanet" => (args...)->error("`oneplanet` not implemented in Julia backend.")
)

export primitives

end # module


Main.BasicPrimitives

# Inference

In [15]:
module EvaluationBasedSampling

using Random
using Statistics
using LinearAlgebra
using Main.BasicPrimitives: primitives
using Main.DistributionsPrep: distributions
using Distributions

export AbstractSyntaxTree, eval, bind_functions, evaluate_program

# Define the AbstractSyntaxTree type
struct AbstractSyntaxTree
    functions::Vector{Any}
    program::Any

    function AbstractSyntaxTree(ast_json::Vector{Any})
        new(ast_json[1:end-1], ast_json[end])
    end
end

# Evaluate function
function eval(e, sig, l, rho=Dict{String,Any}(); verbose=false)
    if verbose println("Expression (before): ", e) end

    if e isa Number || e isa Bool
        result = e

    elseif e isa String
        result = l[e]

    elseif e isa Array
        if e[1] == "defn"
            error("This defn case should never happen!")

        elseif e[1] == "let"
            expression, name = e[2][2], e[2][1]
            c1 = eval(expression, sig, l, rho)
            l[name] = c1
            result = eval(e[3], sig, l, rho)

        elseif e[1] == "if"
            e1 = eval(e[2], sig, l, rho)
            result = e1 ? eval(e[3], sig, l, rho) : eval(e[4], sig, l, rho)

        elseif e[1] in ["sample", "sample*"]
            d = eval(e[2], sig, l, rho)
            s = rand(d)
            log_prob = logpdf(d, s)
            sig["logP"] += log_prob
            result = s

        elseif e[1] in ["observe", "observe*"]
            d = eval(e[2], sig, l, rho)
            y = eval(e[3], sig, l, rho)
            log_prob = logpdf(d, y)
            sig["logP"] += log_prob
            sig["logW"] += log_prob
            result = y

        else
            cs = [eval(element, sig, l, rho) for element in e[2:end]]

            if e[1] isa Array
                println("List: ", e[1])
                error("This list case should never happen!")

            elseif (e[1] isa String) && (haskey(rho, e[1]))
                variables, function_body = rho[e[1]]
                func_env = deepcopy(l)
                for (variable, exp) in zip(variables, cs)
                    func_env[variable] = exp
                end
                func_env[e[1]] = function_body
                result = eval(function_body, sig, func_env, rho)

            elseif (e[1] isa String) && (e[1] in distributions) && (haskey(primitives, e[1]))
                result = primitives[e[1]](cs...; validate_args=false)

            elseif (e[1] isa String) && (haskey(primitives, e[1]))
                result = primitives[e[1]](cs...)

            else
                println("List expression not recognised: ", e)
                error("List expression not recognised")
            end
        end

    else
        println("Expression not recognised: ", e)
        error("Expression not recognised")
    end

    if verbose
        println("Expression (after): ", e)
        println("Result: ", result, typeof(result))
    end

    return result
end

# Bind functions
function bind_functions(ast::AbstractSyntaxTree)
    rho = Dict{String,Any}()
    for e in ast.functions
        if e[1] == "defn"
            rho[e[2]] = (e[3], e[4])
        end
    end
    return rho
end

# Evaluate program
function evaluate_program(ast::AbstractSyntaxTree; verbose=false)
    sig = Dict("logW" => 0.0, "logP" => 0.0)
    l = Dict{String,Any}()
    rho = bind_functions(ast)
    e = eval(ast.program, sig, l, rho; verbose=verbose)
    return e, sig, l
end

end # module EvaluationBasedSampling

Main.EvaluationBasedSampling

In [16]:
module MiniTopologicalSorter

export TopologicalSorter, static_order

function _kahn_order(arrows::Dict)
    # make sure every node appears as a key
    for v in values(arrows), child in v
        arrows[child] = get(arrows, child, [])
    end

    # compute in-degrees
    indeg = Dict(n => 0 for n in keys(arrows))
    for v in values(arrows), child in v
        indeg[child] += 1
    end

    # queue all zero-in-degree nodes
    ready  = [n for n in keys(indeg) if indeg[n] == 0]
    sorted = Any[]

    # Kahn’s loop
    while !isempty(ready)
        n = popfirst!(ready)
        push!(sorted, n)

        for child in arrows[n]
            indeg[child] -= 1
            indeg[child] == 0 && push!(ready, child)
        end
    end

    length(sorted) == length(indeg) ||
        throw(ErrorException("cycle detected – graph is not a DAG"))

    return sorted        # parents → children
end

struct TopologicalSorter
    arrows::Dict{Any,Vector{Any}}
    order::Vector{Any}         # cached parent→child order
    function TopologicalSorter(arrows::Dict)
        new(arrows, _kahn_order(deepcopy(arrows)))
    end
end

"""
Return an iterator that yields the nodes in the same order as
`ts.order`.
"""
static_order(ts::TopologicalSorter) =
    (n for n in ts.order)

end # module MiniTopologicalSorter

Main.MiniTopologicalSorter

In [17]:
module GraphBasedSamplingUtils

using Random
using Statistics
using LinearAlgebra
using Printf
using Main.MiniTopologicalSorter: TopologicalSorter
using Main.BasicPrimitives: primitives
using Main.EvaluationBasedSampling: eval
using Zygote: gradient
using Optimisers: Adam
using Optimisers

export Graph, evaluate_node, evaluate_graph, generate_IC, log_joint

# Define the Graph type
mutable struct Graph
    functions::Vector{Any}
    nodes::Vector{Any}
    arrows::Dict{Any, Any}
    expressions::Dict{Any, Any}
    observe::Any
    program::Any

    function Graph(graph_json::Vector{Any})
        g = new(
            graph_json[1],
            graph_json[2]["V"],
            graph_json[2]["A"],
            graph_json[2]["P"],
            graph_json[2]["Y"],
            graph_json[3]
        )
        g.nodes = topological_sort(g)
        return g
    end
end

function topological_sort(g::Graph; verbose=false)
    for node in g.nodes
        if !haskey(g.arrows, node)
            g.arrows[node] = []
        end
    end
    if verbose println("arrows: ", g.arrows) end

    sorter = TopologicalSorter(g.arrows)
    sorted_list = collect(sorter)
    return reverse(sorted_list)
end

function split_nodes_into_sample_observe(g::Graph)
    sample_nodes, observe_nodes = [], []
    for node in g.nodes
        if occursin("sample", node)
            push!(sample_nodes, node)
        elseif occursin("observe", node)
            push!(observe_nodes, node)
        else
            error("Node present that is neither sample nor observe")
        end
    end
    return sample_nodes, observe_nodes
end

### Evaluation ###

function evaluate_node(node, exp, sig, l; fixed_dists=Dict(), fixed_nodes=Dict(), fixed_probs=Dict(), verbose=false)
    if verbose println("Node: ", node) end
    if haskey(fixed_dists, node)
        result = rand(fixed_dists[node])
        p_log_prob = logpdf(eval(exp[2], sig, l), result)
        q_log_prob = logpdf(fixed_dists[node], result)
        sig["logP"] += q_log_prob
        sig["logW"] += p_log_prob - q_log_prob
    elseif haskey(fixed_nodes, node) && haskey(fixed_probs, node)
        result = fixed_nodes[node]
        log_prob = fixed_probs[node]
        sig["logP"] += log_prob
        if occursin("observe", node)
            sig["logW"] += log_prob
        end
    elseif haskey(fixed_nodes, node)
        result = fixed_nodes[node]
        log_prob = logpdf(eval(exp[2], sig, l), result)
        sig["logP"] += log_prob
        if occursin("observe", node)
            sig["logW"] += log_prob
        end
    else
        result = eval(exp, sig, l, verbose=verbose)
    end
    if verbose println("Value: ", result) end
    return result
end

function evaluate_graph(g::Graph; fixed_dists=Dict(), fixed_nodes=Dict(), fixed_probs=Dict(), verbose=false)
    if verbose println(g) end

    sig = Dict("logW" => 0.0, "logP" => 0.0)
    l = Dict{Any, Any}()
    for node in g.nodes
        exp = g.expressions[node]
        original_logP = sig["logP"]
        result = evaluate_node(node, exp, sig, l, fixed_dists=fixed_dists, fixed_nodes=fixed_nodes, fixed_probs=fixed_probs, verbose=verbose)
        l[node] = result
        l[node * "_logP"] = sig["logP"] - original_logP
    end

    result = eval(g.program, sig, l, verbose=verbose)
    if verbose println("Result: ", result) end
    return result, sig, l
end

### Hamiltonian Monte Carlo ###

function generate_IC(g::Graph; verbose=false)
    _, _, l = evaluate_graph(g, verbose=verbose)
    start = [l[node] for node in g.nodes if occursin("sample", node)]
    if verbose println("Initial conditions: ", start) end
    return start
end

function log_joint(g::Graph, x; verbose=false)
    fixed_nodes = Dict()
    i = 1
    for node in g.nodes
        if occursin("sample", node)
            fixed_nodes[node] = x[i]
            i += 1
        end
    end
    _, sig, _ = evaluate_graph(g, fixed_nodes=fixed_nodes, verbose=verbose)
    log_joint = sig["logP"]
    return log_joint
end

### Variational Inference Utilities ###

function save_parameters(parameters::Vector, variationals::Dict)
    params_here = []
    for dist in values(variationals)
        params = [deepcopy(p) for p in dist.params()]
        append!(params_here, params)
    end
    push!(parameters, params_here)
    return parameters
end

function calculate_b(node::String, variational, logQs::Vector, logWs::Vector; zero=false)
    if zero
        b = 0.0
    else
        Fs, Gs = [], []
        for (logQ, logW) in zip(logQs, logWs)
            Q = logQ[node]
            grads = gradient(() -> Q, variational.optim_params())[1]
            G = length(grads) == 1 ? grads[1] : collect(grads)
            for param in variational.optim_params()
                zero!(param.grad)
            end
            F = G * logW
            push!(Fs, F)
            push!(Gs, G)
        end
        Fs = vcat(Fs...)
        Gs = vcat(Gs...)
        cov_FG = sum(covariance(Fs, Gs))
        var_GG = sum(covariance(Gs, Gs))
        b = var_GG == 0 ? 0.0 : cov_FG / var_GG
    end
    return b
end

function update_parameters(nodes::Vector{String}, variationals::Dict, logQs::Vector, logWs::Vector, optimizer; zero_b=false)
    total_ELBO, total_loss = 0.0, 0.0
    batch_size = length(logQs)
    for node in nodes
        b = calculate_b(node, variationals[node], logQs, logWs; zero=zero_b)
        ELBO, loss = 0.0, 0.0
        for (logQ, logW) in zip(logQs, logWs)
            ELBO -= logQ[node] * logW
            loss -= logQ[node] * (logW - b)
        end
        ELBO /= batch_size
        loss /= batch_size
        total_ELBO += ELBO
        total_loss += loss
    end
    gs = gradient(() -> total_loss, optim.params)[1]
    optim.state = Optimisers.update!(optim.opt, optim.state, optim.params, gs)
    return deepcopy(total_ELBO)
end

function initialize_optimizer(variationals::Dict, learning_rate::Float64)
    all_params = Float32[]
    for dist in values(variationals)
        append!(all_params, dist.optim_params())
    end
    opt = Adam(learning_rate)
    state = Optimisers.setup(opt, all_params)
    return (opt=opt, state=state, params=all_params)
end

end # module GraphBasedSamplingUtils




Main.GraphBasedSamplingUtils

In [18]:
module GibbsSampling

using Random
using Statistics
using LinearAlgebra
using Printf
using Main.GraphBasedSamplingUtils: Graph, evaluate_graph, split_nodes_into_sample_observe, evaluate_node
using Main.SampleHandling: burn_chain, flatten_sample
#using Main.Plotting: log_sample

export Gibbs_samples

### MH within Gibbs ###

function Gibbs_samples(g::Graph, num_samples; tmax=nothing, burn_frac=nothing, wandb_name=nothing, debug=false, verbose=false)
    sample_nodes, _ = split_nodes_into_sample_observe(g)

    samples, weights = [], []
    accepted_small_steps = 0; num_small_steps = 0; num_big_steps = 0
    max_time = tmax !== nothing ? time() + tmax : nothing

    for i in 1:num_samples
        if i == 1
            result, sig, l = evaluate_graph(g, verbose=verbose)
        else
            for resample_node in sample_nodes
                resample_logP = l[resample_node * "_logP"]
                sig_here, l_here = deepcopy(sig), deepcopy(l)
                d = eval(g.expressions[resample_node], sig_here, l_here)
                resample_logP_new = sig_here["logP"] - sig["logP"]
                fixed_nodes, fixed_probs = Dict(resample_node => d), Dict(resample_node => resample_logP_new)
                if debug
                    println("Original node value: ", l[resample_node])
                    println("Original node logP: ", resample_logP)
                    println("Resampled node value: ", d)
                    println("Resampled node logP: ", resample_logP_new)
                end

                for node in g.nodes
                    if node != resample_node
                        fixed_nodes[node] = l[node]
                        if !(node in g.arrows[resample_node])
                            fixed_probs[node] = l[node * "_logP"]
                        end
                    end
                end
                if debug
                    println("Fixed nodes: ", fixed_nodes)
                    println("Fixed probabilities: ", fixed_probs)
                end
                result_new, sig_new, l_new = evaluate_graph(g, fixed_nodes=fixed_nodes, fixed_probs=fixed_probs, verbose=verbose)
                if debug
                    println("Old sig: ", sig)
                    println("New sig: ", sig_new)
                    println("Old environment: ", l)
                    println("New environment: ", l_new)
                end

                acceptance = exp(sig_new["logP"] - sig["logP"] - resample_logP_new + resample_logP)
                alpha = min(1.0, acceptance)
                accept = rand() < alpha
                if accept
                    result, sig, l = result_new, sig_new, l_new
                    accepted_small_steps += 1
                end
                if wandb_name !== nothing
                    log_sample(result, i, wandb_name)
                end
                num_small_steps += 1
                if debug
                    break
                end
            end
            if debug
                break
            end
        end

        num_big_steps += 1
        push!(samples, result)
        push!(weights, 1.0)
        if tmax !== nothing && time() > max_time
            break
        end
    end

    @printf("Acceptance fraction: %.3f\n", accepted_small_steps / num_small_steps)
    println("Number of samples: ", num_big_steps)
    if burn_frac !== nothing
        println("Burn fraction: ", burn_frac)
        nburn = Int(burn_frac * num_big_steps)
        println("Burning up to: ", nburn)
        samples, weights = burn_chain(samples, weights, burn_frac)
    end

    return samples, weights
end

end # module GibbsSampling


Main.GibbsSampling

In [19]:
module GeneralSampling

using Random
using Statistics
using LinearAlgebra
using Printf
using Main.EvaluationBasedSampling: evaluate_program
using Main.GraphBasedSamplingUtils: evaluate_graph
using Main.SampleHandling: burn_chain, flatten_sample
# using Main.Plotting: log_sample

export get_sample, prior_samples, calculate_effective_sample_size, resample_using_importance_weights, Metropolis_Hastings_samples

function get_sample(ast_or_graph, mode::String; verbose=false)
    if mode == "desugar"
        ret, sig, _ = evaluate_program(ast_or_graph, verbose=verbose)
    elseif mode == "graph"
        ret, sig, _ = evaluate_graph(ast_or_graph, verbose=verbose)
    else
        error("Mode not recognised")
    end
    ret = flatten_sample(ret)
    return ret, sig
end

function prior_samples(ast_or_graph, mode::String, num_samples::Int; tmax=nothing, wandb_name=nothing, verbose=false)
    samples, weights = [], []
    max_time = tmax !== nothing ? time() + tmax : nothing
    for i in 1:num_samples
        sample, sig = get_sample(ast_or_graph, mode, verbose)
        weight = sig["logW"]
        if wandb_name !== nothing
            log_sample(sample, i, wandb_name=wandb_name)
        end
        push!(samples, sample)
        push!(weights, weight)
        if tmax !== nothing && time() > max_time
            break
        end
    end
    return samples, weights
end

function calculate_effective_sample_size(weights::Vector; verbose=false)
    norm_weights = weights ./ sum(weights)
    ESS = 1.0 / sum(norm_weights .^ 2)
    if verbose
        println("Effective sample size: ", ESS)
        println("Fractional sample size: ", ESS / length(norm_weights))
        println("Sum of weights: ", sum(norm_weights))
    end
    return ESS
end

function resample_using_importance_weights(samples, log_weights; normalize=true, wandb_name=nothing)
    nsamples = size(samples, 1)
    if normalize
        log_weights = log_weights .- maximum(log_weights)
    end
    weights = exp.(log_weights)
    ESS = calculate_effective_sample_size(weights, verbose=true)
    indices = sample(1:nsamples, Weights(weights), nsamples)
    new_samples = samples[indices, :]
    if wandb_name !== nothing
        for (i, sample) in enumerate(new_samples)
            log_sample(sample, i, wandb_name, resample=true)
        end
    end
    return new_samples
end

function Metropolis_Hastings_samples(ast_or_graph, mode::String, num_samples::Int; tmax=nothing, burn_frac=nothing, wandb_name=nothing, verbose=false)
    accepted_steps = 0
    num_steps = 0
    samples, weights = [], []
    max_time = tmax !== nothing ? time() + tmax : nothing
    old_sample, old_prob = nothing, nothing
    for i in 1:num_samples
        sample, sig = get_sample(ast_or_graph, mode, verbose)
        prob = exp(sig["logW"])
        if i != 1
            acceptance = min(1.0, prob / old_prob)
            accept = rand() < acceptance
        else
            accept = true
        end
        if accept
            new_sample = sample
            new_prob = prob
            accepted_steps += 1
        else
            new_sample = old_sample
            new_prob = old_prob
        end
        num_steps += 1
        if wandb_name !== nothing
            log_sample(sample, i, wandb_name)
        end
        push!(samples, new_sample)
        push!(weights, 1.0)
        old_sample, old_prob = new_sample, new_prob
        if tmax !== nothing && time() > max_time
            break
        end
    end
    println("Acceptance fraction: ", accepted_steps / num_steps)
    samples, weights = burn_chain(samples, weights, burn_frac=burn_frac)
    return samples, weights
end

end # module GeneralSampling


Main.GeneralSampling

# Tests

In [20]:
run(`bash -lc "sudo apt-get remove -y clojure"`)
run(`bash -lc "curl -fsSL https://download.clojure.org/install/linux-install.sh | sudo bash"`)
run(`bash -lc "git clone https://github.com/plai-group/daphne.git"`)

Reading package lists...
Building dependency tree...
Reading state information...
Package 'clojure' is not installed, so not removed
0 upgraded, 0 newly installed, 0 to remove and 34 not upgraded.
Downloading and expanding tar


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 14.4M  100 14.4M    0     0  27.8M      0 --:--:-- --:--:-- --:--:-- 70.7M


Installing libs into /usr/local/lib/clojure
Installing clojure and clj into /usr/local/bin
Installing man pages into /usr/local/share/man/man1
Removing download
Use clj -h for help.


Cloning into 'daphne'...


Process(`[4mbash[24m [4m-lc[24m [4m'git clone https://github.com/plai-group/daphne.git'[24m`, ProcessExited(0))

In [21]:
using JSON
using Main.EvaluationBasedSampling: AbstractSyntaxTree, evaluate_program
using Main.GeneralSampling: get_sample, calculate_effective_sample_size

In [27]:
model1 = """
(let [mu    (sample (normal 1 (sqrt 5)))
      sigma (sqrt 2)
      lik   (normal mu sigma)]
  (observe lik 8)
  (observe lik 9)
  mu)
"""
write("daphne/model1.daphne", model1)

cd("daphne") do
  run(`bash -lc "clojure -M:run desugar  -i model1.daphne  -o model1_ast.json"`)
end

ast = AbstractSyntaxTree(JSON.parsefile("daphne/model1_ast.json"))
μ, sig, _ = evaluate_program(ast; verbose=true)
println("\n→ Single draw μ = ", μ)
println("→ log-joint = ", sig["logP"])

N = 10000
samples = Vector{Float64}(undef, N)
logws   = Vector{Float64}(undef, N)
for i in 1:N
    s, sg = get_sample(ast, "desugar"; verbose=false)
    samples[i] = s
    logws[i]   = sg["logW"]
end

ws = exp.(logws .- maximum(logws))
ess = calculate_effective_sample_size(ws)
println("\nDrew $N samples; ESS = ", round(ess, digits=5))  # Extremely inefficient and no convergence



Expression (before): Any["let", Any["mu", Any["sample", Any["normal", 1, Any["sqrt", 5]]]], Any["let", Any["sigma", Any["sqrt", 2]], Any["let", Any["lik", Any["normal", "mu", "sigma"]], Any["let", Any["dontcare0", Any["observe", "lik", 8]], Any["let", Any["dontcare1", Any["observe", "lik", 9]], "mu"]]]]]
Expression (after): Any["let", Any["mu", Any["sample", Any["normal", 1, Any["sqrt", 5]]]], Any["let", Any["sigma", Any["sqrt", 2]], Any["let", Any["lik", Any["normal", "mu", "sigma"]], Any["let", Any["dontcare0", Any["observe", "lik", 8]], Any["let", Any["dontcare1", Any["observe", "lik", 9]], "mu"]]]]]
Result: 0.6023801023788922Float64

→ Single draw μ = 0.6023801023788922
→ log-joint = -35.58169191833986

Drew 10000 samples; ESS = 88.11749


In [23]:
model2 = """
(defn observe-data [_ data slope bias]
  (let [xn (first data)
        yn (second data)
        zn (+ (* slope xn) bias)]
    (observe (normal zn 1.0) yn)
    (rest (rest data))))
(let [slope (sample (normal 0.0 10.0))
      bias  (sample (normal 0.0 10.0))
      data  (vector 1.0 2.1 2.0 3.9 3.0 5.3
                   4.0 7.7 5.0 10.2 6.0 12.9)]
  (loop 6 data observe-data slope bias)
  (vector slope bias))
"""
write("daphne/model2.daphne", model2)

cd("daphne") do
  run(`bash -lc "clojure -M:run desugar  -i model2.daphne  -o model2_ast.json"`)
end

ast1 = AbstractSyntaxTree(JSON.parsefile("daphne/model2_ast.json"))
params1, sig1, _ = evaluate_program(ast1; verbose=true)
println("\n→ Sampled [slope, bias] = ", params1)
println("→ log-joint = ", sig1["logP"])



Expression (before): Any["let", Any["slope", Any["sample", Any["normal", 0.0, 10.0]]], Any["let", Any["bias", Any["sample", Any["normal", 0.0, 10.0]]], Any["let", Any["data", Any["vector", 1.0, 2.1, 2.0, 3.9, 3.0, 5.3, 4.0, 7.7, 5.0, 10.2, 6.0, 12.9]], Any["let", Any["dontcare1", Any["let", Any["a2", "slope"], Any["let", Any["a3", "bias"], Any["let", Any["acc4", Any["observe-data", 0, "data", "a2", "a3"]], Any["let", Any["acc5", Any["observe-data", 1, "acc4", "a2", "a3"]], Any["let", Any["acc6", Any["observe-data", 2, "acc5", "a2", "a3"]], Any["let", Any["acc7", Any["observe-data", 3, "acc6", "a2", "a3"]], Any["let", Any["acc8", Any["observe-data", 4, "acc7", "a2", "a3"]], Any["let", Any["acc9", Any["observe-data", 5, "acc8", "a2", "a3"]], "acc9"]]]]]]]]], Any["vector", "slope", "bias"]]]]]
Expression (after): Any["let", Any["slope", Any["sample", Any["normal", 0.0, 10.0]]], Any["let", Any["bias", Any["sample", Any["normal", 0.0, 10.0]]], Any["let", Any["data", Any["vector", 1.0, 2.1, 

In [24]:
model3 = """
(defn hmm-step [t states data trans-dists likes]
  (let [z (sample (get trans-dists (last states)))]
    (observe (get likes z) (get data t))
    (append states z)))
(let [data        [0.9 0.8 0.7 0.0 -0.025 -5.0 -2.0 -0.1
                   0.0 0.13 0.45 6.0 0.2 0.3 -1.0 -1.0]
      trans-dists [(discrete [0.10 0.50 0.40])
                   (discrete [0.20 0.20 0.60])
                   (discrete [0.15 0.15 0.70])]
      likes       [(normal -1.0 1.0)
                   (normal  1.0 1.0)
                   (normal  0.0 1.0)]
      states      [(sample (discrete [0.33 0.33 0.34]))]]
  (loop 16 states hmm-step data trans-dists likes))
"""

write("daphne/model3.daphne", model3)

cd("daphne") do
  run(`bash -lc "clojure -M:run desugar  -i model3.daphne  -o model3_ast.json"`)
end



# TODO: Update the following methods inside the primitives module itself
import Main.CustomCategorical: CategoricalDist as _origCatDist

function CategoricalDist(probs::AbstractVector; validate_args::Bool=false)
    return _origCatDist(probs)
end

Main.BasicPrimitives.primitives["discrete"] =
    (p; kwargs...) -> Main.BasicPrimitives.zero_based_categorical(p)

Main.BasicPrimitives.primitives["discrete-guide"] =
    (p; kwargs...) -> Main.BasicPrimitives.zero_based_categorical(p)



#11 (generic function with 1 method)

In [25]:
using JSON
ast_json = JSON.parsefile("daphne/model3_ast.json")
ast      = Main.EvaluationBasedSampling.AbstractSyntaxTree(ast_json)

result, sig, l = Main.EvaluationBasedSampling.evaluate_program(ast; verbose=true)

println("\nResult                  = ", result)
println("Log-joint probability   = ", sig["logP"])

Expression (before): Any["let", Any["data", Any["vector", 0.9, 0.8, 0.7, 0.0, -0.025, -5.0, -2.0, -0.1, 0.0, 0.13, 0.45, 6.0, 0.2, 0.3, -1.0, -1.0]], Any["let", Any["trans-dists", Any["vector", Any["discrete", Any["vector", 0.1, 0.5, 0.4]], Any["discrete", Any["vector", 0.2, 0.2, 0.6]], Any["discrete", Any["vector", 0.15, 0.15, 0.7]]]], Any["let", Any["likes", Any["vector", Any["normal", -1.0, 1.0], Any["normal", 1.0, 1.0], Any["normal", 0.0, 1.0]]], Any["let", Any["states", Any["vector", Any["sample", Any["discrete", Any["vector", 0.33, 0.33, 0.34]]]]], Any["let", Any["a1", "data"], Any["let", Any["a2", "trans-dists"], Any["let", Any["a3", "likes"], Any["let", Any["acc4", Any["hmm-step", 0, "states", "a1", "a2", "a3"]], Any["let", Any["acc5", Any["hmm-step", 1, "acc4", "a1", "a2", "a3"]], Any["let", Any["acc6", Any["hmm-step", 2, "acc5", "a1", "a2", "a3"]], Any["let", Any["acc7", Any["hmm-step", 3, "acc6", "a1", "a2", "a3"]], Any["let", Any["acc8", Any["hmm-step", 4, "acc7", "a1", "a2

In [26]:
import Main.BasicPrimitives
import Distributions

# TODO: Update the following methods inside the primitives module itself

BasicPrimitives.primitives["normal"] =
    (μ, σ; kwargs...) -> Distributions.Normal(float(μ), float(σ))

BasicPrimitives.primitives["normal-guide"] =
    BasicPrimitives.primitives["normal"]


#14 (generic function with 1 method)

In [None]:
model4 = """
(let [weight-prior (normal 0 1)
      W_0 (foreach 10 [] (foreach 1 [] (sample weight-prior)))
      W_1 (foreach 10 [] (foreach 10 [] (sample weight-prior)))
      W_2 (foreach 1 [] (foreach 10 [] (sample weight-prior)))

      b_0 (foreach 10 [] (foreach 1 [] (sample weight-prior)))
      b_1 (foreach 10 [] (foreach 1 [] (sample weight-prior)))
      b_2 (foreach 1 [] (foreach 1 [] (sample weight-prior)))

      x   (mat-transpose [[1] [2] [3] [4] [5]])
      y   [[1] [4] [9] [16] [25]]
      h_0 (mat-tanh (mat-add (mat-mul W_0 x)
                             (mat-repmat b_0 1 5)))
      h_1 (mat-tanh (mat-add (mat-mul W_1 h_0)
                             (mat-repmat b_1 1 5)))
      mu  (mat-transpose
            (mat-tanh (mat-add (mat-mul W_2 h_1)
                               (mat-repmat b_2 1 5))))]
  (foreach 5 [y_r y
              mu_r mu]
    (foreach 1 [y_rc y_r
                mu_rc mu_r]
      (observe (normal mu_rc 1) y_rc)))
  [W_0 b_0 W_1 b_1])
"""
write("daphne/model4.daphne", model4)

cd("daphne") do
  run(`bash -lc "clojure -M:run desugar -i model4.daphne -o model4_ast.json"`)
end

ast3 = AbstractSyntaxTree(JSON.parsefile("daphne/model4_ast.json"))
params3, sig3, _ = evaluate_program(ast3; verbose=true)

println("\n→ Sampled [W_0, b_0, W_1, b_1] shapes/types = ", params3)
println("→ log-joint = ", sig3["logP"])