three types of dimension:
* plate dims
* sample dimensions (usually indexed K)
* user dims; underlying dimensions (which the user gets to interact with). 

In [1]:
import torch.nn as nn
from torch.distributions import Normal
import numpy as np

import sys; sys.path.append("..")
from tpp_trace import *
import utils as u
import tensor_ops as tpp

## Run

In [2]:
kappa = 3
n = 2
tr = sample_and_eval(chain_dist, draws=kappa, nProtected=n, data={"a": 2})

#tr = sample_and_eval(plate_chain_dist, draws=kappa, nProtected=2)#, data={"a": 2})
#tr.trace.out_dicts

tr.trace.out_dicts['sample']['__a']

  return super(Tensor, self).rename(names)


tensor([[[ 0.6436, -0.9349,  0.1692]],

        [[-1.3581,  5.3876,  1.1136]],

        [[ 6.9046,  2.4368,  1.1325]]], names=('_k__a', 'pos_A', 'pos_B'))

## index-aware summing

We have: one factor corresponding to each variable (latent or observed)

e.g. the trace output for the 4 gaussians in our chain example

every time we sample, we add a new dimension. Need to delete these after eval

### No plate case

1. Take all indices `set(I)` in the tensors $T_a$ that depend on `__a` (that have `_k__a` in their names)
2. use pytorch names to order the dims the same in each tensor
3. multiply $T_a$ (as in `*`)
4. sum out `__a`


do the reduction as a for-loop (picking the first K dimension and combining all the tensors with that dimension).


In [3]:
import imp
imp.reload(tpp)
imp.reload(u)

kappa = 2
n = 2
data = {} # {"a": [4] * 100}
tr = sample_and_eval(chain_dist, draws=kappa, nProtected=n, data=data)
tensors = tr.trace.out_dicts['log_prob']

X = t.Tensor([[.3,.1],\
              [.1,.3]])
X = X.refine_names('_k__a', '_k__b')
Y = t.Tensor([[.3,.7],\
              [.2,.3]])
Y = Y.refine_names('_k__a', '_k__b')
tensors = {'__a' : X.log(), '__b' : Y.log() } 

tpp.combine_tensors(tensors, naive=True).values(), \
tpp.combine_tensors(tensors, naive=False).values(), \
u.logmmmeanexp(tensors["__a"], tensors["__b"]).sum()

(dict_values([tensor(7.4262)]),
 dict_values([tensor(-5.4262)]),
 tensor(-10.6475))

In [40]:
def named_example_2D() :
    X = t.Tensor([[.2,.1],\
                  [.1,.3]])
    X = X.refine_names('_k__a', '_k__b')
    Y = t.Tensor([[.3,.2],\
                  [.1,.1]])
    Y = Y.refine_names('_k__a', '_k__b')
    
    return X, Y, '_k__a'


def test_logmulmeanexp() :
    X, Y, dim = named_example_2D()
    log_mean_prod_exp = u.logmulmeanexp(X, Y, dim)
    
    reference = u.logmmmeanexp(X, Y) \
                .rename(None) \
                .diag()
    stripped = log_mean_prod_exp \
                .rename(None) \
                .squeeze()
    
    #assert(t.allclose(stripped, reference) )
    print(reference, "\n", stripped)
    
test_logmulmeanexp()

tensor([0.3612, 0.3512]) 
 tensor([0.3612, 0.3512])


### Testing VI

Set up a Gaussian graphical model:
$$z \sim N(0,1)$$
$$x \sim N(z, 1)$$
Now, we can get $P(x| z) = N(\mu_{x|z}, \Sigma_{xx|z})$ analytically 

$$\mu_{x|z} = \mathbf{\mu_x + \Sigma_{xz}\Sigma_{zz}^{-1}(z-\mu_z)}$$ 
$$\Sigma_{x|z} = \mathbf{\Sigma_{xx} - \Sigma_{xz}\Sigma_{zz}^{-1}\Sigma_{zx}}$$

---

We use the approximate posterior, $$Q(z) = N(\mu, \sigma^2)$$, where mu and sigma are learned parameters. 

When we learn using our ELBO-thing, do those parameters learn to match the true posterior?


In [4]:
# e.g. bivariate example
z_mean, z_var = 0, 1
x_var = 1
rho = 0.5

z = 0 #z.sample([n]).mean()
GROUND_TRUTH_POST_MU = u.biv_norm_conditional_mean(z, z_mean, np.sqrt(x_var), \
                                              np.sqrt(z_var), rho, z)
GROUND_TRUTH_POST_VAR = u.biv_norm_conditional_var(x_var, rho)

GROUND_TRUTH_POST_MU, GROUND_TRUTH_POST_VAR


# u.analytical_posterior_var(var, X)
# u.analytical_posterior_mean(prior_mean, var, X, Z) 

(0.0, 0.25)

# VI without plates

i.e. no repeating bits to abstract over

Optimise the params of an approx posterior over extended Z-space, but not K space

$$Q (Z|X) = \prod_k  Q(Z^k|X) = \prod_k \prod_i Q(Z^k_i \mid Z^k_{qa(i)})$$

and

$$\prod_j f_j^{\kappa_j} = \frac{P(x_, Z)}{\prod Q(z_i^{k_i})}$$

Writing out the target (log marginal likelihood) fully makes the computation clear:

$$ \mathcal{L}= E_{Q(Z|X)} \left[ \log \frac{∑_K  P(Z,K,X)}{Q (Z|X)} \right]$$
$$= E_{Q} \left[ \log \frac{∑_K  P(Z,K,X)}{Q (Z|X)} \right]$$

In [5]:
# https://github.com/anonymous-78913/tmc-anon/blob/master/non-fac/model.py
class Fac(TMC):
    def rsample_log_prob(self, Ks):
        if isinstance(Ks, int):
            Ks = self.T.item() * [Ks]

        zs = []
        log_probs = []

        for i in range(len(Ks)):
            Q = Normal(self.means[i], self.log_stds[i].exp())
            z = Q.rsample(sample_shape=t.Size([Ks[i]]))
            zs.append(z)
            log_probs.append(Q.log_prob(z))

        return zs, log_probs
    
    
class TMC(nn.Module):
    def __init__(self, T, res, log_std=0., like_std=1.):
        super().__init__()
        self.register_buffer('T', t.tensor(T))
        self.register_buffer('res', t.tensor(res))
        #self.res      = res
        self.means    = nn.Parameter(t.zeros(T))
        self.log_stds = nn.Parameter(t.sqrt((1.+t.arange(T).float())/T))#log_std*t.ones(T))
        self.like_std = like_std

    def logpqs(self, Ks):
        """
        Compute the series of tensors that we reduce over, by combining the universal generative model
        with the proposal propbabilities.
        """
        zs, log_qs = self.rsample_log_prob(Ks)

        log_ps = Normal(0., t.sqrt(1./self.T.float())).log_prob(zs[0])
        res = [(log_ps - log_qs[0]).unsqueeze(0)]

        for i in range(1, len(zs)):
            log_ps = Normal(zs[i-1].unsqueeze(1), t.sqrt(1/self.T.float())).log_prob(zs[i])
            res.append(log_ps - log_qs[i])

        res.append(Normal(zs[-1].unsqueeze(1), self.like_std).log_prob(self.res))

        return res

    def reduce(self, Ks):
        """
        Combine tensors
        """
        logpqs = self.logpqs(Ks)

        res = logpqs[0]
        for i in range(1, len(logpqs)):
            res = logmmmeanexp(res, logpqs[i])
        return res

NameError: name 'TMC' is not defined

In [None]:
class FactorisedDist(nn.Module):
    """
    Takes a list of distributions and allows e.g. sampling from them 
    """
    def __init__(self, dists):
        super().__init__()
        self.dists = dists

    def rsample(self):
        return [d.rsample() for d in self.dists]

    def log_prob(self, zs):
        return [d.log_prob(z) for (d, z) in zip(self.dists, zs)]

    def rsample_log_prob(self):
        zs = self.rsample()
        lps = self.log_prob(zs)
        return zs, lps
    
    
class FacRandomSequential(nn.Sequential):
    """
    Transforms input into a series of factorised distributions
    """
    def forward(self, input):
        mods = list(self._modules.values())
        z = input
        dists = []
        for i in range(len(mods)):
            z, dist = mods[i](z)
            dists.append(dist)
        
        return FactorisedDist(dists)

In [None]:
class VAE(nn.Module):
    """
    Usual single/multi-sample VAE
    """
    def __init__(self, p, q, K):
        super().__init__()
        self.p = p
        self.q = q
        self.K = K

    def forward(self, x):
        wz = self.q.sample(x.size(0), sample_shape=t.Size([self.K]))
        elbo = self.p.log_prob((x, wz)) - self.q.log_prob(wz)
        lme = logmeanexp(elbo)
        return lme

    def train(self, x):
        opt = t.optim.Adam(q.parameters())
        for i in range(100):
            opt.zero_grad()
            obj = self(x)
            (-obj).backward()
            opt.step()
            print(obj)

## VI with plates

$T$: underlying tensor list 

1. Get plate order
    - by looking for the first tensor with that plate
    - and using the index of that tensor?
2. Sum out plates in reverse order of definition

For p in reverse(plates):
* $T_{\mathrm{new}} = []$

* $T_p \leftarrow$ all tensors in p

* Remove $T_p$ from $T$

* $T_{\mathrm{new}} \leftarrow T_p$

* Sum out all sample indexes within the plate

* $T_{\backslash p} \leftarrow$ Sum out the plate

* $T += T_{\backslash p}$


In [None]:
def rearrange_by_plate(tensor_dict) :
    """
    :param tensor_dict: dict of log_prob tensors
    :return: dict of dicts of log_prob tensors, dividing by plates
    """
    return NotImplementedError()


# Sum:  sum over each plate
def plate_sum() :
    return NotImplementedError()

In [None]:
# Consider Figure 1 from TMC

"""
z2 | z1
z3 | z1
z4 | z2
x | z3, z4
"""
def simple_dist(trace):
    k1 = trace["i1"](WrappedDist(Normal, t.ones(3), 3))
    k2 = trace["i2"](WrappedDist(Normal, k1, 3))
    k3 = trace["i3"](WrappedDist(Normal, k2, 3))
    (k3,) = trace.delete_names(("i1", "i2"), (k3,))
    k4 = trace["i4"](WrappedDist(Normal, k3, 3))
    
    return k4

tr = sample_and_eval(simple_dist, draws=kappa, nProtected=2)#, data={"a": 2})

## index-aware sampling