In [13]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

import random
from src import *

In [16]:
# Define several different lambda samplers so we can track their usage.

coin_A = Dist(Sampler(NamedCallable(lambda: 1 if random.random() < 0.2 else 0, "coin_A")), 1000)
coin_B = Sampler(NamedCallable(lambda: 1 if random.random() < 0.6 else 0, "coin_B"))
dice = Sampler(NamedCallable(lambda: random.randint(1, 6), "dice"))
gauss = Sampler(NamedCallable(lambda: random.gauss(0, 1), "gauss"))
bernoulli = Sampler(NamedCallable(lambda: 1 if random.random() < 0.8 else 0, "bernoulli"))

complicated_program = If(
    coin_A, # if coin_A == 1
    Add(
        dice,
        Mul(coin_B, gauss)
    ),
    Add(
        bernoulli,
        Mul(coin_B, Exact(3))
    )
)

profiled = Profile(Dist(complicated_program, 1000))

estimate = profiled.estimate()
variance = profiled.variance().estimate()

print(pretty(profiled))
print("Estimate:", estimate)
print("Variance:", variance)
print("Profile Summary:\n", profiled.summary())

Profile
└─ Dist(n=1000)
   └─ If
      ├─ Dist(n=1000)
      │  └─ Sampler(coin_A)
      ├─ Add
      │  ├─ Sampler(dice)
      │  └─ Mul
      │     ├─ Sampler(coin_B)
      │     └─ Sampler(gauss)
      └─ Add
         ├─ Sampler(bernoulli)
         └─ Mul
            ├─ Sampler(coin_B)
            └─ Exact(3)
Estimate: 2.8353977608549084
Variance: 2.6665584891699776
Profile Summary:
 ('dice', 42): (5000, 3.5701634478755295e-07)
('coin_B', 41): (14000, 1.1372509262790637e-07)
('gauss', 43): (7000, 3.266787139831909e-07)
('coin_A', 40): (6000000, 1.0303843128106867e-07)
('bernoulli', 44): (5000, 1.1799156200140715e-07)

