In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
sns.set()

In [None]:
def desc(t):
    t_ = t.numpy()
    print(f'mean={t_.mean():.5f}, std={t_.std():.5f}')
    
def hist(xs, xlim=None):
    plt.figure(figsize=(14,4))
    _ = plt.hist(xs, bins=64, density=True, color='c')
    
    if xlim is not None:
        plt.xlim(xlim)

In [None]:
class Rejector(tfd.Distribution):
    def __init__(self, underlying, condition, name=None):
        self._u = underlying
        self._c = condition
        super().__init__(dtype=underlying.dtype, 
                         name=name or f'rejection_{underlying}',
                         reparameterization_type=tfd.NOT_REPARAMETERIZED,
                         validate_args=underlying.validate_args,
                         allow_nan_stats=underlying.allow_nan_stats)
        
    def _batch_shape(self):
        return self._u.batch_shape
    def _batch_shape_tensor(self):
        return self._u.batch_shape_tensor()
    def _event_shape(self):
        return self._u.event_shape
    def _event_shape_tensor(self):
        return self._u.event_shape_tensor()

    def _sample_n(self, n, seed=None):
        return tf.while_loop(
            cond=lambda samples: not tf.reduce_all(self._c(samples)),
            body=lambda samples: [tf.where(self._c(samples), x=samples, y=self._u.sample(n, seed=seed))],
            loop_vars=[self._u.sample(n, seed=seed)]
        )[0]

In [None]:
# condition written in `numpy` form
cond = lambda x: np.isclose(x, 1.4, atol=1e-3)
cond = lambda x: x > 0
# cond = lambda x: (-3 < x) & (x < 3)
# cond = lambda x: True # id

N = 500000
xs = Rejector(tfd.Normal(0,1), cond).sample(N)

hist(xs.numpy(), xlim=[-3,3])

In [None]:
class Wrapper:
    def __init__(self, dist, **kwargs):
        self.dist = dist
        self.__dict__.update(kwargs)
    
    def observe(self, cond):
        """
        Installs the observe condition @cond (unary predicate)
        as a rejector over the underlying distribution.
        """
        assert callable(cond)
        self.rej = Rejector(self.dist, cond)
        
    def sample(self, N, seed=None):
        d = self.rej if hasattr(self, 'rej') else self.dist
        return d.sample(N, seed=seed)

In [None]:
x = Wrapper(dist=tfd.Normal(0,1))

x.observe(lambda a: a > 0) # observe(phi(x))

hist(x.sample(10000).numpy())

In [None]:
x.sample(1).numpy()[0]

In [None]:
b1 = Wrapper(dist=tfd.Bernoulli(probs=0.4))
b2 = Wrapper(dist=tfd.Bernoulli(probs=0.2))

# observe(b1 ^ b2)
# b1.observe(lambda x: x == 1)
# b2.observe(lambda x: x == 1)

# observe(b1 v b2)
# TODO

## Insert `observe`

In [None]:
import ast
import astor
import showast

In [None]:
# %%showast
# def f():
#     x = Wrapper(dist=tfd.Normal(0,1))
#     return x.sample(1)

In [None]:
prog = """
def f():
    x = Wrapper(dist=tfd.Normal(0,1))
    return x.sample(1)
"""

class my(ast.NodeVisitor):
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)
        self.vars = []
        
    def transform(self):
        self.visit(self.tree)
        return astor.to_source(self.tree)
    
    def visit_Name(self, node):
        if isinstance(node.ctx, ast.Store):
            self.vars.append(node)
            
    def visit_Return(self, node):
        print(node.__dict__)
        
#         last = self.__to_resolve.pop()
        
#         if not isinstance(last, ast.If):
#            return node
        
#         last.body = [node]
#         return None
        
v = my()
t = ast.parse(prog)
v.visit(t)