In [None]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu'  # Use CPU backend for JAX

import sys
sys.path.insert(0, os.path.abspath('..'))

import random
from autovar import *

In [2]:
# Define several different lambda samplers so we can track their usage.

coin_A = Dist(Sampler(NamedCallable(lambda: 1 if random.random() < 0.2 else 0, "coin_A")), 1000)
coin_B = Sampler(NamedCallable(lambda: 1 if random.random() < 0.6 else 0, "coin_B"))
dice = Sampler(NamedCallable(lambda: random.randint(1, 6), "dice"))
gauss = Sampler(NamedCallable(lambda: random.gauss(0, 1), "gauss"))
bernoulli = Sampler(NamedCallable(lambda: 1 if random.random() < 0.8 else 0, "bernoulli"))

complicated_program = If(
    coin_A, # if coin_A == 1
    Add(
        dice,
        Mul(coin_B, gauss)
    ),
    Add(
        bernoulli,
        Mul(coin_B, Exact(3))
    )
)

profiled = Profile(complicated_program)

estimate = Dist(profiled, 1000).estimate()
variance = Dist(profiled.variance(), 1000).estimate()

print(pretty(profiled))
print("Estimate:", estimate)
print("Variance:", variance)
print("Profile Summary:\n", profiled.summary())

Profile
└─ If
   ├─ Dist(n=1000)
   │  └─ Sampler(coin_A)
   ├─ Add
   │  ├─ Sampler(dice)
   │  └─ Mul
   │     ├─ Sampler(coin_B)
   │     └─ Sampler(gauss)
   └─ Add
      ├─ Sampler(bernoulli)
      └─ Mul
         ├─ Sampler(coin_B)
         └─ Exact(3)
Estimate: 2.7699775725791884
Variance: 2.783912712036759
Profile Summary:
 ('coin_A', 0): (5000000, 2.063553212210536e-07)
('bernoulli', 4): (2773, 2.2241732880091608e-07)
('coin_B', 1): (9000, 2.1873326558205816e-07)
('dice', 2): (2227, 7.02213457495737e-07)
('gauss', 3): (4227, 6.671197781163466e-07)
('<function Sampler.variance.<locals>.<lambda> at 0x10e1411c0>', 5): (1000, 1.2792404741048814e-06)
('<function Sampler.variance.<locals>.<lambda> at 0x10e141b20>', 6): (2000, 5.199983716011048e-07)
('<function Sampler.variance.<locals>.<lambda> at 0x10e19f560>', 7): (2000, 1.3226941227912903e-06)
('<function Sampler.variance.<locals>.<lambda> at 0x10e19f600>', 8): (1000, 5.490467883646488e-07)
('<function Sampler.variance.<locals>.<