# Code associated with lecture 6

In [1]:
using Gen
using Plots

In [2]:
# Generative model of an urn with unknown ratio of red and blue marbles
# Using the @trace notation
@gen function unknown_urn()
    # p(θ) ̃Uniform(0,1)  [prior distribution]
    theta = @trace(uniform(0, 1), :theta)
    for i=1:10
        # p(y=1|θ) ~ Bernoulli(θ) [likelihood function]
        y = @trace(bernoulli(theta), :data => i => :y)
    end
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], var"##unknown_urn#292", Bool[], false)

In [3]:
# same as above, but using the ~ notation
@gen function unknown_urn()
    # p(θ) ̃Uniform(0,1)  [prior distribution]
    theta ~ uniform(0, 1)
    for i=1:10
        # p(y=1|θ) ~ Bernoulli(θ) [likelihood function]
        {:data => i => :y} ~ bernoulli(theta)
    end
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], var"##unknown_urn#293", Bool[], false)

In [4]:
# simulate the generative function to draw a sample
trace = simulate(unknown_urn, ())
# look into the random choices (i.e., a sample from the generative model)
Gen.get_choices(trace)

│
├── :theta : 0.6299190921567404
│
└── :data
    │
    ├── 5
    │   │
    │   └── :y : true
    │
    ├── 4
    │   │
    │   └── :y : false
    │
    ├── 6
    │   │
    │   └── :y : true
    │
    ├── 7
    │   │
    │   └── :y : true
    │
    ├── 2
    │   │
    │   └── :y : true
    │
    ├── 10
    │   │
    │   └── :y : true
    │
    ├── 9
    │   │
    │   └── :y : true
    │
    ├── 8
    │   │
    │   └── :y : false
    │
    ├── 3
    │   │
    │   └── :y : true
    │
    └── 1
        │
        └── :y : true


In [1]:
# helper function to create a choicemap with the observed data xs
function make_constraints(ys::Vector{Bool})
    # choicemap to encode observations
    constraints = Gen.choicemap()
    # fill in the choicemap with the observed draws
    for i=1:length(ys)
        constraints[:data => i => :y] = ys[i]
    end
    # return the observations
    constraints
end

# **** Start reading this codeblock from here: ****
# we will store observations in a Boolean array
draws = Bool[]
# true=red marble; false=blue marble
# typically data is something we read from file, but in this case we will hand-code it
draws = [true, true,true, true, true, true, false, true, false, true]
observations = make_constraints(draws)

LoadError: UndefVarError: `Gen` not defined

In [6]:
# the posterior p(θ|observations)
# using "importance resampling" from Gen's library of inference methods
# we run this algorithm 100 times and record the output
traces = [first(Gen.importance_resampling(unknown_urn, (), observations, 1000)) for i in 1:100]

# collect the inferred θs across these 100 chains and plot
thetas = Float64[]
thetas = [t[:theta] for t in traces]
p = histogram(thetas, label="\theta")
savefig(p, "urn-posterior.svg")

In [7]:
@gen function unknown_urn_single()
    # p(θ) ̃Uniform(0,1)  [prior distribution]
    theta ~ uniform(0, 1)
    # p(y=1|θ) ~ Bernoulli(θ) 
    y ~ bernoulli(theta)
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], var"##unknown_urn_single#295", Bool[], false)

In [8]:
# draw a sample from the generative model
# notice the args x=() because our generative model
# unknown_urn_single() has no arguments
trace = Gen.simulate(unknown_urn_single, ())
# peek inside the trace
Gen.get_choices(trace)

│
├── :y : true
│
└── :theta : 0.8015938986727597


In [9]:
traces = [simulate(unknown_urn_single, ()) for _ in 1:1000]
thetas = [traces[i][:theta] for i in 1:1000]
ys = [traces[i][:y] for i in 1:1000]
p = histogram(thetas, label="θ")
p1 = scatter(ys, thetas, label=:none)
savefig(p1, "single-urn.svg")

In [10]:
@gen function foo(a, b=0)
    z ~ bernoulli(0.5)
    if z
        return a + b + 1
    else
        return a + b
    end
end
# draw a sample from the generative model
trace = Gen.simulate(foo, (3, 5))
# get the random choice made
choices = Gen.get_choices(trace)

│
└── :z : true


In [11]:
# get the return value of the function foo
return_value = Gen.get_retval(trace)

9

In [12]:
# draw a trace and return its log score where y=true
(trace, weight) = Gen.generate(unknown_urn_single, (), choicemap((:y, true)))

(Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], var"##unknown_urn_single#295", Bool[], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(true, -3.6998986620853485, NaN, true), :theta => Gen.ChoiceOrCallRecord{Float64}(0.024726032027906575, -0.0, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -3.6998986620853485, 0.0, (), true), -3.6998986620853485)

## Illustrating Gen.update

In [14]:
@gen function bar()
    a ~ bernoulli(0.3)
    b ~ bernoulli(0.4)
    if b
        c ~ bernoulli(0.6)
        val = c && a
    else
        d ~ bernoulli(0.1)
        val = d && a
    end
    e ~ bernoulli(0.7)
    val = e && val
    return val
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], var"##bar#300", Bool[], false)

In [15]:
trace = Gen.simulate(bar, ())
Gen.get_choices(trace)

│
├── :a : true
│
├── :b : true
│
├── :e : false
│
└── :c : true


In [16]:
# encode the new trace u as a choicemap
constraints = choicemap((:b, false), (:d, true))
# make the update
(new_trace, w, _, discard) = update(trace, (), (), constraints)

(Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], var"##bar#300", Bool[], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:a => Gen.ChoiceOrCallRecord{Bool}(true, -1.2039728043259361, NaN, true), :b => Gen.ChoiceOrCallRecord{Bool}(false, -0.5108256237659907, NaN, true), :d => Gen.ChoiceOrCallRecord{Bool}(true, -2.3025850929940455, NaN, true), :e => Gen.ChoiceOrCallRecord{Bool}(false, -1.203972804325936, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -5.221356325411908, 0.0, (), false), -1.3862943611198906, UnknownChange(), DynamicChoiceMap(Dict{Any, Any}(:b => true, :c => true), Dict{Any, Any}()))

In [17]:
# peek inside the adjusted trace
Gen.get_choices(new_trace)

│
├── :a : true
│
├── :b : false
│
├── :d : true
│
└── :e : false


In [156]:
# choices not used from the original trace t
discard

│
├── :b : false
│
└── :d : true


In [18]:
# log score of the adjusted trace
w

-1.3862943611198906

In [19]:
# regenarate the random choices associated with the addresses a and b
(new_trace, w, _) = Gen.regenerate(trace, (), (), select(:a, :b))

(Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], var"##bar#300", Bool[], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:a => Gen.ChoiceOrCallRecord{Bool}(false, -0.35667494393873245, NaN, true), :b => Gen.ChoiceOrCallRecord{Bool}(false, -0.5108256237659907, NaN, true), :d => Gen.ChoiceOrCallRecord{Bool}(false, -0.10536051565782628, NaN, true), :e => Gen.ChoiceOrCallRecord{Bool}(false, -1.203972804325936, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -2.1768338876884856, 0.0, (), false), 0.0, UnknownChange())

In [20]:
get_choices(new_trace)

│
├── :a : false
│
├── :b : false
│
├── :d : false
│
└── :e : false


In [21]:
# weight
w

0.0

## Static annotation

In [169]:
@gen function foo_dynamic(x::Float64)
    # draw the mean of a normal dist. w/ a coin flip
    if @trace(bernoulli(0.5), :branch)
        mu = x
    else
        mu = 1/x
    end
    # draw the normal distribution
    z = @trace(normal(mu, 1), :z)
end
trace = Gen.simulate(foo_dynamic, (3.0,))
get_choices(trace)

│
├── :branch : true
│
└── :z : 2.6382706573019994


In [22]:
@gen (static) function foo_static(x::Float64)
    # in static annotation, you can use the "? :" branching
    # but it can't be a top-level expression 
    # e.g., it can be on the right-hand side
    mu = @trace(bernoulli(0.5), :branch) ? (x) : (1/x)
    # draw the normal distribution
    @trace(normal(x, 1), :z)
end
@load_generated_functions
trace = Gen.simulate(foo_static, (3.0,))
get_choices(trace)

│
├── :branch : true
│
└── :z : 4.067046879558006


## Combinator example using Map

In [25]:
@gen function unknown_urn()
    # p(θ) ̃Uniform(0,1)  [prior distribution]
    theta ~ uniform(0, 1)
    for i=1:100
        # p(y=1|θ) ~ Bernoulli(θ) [likelihood function]
        {:data => i => :y} ~ bernoulli(theta)
    end
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], var"##unknown_urn#327", Bool[], false)

In [26]:
# kernel generative function G_k
@gen (static) function observe_flip(theta::Float64)
    # p(y|θ) ~ Bernoulli(θ) [likelihood function]
    y ~ bernoulli(theta)
    return y
end

@gen (static) function unknown_urn_static(k::Int64)
    # prior over coin weight p(θ)
    theta ~ uniform(0,1)
    # called k times
    data ~ Gen.Map(observe_flip)(fill(theta, k))
end

var"##StaticGenFunction_unknown_urn_static#353"(Dict{Symbol, Any}(), Dict{Symbol, Any}())

In [27]:
# need this to load the static functions
@load_generated_functions

# get a draw and peek inside
(trace, _) = generate(unknown_urn_static, (3,))
get_choices(trace)

│
├── :theta : 0.917214885808224
│
└── :data
    │
    ├── 1
    │   │
    │   └── :y : false
    │
    ├── 2
    │   │
    │   └── :y : true
    │
    └── 3
        │
        └── :y : true


In [28]:
# this is our kernel G_k
@gen function foo(x1::Float64, x2::Float64)
    # draw a normal based on the arguments 
    y = @trace(normal(x1 + x2, 1.0), :z)
    # and return it
    return y
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##foo#365", Bool[0, 0], false)

In [29]:
bar = Map(foo)

Map{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##foo#365", Bool[0, 0], false))

In [30]:
(trace, _) = generate(bar, ([0.0, 0.5], [0.5, 1.0]))
get_choices(trace)

│
├── 1
│   │
│   └── :z : 0.23628144708105397
│
└── 2
    │
    └── :z : 1.9173202290293534


In [34]:
# using "importance resampling" from Gen's library of inference methods
traces = [first(Gen.importance_resampling(unknown_urn, (), observations, 1000)) for i in 1:10]
# collect the inferred θs across these 100 chains and plot
thetas = Float64[]
thetas = [t[:theta] for t in traces]
p2 = histogram(thetas, label="\theta")
savefig(p2, "urn-posterior-static.svg")