A demo explaining the `Sampler` implementation of the `IntractableReal` class.

In [None]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu'  # Use CPU backend for JAX

import sys
sys.path.insert(0, os.path.abspath('..'))

import random
from autovar import *

Let's define a fair coin. It outputs 0 half the time, 1 the other half. 

In [10]:
def coin(p: float) -> float:
    return 1 if random.random() < p else 0

print("Sample coin flips:", [coin(0.5) for _ in range(5)])

Sample coin flips: [1, 0, 0, 1, 1]


Now, we introduce the notion of a `Sampler`, which is an implementation of the `IntractableReal` class. Samplers are relevant when we can run a program, but have limited knowledge of its distribution over outputs.

Our `Sampler(f).evaluate()` implements the simplest possible unbiased estimator of program mean: it simply runs `f()` once and takes this as the mean.

In [11]:
sampler = Sampler(lambda: coin(0.5))
print("Coin flips from sampler:", [sampler.estimate() for _ in range(5)])

Coin flips from sampler: [0, 0, 0, 1, 1]


To estimate $Var(f)$, our `Sampler` defines its own internal `Sampler`, which takes 2 samples from our lambda, $f$.

This looks like: `Sampler(lambda: 0.5 * (self.f() - self.f())**2)`

Note there are only two possible output values: $self.f() \in \set{0, 1}$.

So $f() - f() \in  \set{-1, 0, 1}$, and $(f() - f())^2 = 0, 1$.

Because we have $p(0) = p(-1 | 1)= 0.5$, we can expect our variance to be 0 half the time, and 0.5 the other half of the time. Let's sample and see!

In [12]:
results = [sampler.variance().estimate() for _ in range(1000)]
count_0 = results.count(0)
count_05 = results.count(0.5)
print(f"0 occurs {count_0} times, 0.5 occurs {count_05} times out of 1000 samples")
print(f"Fraction of 0: {count_0/1000:.2f}, Fraction of 0.5: {count_05/1000:.2f}")

0 occurs 506 times, 0.5 occurs 494 times out of 1000 samples
Fraction of 0: 0.51, Fraction of 0.5: 0.49


We can use the `Dist(d, n)` implementation of the `IntractableReal` class estimate the mean and variance of an underlying distribution, `d`, over many samples, `n`.

Say we don't know our coin is fair. We can imagine estimating the true mean and variance over the distribution over coin flips by taking many samples using our `Sampler()`.

In [13]:
print(f"mean: {Dist(sampler, 1000).estimate()}, variance: {Dist(sampler, 1000).variance().estimate()}")

mean: 0.523, variance: 0.000254


It seems this works! We can do the same to discover unfair coins.

In [14]:
new_rate = 0.8
sampler = Sampler(lambda: coin(new_rate))
print(f"mean: {Dist(sampler, 100).estimate()}, variance: {Dist(sampler, 100).variance().estimate()}")

mean: 0.82, variance: 0.0021


`Sampler` is well-defined over arbitrary lambda functions, so long as they produce a single real number. Let's consider the behavior of a more complex probabilistic program.

In [15]:
def weird_tricoin(p: float) -> float:
    if coin(p) == 0:
        return 0
    else:
        if coin(1-p) == 0:
            return 1
        else:
            return 100

sampler = Sampler(lambda: weird_tricoin(0.5))
print(f"mean: {Dist(sampler, 10000).estimate()}, variance: {Dist(sampler, 10000).variance().estimate()}")


mean: 25.688, variance: 0.18597102000000001


Say you've done the math, and you know that the expected value of this program is 25.25. You can specify this to sampler, and let variance be determined by sampling against that mean using a refined variance sampler: 

`return Sampler(lambda: (self.f() - self.known_mean)**2)`.

In [16]:
sampler_with_known_mean = Sampler(lambda: weird_tricoin(0.5), known_mean=25.25)
print(f"mean: {Dist(sampler_with_known_mean, 10000).estimate()}, variance: {Dist(sampler_with_known_mean, 10000).variance().estimate()}")

profile = Profile(Sampler(lambda: weird_tricoin(0.5)))
print("Mean: ", profile.estimate())
print("Variance: ", profile.variance().estimate())
print("Summary: ", profile.summary())

mean: 25.25, variance: 0.18712708
Mean:  0
Variance:  0.5
Summary:  ('<function <lambda> at 0x108e11080>', 2018): (1, 5.829497240483761e-07)
('<function <lambda> at 0x108e11080>.variance', 2019): (1, 3.4580007195472717e-06)

