In [65]:
import re
import string
import random
import torch as t
from torch.distributions import Normal
from functools import reduce

import sys; sys.path.append("..")
from tpp_trace import *
import utils as u

In [2]:
# example directed graph without plate repeats
# the real stuff happens as side effects in trace's dicts
def chain_dist(trace):
    a = trace["a"](WrappedDist(Normal, t.ones(3), 3))
    b = trace["b"](WrappedDist(Normal, a, 3))
    c = trace["c"](WrappedDist(Normal, b, 3))
    (c,) = trace.delete_names(("a", "b"), (c,))
    d = trace["d"](WrappedDist(Normal, c, 3))
    
    return d


# example directed graph with plate repeats
# 3(a) -> 4(b) -> c -> d
def plate_chain_dist(trace):
    a = trace["a"](WrappedDist(Normal, t.ones(3), 3), plate_name="A", plate_shape=3)
    b = trace["b"](WrappedDist(Normal, a, 3),         plate_name="B", plate_shape=4)
    c = trace["c"](WrappedDist(Normal, b, 3))
    (c,) = trace.delete_names(("a", "b"), (c,))
    d = trace["d"](WrappedDist(Normal, c, 3))
    
    return d


def sampler(draws, nProtected, data={}):
    return trace({"data":data}, \
                 SampleLogProbK(K=draws, protected_dims=nProtected) \
                )

def evaluator(sampler, nProtected, data={}) :
    tr = sampler.trace
    d = {"data": data, "sample": tr.out_dicts["sample"]} # error if called before sampler
    lpk = LogProbK(plate_names=tr.fn.plate_names, protected_dims=nProtected)
    
    return trace(d, lpk)


def sample_and_eval(model, draws, nProtected, data={}) :
    tr1 = sampler(draws, nProtected, data=data)
    val = model(tr1)
    print("Plates:", tr1.trace.fn.plate_names)
    print(val.names)
    tr2 = evaluator(tr1, nProtected, data=data)
    val = model(tr2)
    print(val.names)
    
    return tr2

## Run

In [3]:
kappa = 3

tr = sample_and_eval(chain_dist, draws=kappa, nProtected=2)#, data={"a": 2})
#tr.trace.out_dicts

tr = sample_and_eval(plate_chain_dist, draws=kappa, nProtected=2)#, data={"a": 2})
#tr.trace.out_dicts



Plates: []
('_K', 'pos_A', 'pos_B')
('_k__d', '_k__c', 'pos_A', 'pos_B')
Plates: ['_plate_A', '_plate_B']
('_plate_B', '_plate_A', '_K', 'pos_A', 'pos_B')
('_k__d', '_k__c', '_plate_B', '_plate_A', 'pos_A', 'pos_B')


  return super(Tensor, self).rename(names)


## index-aware summing

In [70]:
# We have: one factor corresponding to each variable (latent or observed)
# e.g. the trace output for the 4 gaussians in our example

# Combine: 
# list of log_prob tensors -> list of log_prob tensors, summed across groups
# TODO: add data handling
# TODO: add plates
def combine_tensors(tensors) :
    names = list(tensors.keys()) # variable names
    tensors = tensors.values()
    print(names)
    
    while len(names) > 2:
        # choose a random index (e.g. "a")
        i = names.pop(random.randrange(len(names)))

        # remove all factors which depend on "a"
        tensors, i_tensors = get_dependent_factors(tensors, i)
        
        #Take the tensor product of all these factors 𝑡_a = ∏ 𝑡|𝐾_a
        T = reduce(t.matmul, i_tensors)
        # need a `view` here?
        # Is opt_einsum necessary for controlling dim explosion?
        # squeeze?
        
        #Sum out 𝑘^a  for that variable →𝑡∖a
        T = T.sum(i)
        #TODO: plate_sum()
        
        # Put 𝑡∖a back in the list
        tensors.append(T)
        
        # TODO: when does the split by group happen?
        # groupby(tensors)
        
    # Then the last two tensors are simple
    #T = combine_two_tensors(tensors)
    
    
    
    return T


# factors T∣K_i that depend on K_i
def get_dependent_factors(tensors, i):
    i_dim = k_dim_name(i)
    dependents = [tensor for tensor in tensors \
                    if i_dim in tensor.names]
    nondependents = [tensor for tensor in tensors \
                        if i_dim not in tensor.names]
    
    return nondependents, dependents


def combine_two_tensors(tensors) :
    assert len(tensors) == 2
    return u.logmmmeanexp(*tensors)

# groups 
# i.e. split the arg_names into indices and latents
# (otherwise, we will conflate _k_dims and latent variables)
def groupby(tensors) :
    return NotImplementedError()


# Rearrange:  
# dict of log_prob tensors -> dict of dicts of log_prob tensors, dividing by plates
def rearrange_by_plate(tensor_dict) :
    return NotImplementedError()


# Sum:  sum over each plate
def plate_sum() :
    return NotImplementedError()

In [71]:
kappa = 2

tr = sample_and_eval(chain_dist, draws=kappa, nProtected=2)#, data={"a": 2})
combine_tensors(tr.trace.out_dicts["log_prob"])


tr.trace.out_dicts["log_prob"]

Plates: []
('_K', 'pos_A', 'pos_B')
('_k__d', '_k__c', 'pos_A', 'pos_B')
['__a', '__b', '__c', '__d']


RuntimeError: Expected tensor to have size 3 at dimension 1, but got size 1 for argument #2 'batch2' (while checking arguments for bmm)

# VI without plates

i.e. no repeating bits to abstract over

$$Q (Z|X) = \prod_k  Q(Z^k|X) = \prod_k \prod_i Q(Z^k_i \mid Z^k_{qa(i)})$$

and

$$\prod_j f_j^{\kappa_j} = \frac{P(x_, Z)}{\prod Q(z_i^{k_i})}$$

Writing out the target (log marginal likelihood) fully makes the computation clear:

$$ \mathcal{L}= E_{Q(Z|X)} \left[ \log \frac{∑_K  P(Z,K,X)}{Q (Z|X)} \right]$$
$$= E_{Q} \left[ \log \frac{∑_K  P(Z,K,X)}{Q (Z|X)} \right]$$

1. get factors 
2. sum out indices until you get a scalar log prob
3. gradient ascent


### 2. the sum (combine)

a. until the last two, do clever thing

b. and we don't need to be clever when there's only two tensors

### The clever thing

1. choose a random index (e.g. "a")
2. take all vectors which use "a"
3. matmul + flatten
    - this gives you a giant tensor so be clever (maybe `opt.einsum`, but probably not)
4. sum out "a"


To sum out $K_5$

1. Remove all tensors / factors $T \mid K_5$ from the list that depend on $K_5$
2. Take the tensor product of all these factors $t_5 = \prod t |K_5$
3. Sum out $k_i^{(5)}$ for that variable $\to t_{\backslash 5}$
4. Put $t_{\backslash 5}$ back in the list

# 3.
a. "sum(I)" =  we  multiply  all  factors  in  the  list  that  depend  on I, 

b. sum I out, 

c. and put the result back in the list.

In [6]:
# Test case: we want to compute the true posterior, especially its variance

def P(trace):
    a = WrappedDist(Normal, t.ones(2), 3)
    b = WrappedDist(Normal, t.ones(2), 3)
    #return (a+b)

#class Q(nn.Module):
#    def __init__(self)


tr1 = trace({"data": {}}, SampleLogProbK(K=4, protected_dims=2))
P(tr1)

# needs a prior and approximate posterior
# Lookup Gaussian conditioning in Bishop for this

## VI with plates

$T$: underlying tensor list 

1. Get plate order
    - by looking for the first tensor with that plate
    - and using the index of that tensor?
2. Sum out plates in reverse order of definition

For p in reverse(plates):
* $T_{\mathrm{new}} = []$

* $T_p \leftarrow$ all tensors in p

* Remove $T_p$ from $T$

* $T_{\mathrm{new}} \leftarrow T_p$

* Sum out all sample indexes within the plate

* $T_{\backslash p} \leftarrow$ Sum out the plate

* $T += T_{\backslash p}$


## index-aware sampling

## Laurence's example

In [7]:
nProtected = 2
draws = 4 # kappa

tr1 = trace({"data": {}}, SampleLogProbK(K=draws, protected_dims=nProtected))
val = chain_dist(tr1)
print(val.names)
tr2 = trace({"data": {}, "sample": tr1.trace.out_dicts["sample"]}, \
            LogProbK(tr1.trace.fn.plate_names, nProtected))
val = chain_dist(tr2)
print(val.names)
#tr2.trace.out_dicts#["sample"]#["__a"]

('_K', 'pos_A', 'pos_B')
('_k__d', '_k__c', 'pos_A', 'pos_B')


In [8]:
X = t.ones(2,4) *3
Y = t.ones(4,2) *2
u.logmmmeanexp(X, Y)

tensor([[5., 5.],
        [5., 5.]])