# VI without plates

i.e. no repeating bits to abstract over

Optimise the params of an approx posterior over extended Z-space, but not K space

$$Q (Z|X) = \prod_k  Q(Z^k|X) = \prod_k \prod_i Q(Z^k_i \mid Z^k_{qa(i)})$$

and

$$\prod_j f_j^{\kappa_j} = \frac{P(x_, Z)}{\prod Q(z_i^{k_i})}$$

Writing out the target (log marginal likelihood) fully makes the computation clear:

$$ \mathcal{L}= E_{Q(Z|X)} \left[ \log \frac{∑_K  P(Z,K,X)}{Q (Z|X)} \right]$$
$$= E_{Q} \left[ \log \frac{∑_K  P(Z,K,X)}{Q (Z|X)} \right]$$

with

$$
  \frac{P({Z, K, X})}{Q({Z|X})} = 
  P({K}) 
  P \left({X| Z_{\mathrm{pa}{X}}^{K_{\mathrm{pa}{X}}}} \right)  
  \prod_i 
  \frac{P\left({Z_i^{K_i}| Z^{K_{\mathrm{pa}(i)}}_{\mathrm{pa}(i)}} \right)}
  {Q \left(  
    Z_i^{K_i}| Z^{K_i}_{\mathrm{qa}(i)}
  \right)}$$

## The computation

1. Form joint P/Q: index prior * lik * product of latent P/Qs
2. $\mathcal{L}$: sum out K, then log P - log Q, then average

In [2]:
import math
import numpy as np
import torch as t
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions import MultivariateNormal as MVN

import sys; sys.path.append("..")
from tpp_trace import *
import utils as u
import tensor_ops as tpp
from tvi import *

import imp

## First with no plates

In [205]:
# a factorised approx posterior. generate 3 simple variables
# sample along the chain

# a ~ N([1],[3])
# b ~ N(a,[3])
# c ~ N(b,[3])

n = 3
scale = n
Norm = lambda mu : WrappedDist(Normal, mu, scale)

TRUE_MEAN_A = 10

# a -> b -> c observed
def chain_dist(trace, n=3):
    a = trace["a"](Norm(t.ones(n) * TRUE_MEAN_A))
    b = trace["b"](Norm(a))
    c = trace["c"](Norm(b))
    
    return c

# a param placeholder
# Hardcoding 2 params for each var, for now
# factorised Gaussian with learned means and covs
class ChainQ(nn.Module):
    def __init__(self, n=3):
        super().__init__()
        self.mean_a = nn.Parameter(t.zeros(n))
        self.mean_b = nn.Parameter(t.zeros(n))
        self.logscale_a = nn.Parameter(t.zeros(n))
        self.logscale_b = nn.Parameter(t.zeros(n))
    
    def sample(self, trace) :
        a = trace["a"](Norm(t.ones(n)))
        b = trace["b"](Norm(a)) 


## Dryrun

In [None]:
draws = 2
nProtected = 2 # (number of vars in chain dist) minus observed nodes

P = chain_dist
# TODO: Why pos_B?
data_tensor = (t.ones(n) * 1).refine_names('pos_B')
data = {"__c": data_tensor}

trq = sampler(draws, nProtected, data=data)
Q = ChainQ()
Q.sample(trq)
trp = evaluator(trq, nProtected, data=data)
P(trp)
print(trp.trace.out_dicts)
P(trq)
trq.sum_out_pos()
trp.sum_out_pos()

trq.trace.out_dicts = rename_placeholders(trp, trq)

#print(trp.trace.out_dicts["sample"])
#print()
#print(trq.trace.out_dicts["sample"])
#print()
tensors = subtract_latent_log_probs(trp, trq)
loss_dict = tpp.combine_tensors(tensors)
key = next(iter(loss_dict))
loss_dict[key]
    

In [240]:
class TVI(nn.Module) :
    def __init__(self, p, q, k, x, nProtectedDims):
        super().__init__()
        self.p = p
        self.q = q
        self.k = k
        self.nProtected = nProtectedDims
        
        self.data_dict = {}
        self.data_dict["__c"] = []
        self.data = nn.Parameter(x, requires_grad=False) 
        
    # 1. s = sample Q
    # 2. lp_Q = eval Q.logprob(s)
    # 3. lp_P = eval P.logprob(s)
    # 4. f = lp_P - lp_Q
    # 5. loss = combine fs
    def forward(self):
        self.data_dict["__c"].append(self.data)
        
        # init traces at each step
        sample_trace = sampler(self.k, self.nProtected, data={"__c": self.data})
        # sample recognition model Q -> Q-sample and Q-logprobs
        self.q.sample(sample_trace)
        # compute P logprobs under prior
        #self.p(sample_trace)
        
        # Pass Q samples to new trace
        eval_trace = evaluator(sample_trace, self.nProtected, data={"__c": self.data})
        # compute P logprobs under ...
        self.p(eval_trace)
        
        eval_trace.sum_out_pos()
        sample_trace.sum_out_pos()
        # align dims in Q
        sample_trace.trace.out_dicts = rename_placeholders(eval_trace, sample_trace)
        
        # to ratio land: P.log_probs - Q.log_probs (just the latents)
        tensors = subtract_latent_log_probs(eval_trace, sample_trace)
        # reduce gives loss
        loss_dict = tpp.combine_tensors(tensors)
        assert(len(loss_dict.keys()) == 1)
        key = next(iter(loss_dict))

        return loss_dict[key]


def setup_and_run(tvi, ep=2000, eta=1) :
    print("Q params init:", [ p for p in tvi.q.parameters() ])
    optimiser = t.optim.Adam(tvi.q.parameters(), lr=eta) # optimising q only    
    optimise(tvi, optimiser, ep)
    
    return tvi


def optimise(tvi, optimiser, eps) :
    for i in range(eps):
        optimiser.zero_grad()
        loss = - tvi() 
        loss.backward(retain_graph=True)
        optimiser.step()


def sample_generator(nProtected, P, dataName="__c") :
    k = 1
    trp = sampler(k, nProtected, data={})
    P(trp)
    return trp.trace.out_dicts["sample"][dataName] \
            .squeeze(0)
        

def get_error_on_a(a_mean, n, tvi) :
    a_mean = t.ones(n) * a_mean
    return a_mean - tvi.q.mean_a


# Recovering mean of first var
def main(nvars=3, nProtected=2, k=2, epochs=2000, true_mean=10, lr=0.2) :
    Q = ChainQ()
    P = chain_dist
    
    # Get _c data by sampling generator
    x = sample_generator(nProtected, P, dataName="__c")
    tvi = setup_and_run(TVI(P, Q, k, x, nProtected), epochs, eta=lr)
    
    return get_error_on_a(true_mean, nvars, tvi), tvi

In [239]:
error, tvi = main(nvars=3, k=2, epochs=1000, true_mean=TRUE_MEAN_A, lr=0.2)
print(tvi.q.mean_a)
#, tvi.q.mean_b, tvi.q.logscale_a, tvi.q.logscale_b#, tvi.data_dict

#print(f"{error / TRUE_MEAN_A *100}%")

Q params init: [Parameter containing:
tensor([0., 0., 0.], requires_grad=True), Parameter containing:
tensor([0., 0., 0.], requires_grad=True), Parameter containing:
tensor([0., 0., 0.], requires_grad=True), Parameter containing:
tensor([0., 0., 0.], requires_grad=True)]
Parameter containing:
tensor([0., 0., 0.], requires_grad=True)


In [237]:
Q = ChainQ()
P = chain_dist
k = 2
nProtected = 2 

# Get _c data by sampling generator
x = sample_generator(nProtected, P, dataName="__c")

trq = sampler(k, nProtected, data={"__c": x})
#print(trq.trace.in_dicts)
Q.sample(trq)
print(trq.trace.out_dicts)
trp = evaluator(trq, nProtected, data={"__c": x})
#P(trp) # evaluate under prior!
print(trp.trace.in_dicts)
# compute P logprobs
P(trp)
print(trp.trace.out_dicts)
#print(trq.trace.fn.plate_names)

trp.sum_out_pos()
trq.sum_out_pos()
# aligning dims in Q
trq.trace.out_dicts = rename_placeholders(trp, trq)

# to ratio land: P.log_probs - Q.log_probs (just the latents)
tensors = subtract_latent_log_probs(trp, trq)
print("\n", tensors)
# reduce gives loss
loss_dict = tpp.combine_tensors(tensors)
assert(len(loss_dict.keys()) == 1)
key = next(iter(loss_dict))

loss_dict[key]


{'sample': {'__a': tensor([[[ 0.9192,  3.0368, -3.8186]],

        [[ 3.4921, -2.6384,  4.6960]]], names=('_K', 'pos_A', 'pos_B')), '__b': tensor([[[-0.2439,  1.9266, -4.3716]],

        [[ 4.5765, -3.4173, -2.1387]]], names=('_K', 'pos_A', 'pos_B'))}, 'log_prob': {'__a': tensor([[[-2.0179, -2.2480, -3.3075]],

        [[-2.3626, -2.7530, -2.7764]]], names=('_K', 'pos_A', 'pos_B')), '__b': tensor([[[-2.0927, -2.0860, -2.0345]],

        [[-2.0829, -2.0513, -4.6127]]], names=('_K', 'pos_A', 'pos_B'))}}
{'data': {'__c': tensor([[14.0023, 14.0198, 11.2501]], names=('pos_A', 'pos_B'))}, 'sample': {'__a': tensor([[[ 0.9192,  3.0368, -3.8186]],

        [[ 3.4921, -2.6384,  4.6960]]], names=('_K', 'pos_A', 'pos_B')), '__b': tensor([[[-0.2439,  1.9266, -4.3716]],

        [[ 4.5765, -3.4173, -2.1387]]], names=('_K', 'pos_A', 'pos_B'))}}
{'sample': {'__a': tensor([[[ 0.9192,  3.0368, -3.8186]],

        [[ 3.4921, -2.6384,  4.6960]]], names=('_k__a', 'pos_A', 'pos_B')), '__b': tensor([[[[-0.24

tensor(-12.3278, requires_grad=True)

In [186]:
tr1 = trace({"data": {}}, SampleLogProbK(K=2, protected_dims=2))
d = WrappedDist(Normal, t.ones(3), 3)
a = tr1["a"](d) 
val = chain_dist(tr1)
tr2 = trace({"data": {}, "sample": tr1.trace.out_dicts["sample"]}, LogProbK(tr1.trace.fn.plate_names, 2))
val = chain_dist(tr2)

tr2.trace.out_dicts

{'sample': {'__a': tensor([[[11.0944, 10.5638,  9.4949]],
  
          [[ 5.8477,  8.7341,  9.2536]]], names=('_k__a', 'pos_A', 'pos_B')),
  '__b': tensor([[[[13.9619, 16.5348,  3.4176]]],
  
  
          [[[ 7.3123,  8.6786,  9.7241]]]],
         names=('_k__b', '_k__a', 'pos_A', 'pos_B')),
  '__c': tensor([[[[[ 8.9255, 16.4849,  2.0058]]]],
  
  
  
          [[[[ 7.1843,  4.3577, 11.1260]]]]],
         names=('_k__c', '_k__b', '_k__a', 'pos_A', 'pos_B'))},
 'log_prob': {'__a': tensor([[[-2.0841, -2.0352, -2.0317]],
  
          [[-2.9754, -2.1066, -2.0485]]], names=('_k__a', 'pos_A', 'pos_B')),
  '__b': tensor([[[[-2.4744, -3.9983, -4.0694]],
  
           [[-5.6753, -5.3981, -3.9097]]],
  
  
          [[[-2.8122, -2.2150, -2.0205]],
  
           [[-2.1367, -2.0177, -2.0298]]]],
         names=('_k__b', '_k__a', 'pos_A', 'pos_B')),
  '__c': tensor([[[[[ -3.4267,  -2.0177,  -2.1283]]],
  
  
           [[[ -2.1621,  -5.4030,  -5.3271]]]],
  
  
  
          [[[[ -4.5696, -10.2555, 

## VI, No plates but including deletes


In [None]:
def chain_dist_del(trace):
    a = trace["a"](Norm(t.ones(n)))
    b = trace["b"](Norm(a))
    c = trace["c"](Norm(b))
    (c,) = trace.delete_names(("a", "b"), (c,))
    d = trace["d"](Norm(c))
    
    return c


In [None]:
# call sampler on Q. 
# gives you the samples and a log Q tensor `log_prob`
tr1 = sampler(draws, nProtected, data=data)

val = P(tr1)
log_q = tr1.trace.out_dicts["log_prob"]

# compute the log_probs

# pass these to evaluator, which does a lookup for all the latents 
# gives you log P for each latent
tr2 = evaluator(tr1, nProtected, data=data)
val = P(tr2)

#tr2.trace.out_dicts["log_prob"]
#log_p = 

#Q = pytorch.module
#    - `q.forward()` will look like chain_dist
    

#- optimise it


## plate VI

- For plates, we just don't filter [@17](https://github.com/LaurenceA/tpp/blob/bd1fe20dcf86a1c02cc0424632571fba998d104f/utils.py#L17)
- Painful stuff: need to keep the generative order (e.g. a, b, c, d)
    - because we start by summing the lowest-level plates
        - solution: enforce that the last variable is a leaf e.g. `d`
- Careful when combining P & Q tensors: maintain the ordering!

- Plates: doing the summation backwards through the plates, yeah?
    - This implies tricky implementation blah
    - Py 3.6 dicts are ordered by insertion though, so use that


In [None]:
# example directed graph with plate repeats
# 3(a) -> 4(b) -> c -> d
def plate_dist(trace):
    Na = Norm(t.ones(n))
    a = trace["a"](Na, plate_name="A", plate_shape=3)
    Nb = Norm(a)
    b = trace["b"](Nb, plate_name="B", plate_shape=4)
    Nc = Norm(b)
    c = trace["c"](Nc)
    
    (c,) = trace.delete_names(("a", "b"), (c,))
    
    Nd = Norm(c)
    d = trace["d"](Nd)
    
    return d

tr = sample_and_eval(plate_dist, draws=1, nProtected=2)
tr.trace.out_dicts
tr.trace.in_dicts