# TMC by hand, in raw Pytorch 

e.g. Conjugate Gaussian chain

no `trace` or combine

Should be possible to do a chain model ("just a list of matrices")
 
 `a = N(0, s)`
 
 `b = N(a, s)`
 
 `c = N(b, s)`
 
 i.e. one known variable $c$: the data, at the end of the chain


* TMC module 
  * P module
    * ParamNormal modules
  * Q module
    * ParamNormal modules

In [1]:
import math
import numpy as np
import torch as t
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions import MultivariateNormal as MVN

In [2]:
# Represent a normal as two adaptive params and an eval method
# from https://github.com/anonymous-78913/tmc-anon/blob/master/param/lvm.py
class ParamNormal(nn.Module):
    def __init__(self, shape, mean=0, scale=1.):
        super().__init__()
        self.loc = nn.Parameter(t.ones(size=shape) * mean)
        self.log_scale = nn.Parameter(t.ones(shape) * math.log(scale))

        
    def forward(self):
        return Normal(self.loc, self.log_scale.exp())


class LinearNormal(nn.Module):
    def __init__(self, shape=t.Size([]), scale=1.):
        super().__init__()
        self.log_scale = nn.Parameter(t.ones(shape) * math.log(scale))

        
    def forward(self, input_):
        return Normal(input_, self.log_scale.exp())

# Hardcode a basic chain a -> b -> x

In [3]:
# generative_model
# TODO: let s vary
class ChainP(nn.Module):
    def __init__(self, sigma, mean=0):
        super().__init__()
        self.Pa = ParamNormal((), mean=mean, scale=sigma)
        self.Pb = LinearNormal((), scale=sigma)
        self.Px = LinearNormal((), scale=sigma)

        
    def sample(self, N):
        a = self.Pa().rsample()
        b = self.Pb(a).rsample(sample_shape=t.Size([N]))
        x = self.Px(b).rsample()
        
        return x, a.unsqueeze(-1), b

    
    def log_prob(self, samples):
        x, a, b = samples
        log_Pa = self.Pa().log_prob(a)
        log_Pb = self.Pb(a).log_prob(b)
        log_Px = self.Px(b).log_prob(x)
        
        return log_Pa.sum(-1) \
                + log_Pb.sum(-1) \
                + log_Px.sum(-1)


# TODO: isotropic?
class ChainQ(nn.Module):
    def __init__(self, s):
        super().__init__()
        self.Qa = ParamNormal((), scale=s)
        sigma_plus = math.sqrt(s**2 + s**2)
        self.Qb = ParamNormal((), scale=sigma_plus)

        
    def sample(self, N, shape=t.Size([])):
        a = self.Qa().rsample(sample_shape=shape)
        b = self.Qb().rsample(sample_shape=t.Size([*shape, N]))
        
        return (a.unsqueeze(-1), b)

    
    def log_prob(self, samples):
        a, b = samples
        logQa = self.Qa().log_prob(a)
        logQb = self.Qb().log_prob(b)
        
        return logQa.sum(-1) + logQb.sum(-1)

In [5]:
def logmeanexp(x, dim=0):
    max_ = x.max(dim=dim, keepdim=True)[0]
    normed = x - max_
    lme = normed.exp() \
            .mean(dim, keepdim=True) \
            .log() 

    return (lme + max_).squeeze(dim)


class SimpleTMC(nn.Module) :
    def __init__(self, p, q, k):
        super().__init__()
        self.p = p
        self.q = q
        # TODO: allow variable k 
        self.K = k

    
    def get_error_on_a(self) :
        return float(self.p.Pa.loc - self.q.Qa.loc)
    

    def forward(self, x):
        Qa, Qb = self.q.Qa, self.q.Qb
        Pa, Pb, Px = self.p.Pa, self.p.Pb, self.p.Px
        
        a = Qa().rsample(sample_shape=t.Size([self.K, 1, 1]))
        b = Qb().rsample(sample_shape=t.Size([self.K, x.size(0)]))
        
        fa = Pa().log_prob(a) - Qa().log_prob(a)
        fb = Pb(a).log_prob(b) - Qb().log_prob(b)
        fx = Px(b).log_prob(x)
        
        f_int_b = logmeanexp(fb + fx, dim=-2).sum(dim=-1) \
                    + fa.view(-1) # plate involved here
        
        return logmeanexp(f_int_b)


# optimise

In [6]:
# >10,000 epochs for serious eval
def setup_and_run(tmc, x, ep=2000, eta=0.2) :
    optimiser = t.optim.Adam(tmc.q.parameters(), lr=eta) # optimising q only
    X = nn.Parameter(t.Tensor(x), requires_grad=False) 

    optimise(tmc, X, optimiser, ep)
    
    return tmc


def optimise(tmc, x, optimiser, eps) :
    for i in range(eps):
        loss = - tmc(x)
        optimiser.zero_grad()
        loss.backward(retain_graph=True)
        optimiser.step()
        #print(tmc.q.Qa.loc.grad)


def main(a_mu) :
    sa = 1.
    sb = sa
    sx = sa
    
    N = 100
    k = 5
    
    p = ChainP(sa, mean=a_mu)
    q = ChainQ(sb)
    tmc = SimpleTMC(p, q, k)
    x, _, _ = tmc.p.sample(N)

    tmc = setup_and_run(tmc, x)
    
    return tmc.get_error_on_a()

a_mu = 100
main(a_mu)

3.501129150390625

In [None]:
N = 2#30
errors = np.array([main(a_mu) for i in range(N)])
avg_error = errors.mean()

In [None]:
print("Average error:", round(avg_error / a_mu * 100, 2), "%")
print("Error variance:", round(np.array(errors).var() / a_mu * 100, 2), "%")


# Other attempts (more manual)


In [None]:
# We sample from Q, an approx posterior
# Has to be an isotropic Gaussian, 
# also a nn.module
def sample_model(k, c, prior_mean=0, var=1.0):
    s = t.Size([k])
    
    a = Normal(prior_mean, var)
    z_a = a.rsample(s)

    b = Normal(z_a, var)
    z_b = b.rsample()
    
    return z_a, z_b


# TODO
def get_factors(P, x) :
    return P.log_prob(x)


BATCH_SIZE = 2
ks = [BATCH_SIZE] * N_VARS


data = t.randn(1)
samples = sample_model(BATCH_SIZE, data)

In [None]:
"""
    estimators
"""
def vae_marginal_likelihood(x,z) :
    return P(x,z) / Q(z)