# Implementions

So far, we have been using what is known as Gen's dynamic modelling language which provides an implementation of the `@gen` macro to define generative functions. In this lecture, we will examine how these work under the hood. First, let's load Gen and the `Distributions` Julia library which provides basic sampling and scoring support for commonly used probability distributions.

In [1]:
using Gen
using Distributions
using PyPlot

In order to understand how Gen works, we will introduce a simplfied variant of the language. First, instead of using the `@gen` macro to define generative functions, we will define them directly. Amongst other things, the macro can be understood as desugaring the `~` operator. Instead of using this operator, we will assume that we have access to a `sample` function which takes three arguments: (i) the name of the sample (ii) the distribution to sample from and (iii) the arguments of the distribution to be sampled from. We will assume that the implementation of the sample function is passed as an argument to our generative function (we can think of the `@gen` macro as adding this argument to the function definition). Here are implementations of `flip_biased_coin` in both Gen and GenLite.

In [2]:
@gen function flip_biased_coin(N)
    θ ~ beta(1,1)
    [{:flip => i} ~ bernoulli(θ)  for i in 1:N]
end;

function flip_biased_coin_lite(sample, N)
    θ = sample(:θ, Beta, (1,1))
    [sample(:flip => i, Bernoulli, θ) for i=1:N]
end;



Here is the line model translated into our new idiom.

In [3]:
@gen function line_model(xs::Vector{Float64})
    n = length(xs)
    
    m ~ normal(0, 1)
    b ~ normal(0, 2)
    ϵ ~ normal(0,2.5)

    ys=[{:y => i} ~ normal(m * x + b, ϵ^2) for (i, x) in enumerate(xs)]
end;

function line_model_lite(sample, xs)
    n = length(xs)

    m = sample(:m, Normal, (0, 1))
    b = sample(:b, Normal, (0, 2))
    ϵ = sample(:ϵ, Normal, (0, 2.5))
    
    ys = [sample(:y => i, Normal, (m * x + b, ϵ^2)) for (i, x) in enumerate(xs)]
end;

t=Gen.simulate(line_model,([1.,2.,3.,4.,5.],))

Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol,Any}(), Dict{Symbol,Any}(), Type[Array{Float64,1}], false, Union{Nothing, Some{Any}}[nothing], ##line_model#254, Bool[0], false), Trie{Any,Gen.ChoiceOrCallRecord}(Dict{Any,Gen.ChoiceOrCallRecord}(:ϵ => Gen.ChoiceOrCallRecord{Float64}(-2.1502884475932325, -2.2051284977070607, NaN, true),:m => Gen.ChoiceOrCallRecord{Float64}(-1.8081429877652804, -2.5536290653070504, NaN, true),:b => Gen.ChoiceOrCallRecord{Float64}(1.1151036782510546, -1.767517740420747, NaN, true)), Dict{Any,Trie{Any,Gen.ChoiceOrCallRecord}}(:y => Trie{Any,Gen.ChoiceOrCallRecord}(Dict{Any,Gen.ChoiceOrCallRecord}(4 => Gen.ChoiceOrCallRecord{Float64}(-10.907533725375393, -2.986761482679123, NaN, true),2 => Gen.ChoiceOrCallRecord{Float64}(-4.714461072848758, -2.5647084189035767, NaN, true),3 => Gen.ChoiceOrCallRecord{Float64}(-10.572826487446568, -3.3676665541097575, NaN, true),5 => Gen.ChoiceOrCallRecord{Float64}(-7.21712972299034, -2.46188176

In order to make use of these implementations of generative functions, we will need to define sample and pass it in. Let's see how we can implement a version of `simulate` using this technique.

Recal that simulate takes a generative function and its arguments and returns a trace representing a sample from the generative function with those arguments. We will need to represent a few things in this program.

 - **The set of random choices**. For this, we will simply use a dictionary (hashtable) with keys being sequences of symbols such as `:f => 1`.
 - **The score of the sample**. This will be the log density or probability of each random choice made during sampling.
 - **The trace**. For this, we will make use of a simple tuple with the following elements:
     1. The generative function
     2. The arguments the function was called on.
     3. The return value of the function.
     4. The set of choices made during sampling.
     5. The log probability of the trace that was sampled.
     
     

In [19]:
function simulate_lite(gen_func, args)
    
    # initialise the set of choices to an empty dictionary
    choices = Dict()
    
    # Initialize the density at 1
    score = 0.0
    
    # An implementation of the sample function
    function sample_(name, distribution, dist_args)

        
        # Create an instance of the relevant distribution from the Distributions library
        dist = distribution(dist_args...) 
 
        # Sample the value
        value = rand(dist)
        
        # Score the value
        density = Distributions.logpdf(dist, value)
        
        # Update the log density with the value
        score += density
        
        # Record the sampled value with its name
        choices[name] = value
        return(value)
    end

    # Call the generative function with the sample function defined
    retval = gen_func(sample_, args...)
    
    # return trace as a named tuple
    (gen_func=gen_func, args=args, retval=retval, choices=choices, score=score)

end;


simulate_lite(flip_biased_coin_lite, (1000,))

(gen_func = flip_biased_coin_lite, args = (1000,), retval = Bool[0, 1, 1, 1, 1, 1, 1, 1, 1, 1  …  1, 1, 1, 1, 0, 1, 1, 1, 1, 1], choices = Dict{Any,Any}((:flip => 899) => true,(:flip => 930) => true,(:flip => 659) => true,(:flip => 298) => false,(:flip => 706) => true,(:flip => 10) => true,(:flip => 176) => true,(:flip => 686) => false,(:flip => 523) => true,(:flip => 467) => true…), score = -431.6268812215542)

Perhaps a more interesting function is the implementation of `generate`.

In [6]:
function generate_lite(gen_func, args, condition)
    
    choices = Dict()
    score = 0.0
    
    # The importance weight = the likelihood  in this case.
    weight = 0.0 
    
    function sample_(name, distribution, dist_args)
        
        dist = distribution(dist_args...)
        
        # Check to see if the random choice is already in the constraints
        if name in keys(condition)
            
            # If the random choice is conditioned, then just score it and set it in the choices dictionary
            score += logpdf(dist, condition[name])
            weight += logpdf(dist, condition[name])
            choices[name] = condition[name]
        else
            
            # Otherwise, sample it as in simulate and do NOT update the weight.
            value = rand(dist)
            
            # Critically, score is updated in both places.
            score += logpdf(dist, value)
            choices[name] = value
        end
    end
    
    retval = gen_func(sample_, args...)
    trace = (gen_func=gen_func, args=args, retval=retval, choices=choices, score=score)
    return weight, trace
end;