# VI without plates

i.e. no repeating bits to abstract over

Optimise the params of an approx posterior over extended Z-space, but not K space

$$Q (Z|X) = \prod_k  Q(Z^k|X) = \prod_k \prod_i Q(Z^k_i \mid Z^k_{qa(i)})$$

and

$$\prod_j f_j^{\kappa_j} = \frac{P(x_, Z)}{\prod Q(z_i^{k_i})}$$

Writing out the target (log marginal likelihood) fully makes the computation clear:

$$ \mathcal{L}= E_{Q(Z|X)} \left[ \log \frac{∑_K  P(Z,K,X)}{Q (Z|X)} \right]$$
$$= E_{Q} \left[ \log \frac{∑_K  P(Z,K,X)}{Q (Z|X)} \right]$$

with

$$
  \frac{P({Z, K, X})}{Q({Z|X})} = 
  P({K}) 
  P \left({X| Z_{\mathrm{pa}{X}}^{K_{\mathrm{pa}{X}}}} \right)  
  \prod_i 
  \frac{P\left({Z_i^{K_i}| Z^{K_{\mathrm{pa}(i)}}_{\mathrm{pa}(i)}} \right)}
  {Q \left(  
    Z_i^{K_i}| Z^{K_i}_{\mathrm{qa}(i)}
  \right)}$$

## The computation

1. Form joint P/Q: index prior * lik * product of latent P/Qs
2. $\mathcal{L}$: sum out K, then log P - log Q, then average

In [1]:
import math
import numpy as np
import torch as t
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions import MultivariateNormal as MVN

import sys; sys.path.append("..")
from tpp_trace import *
import utils as u
import tensor_ops as tpp
from tvi import *

import imp


## First with no plates

In [2]:
# a factorised approx posterior. generate 3 simple variables
# sample along the chain

# a ~ N([1],[3])
# b ~ N(a,[3])
# c ~ N(b,[3])

n = 3
scale = n
Norm = lambda mu, var : WrappedDist(Normal, mu, var)

TRUE_MEAN_A = 10

# Prior
# a -> b -> c observed
def chain_dist(trace, n=3):
    a = trace["a"](Norm(t.ones(n) * TRUE_MEAN_A, scale))
    b = trace["b"](Norm(a, scale))
    c = trace["c"](Norm(b, scale))
    
    return c

# a param placeholder
# Hardcoding 2 params for each var, for now
# factorised Gaussian with learned means and covs
class ChainQ(nn.Module):
    def __init__(self, n=3):
        super().__init__()
        self.mean_a = nn.Parameter(t.ones(n))
        self.mean_b = nn.Parameter(t.ones(n))
        self.logscale_a = nn.Parameter(t.ones(n)) # t.log(t.ones(n))
        self.logscale_b = nn.Parameter(t.ones(n))
    
    # TODO: make this actually depend on the params
    def sample(self, trace) :
        a = trace["a"](WrappedDist(Normal, self.mean_a, t.exp(self.logscale_a)))#Norm(self.mean_a, t.exp(self.logscale_a)))
        b = trace["b"](WrappedDist(Normal, self.mean_b, t.exp(self.logscale_b)))#Norm(a * self.mean_b, t.exp(self.logscale_b)))


In [11]:
class TVI(nn.Module) :
    def __init__(self, p, q, k, x, nProtectedDims):
        super().__init__()
        self.p = p
        self.q = q
        self.k = k
        self.nProtected = nProtectedDims
        
        self.data_dict = {}
        self.data_dict["__c"] = []
        self.data = nn.Parameter(x, requires_grad=False) 
        
    # 1. s = sample Q
    # 2. lp_Q = eval Q.logprob(s)
    # 3. lp_P = eval P.logprob(s)
    # 4. f = lp_P - lp_Q
    # 5. loss = combine fs
    def forward(self):
        self.data_dict["__c"].append(self.data)
        
        # init traces at each step
        sample_trace = sampler(self.k, self.nProtected, data={"__c": self.data})
        # sample recognition model Q -> Q-sample and Q-logprobs
        self.q.sample(sample_trace)
        # compute P logprobs under prior
        #self.p(sample_trace)
        
        # Pass Q samples to new trace
        eval_trace = evaluator(sample_trace, self.nProtected, data={"__c": self.data})
        # compute P logprobs under ...
        self.p(eval_trace)
        
        eval_trace.sum_out_pos()
        sample_trace.sum_out_pos()
        # align dims in Q
        sample_trace.trace.out_dicts = rename_placeholders(eval_trace, sample_trace)
        
        # to ratio land: P.log_probs - Q.log_probs (just the latents)
        tensors = subtract_latent_log_probs(eval_trace, sample_trace)
        
        # REMOVE after debug
        # Try using just P for loss
        #tensors = eval_trace.trace.out_dicts["log_prob"]
        
        # reduce gives loss
        loss_dict = tpp.combine_tensors(tensors)
        assert(len(loss_dict.keys()) == 1)
        key = next(iter(loss_dict))

        return loss_dict[key]


def setup_and_run(tvi, ep=2000, eta=1) :
    print("Q params init:", list(iter(tvi.q.parameters())))
    optimiser = t.optim.Adam(tvi.q.parameters(), lr=eta) # optimising q only    
    optimise(tvi, optimiser, ep)
    
    return tvi


def optimise(tvi, optimiser, eps) :
    for i in range(eps):
        optimiser.zero_grad()
        loss = - tvi() 
        loss.backward()# retain_graph=True

        #for name, param in tvi.named_parameters():
        #    print(name, param.grad)

        optimiser.step()
        
        


def sample_generator(nProtected, P, dataName="__c") :
    k = 1
    trp = sampler(k, nProtected, data={})
    P(trp)
    return trp.trace.out_dicts["sample"][dataName] \
            .squeeze(0)
        

def get_error_on_a(a_mean, n, tvi) :
    a_mean = t.ones(n) * a_mean
    return a_mean - tvi.q.mean_a


# Recovering mean of first var
def main(nvars=3, nProtected=2, k=2, epochs=2000, true_mean=10, lr=0.2) :
    Q = ChainQ()
    P = chain_dist
    
    # Get _c data by sampling generator
    x = sample_generator(nProtected, P, dataName="__c")
    tvi = setup_and_run(TVI(P, Q, k, x, nProtected), epochs, eta=lr)
    
    return get_error_on_a(true_mean, nvars, tvi), tvi

In [15]:
ep = 5000 
error, tvi = main(nvars=3, k=2, epochs=ep, true_mean=TRUE_MEAN_A, lr=0.1)

#tvi.q.mean_a, tvi.q.mean_b, tvi.q.logscale_a, tvi.q.logscale_b#, tvi.data_dict
print()
print(f"{error / TRUE_MEAN_A *100}%")

Q params init: [Parameter containing:
tensor([1., 1., 1.], requires_grad=True), Parameter containing:
tensor([1., 1., 1.], requires_grad=True), Parameter containing:
tensor([1., 1., 1.], requires_grad=True), Parameter containing:
tensor([1., 1., 1.], requires_grad=True)]

tensor([-1.3009, -4.8312, -1.0335], grad_fn=<MulBackward0>)%


In [None]:
Q = ChainQ()
P = chain_dist
k = 2
nProtected = 2 

# Get _c data by sampling generator
x = sample_generator(nProtected, P, dataName="__c")

trq = sampler(k, nProtected, data={"__c": x})
#print(trq.trace.in_dicts)
Q.sample(trq)
#print(trq.trace.out_dicts)
trp = evaluator(trq, nProtected, data={"__c": x})
#P(trp) # evaluate under prior!
#print(trp.trace.in_dicts)
# compute P logprobs
P(trp)
#print(trp.trace.out_dicts)
#print(trq.trace.fn.plate_names)

trp.sum_out_pos()
trq.sum_out_pos()
# aligning dims in Q
trq.trace.out_dicts = rename_placeholders(trp, trq)

# to ratio land: P.log_probs - Q.log_probs (just the latents)
tensors = subtract_latent_log_probs(trp, trq)
#print("\n", tensors)
# reduce gives loss
loss_dict = tpp.combine_tensors(tensors)
assert(len(loss_dict.keys()) == 1)
key = next(iter(loss_dict))

loss_dict[key]


In [None]:
import torch
import numpy as np

# Temp, rain, hum
inputs = np.array([[73, 67, 43], 
                   [91, 88, 64], 
                   [87, 134, 58], 
                   [102, 43, 37], 
                   [69, 96, 70]], dtype='float32')

# apples and oranges
targets = np.array([[56, 70], 
                    [81, 101], 
                    [119, 133], 
                    [22, 37], 
                    [103, 119]], dtype='float32')

inputs = torch.from_numpy(inputs)
targets = torch.from_numpy(targets)
print(inputs)
print(targets)

# weights and Biases
w = torch.randn(2,3, requires_grad=True)
b = torch.randn(2, requires_grad=True)
print(w)
print(b)


#  Define model
def model(x):
    return x @ w.t() + b

predit = model(inputs)

def mse(t1, t2):
    diff = t1 - t2
    return torch.sum(diff * diff)/ diff.numel()

loss = mse(predit, targets)
loss


# Compute gradient
w.retain_grad()
b.retain_grad()
loss.backward()

print(w)
print(w.grad)


## VI, No plates but including deletes


In [None]:
def chain_dist_del(trace):
    a = trace["a"](Norm(t.ones(n)))
    b = trace["b"](Norm(a))
    c = trace["c"](Norm(b))
    (c,) = trace.delete_names(("a", "b"), (c,))
    d = trace["d"](Norm(c))
    
    return c


In [None]:
# call sampler on Q. 
# gives you the samples and a log Q tensor `log_prob`
tr1 = sampler(draws, nProtected, data=data)

val = P(tr1)
log_q = tr1.trace.out_dicts["log_prob"]

# compute the log_probs

# pass these to evaluator, which does a lookup for all the latents 
# gives you log P for each latent
tr2 = evaluator(tr1, nProtected, data=data)
val = P(tr2)

#tr2.trace.out_dicts["log_prob"]
#log_p = 

#Q = pytorch.module
#    - `q.forward()` will look like chain_dist
    

#- optimise it


## plate VI

- For plates, we just don't filter [@17](https://github.com/LaurenceA/tpp/blob/bd1fe20dcf86a1c02cc0424632571fba998d104f/utils.py#L17)
- Painful stuff: need to keep the generative order (e.g. a, b, c, d)
    - because we start by summing the lowest-level plates
        - solution: enforce that the last variable is a leaf e.g. `d`
- Careful when combining P & Q tensors: maintain the ordering!

- Plates: doing the summation backwards through the plates, yeah?
    - This implies tricky implementation blah
    - Py 3.6 dicts are ordered by insertion though, so use that


In [None]:
# example directed graph with plate repeats
# 3(a) -> 4(b) -> c -> d
def plate_dist(trace):
    Na = Norm(t.ones(n))
    a = trace["a"](Na, plate_name="A", plate_shape=3)
    Nb = Norm(a)
    b = trace["b"](Nb, plate_name="B", plate_shape=4)
    Nc = Norm(b)
    c = trace["c"](Nc)
    
    (c,) = trace.delete_names(("a", "b"), (c,))
    
    Nd = Norm(c)
    d = trace["d"](Nd)
    
    return d

tr = sample_and_eval(plate_dist, draws=1, nProtected=2)
tr.trace.out_dicts
tr.trace.in_dicts