# TMC by hand, in raw Pytorch 

e.g. Conjugate Gaussian chain

no `trace` or combine

Should be possible to do a chain model ("just a list of matrices")
 
 `a = N(0, s)`
 
 `b = N(a, s)`
 
 `c = N(b, s)`
 
 i.e. one known variable $c$: the data, at the end of the chain



<img src="tmc.png" style="width: 60%;"/>

In [3]:
import math
import numpy as np
import torch as t
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions import MultivariateNormal as MVN

In [20]:
DATA_DIM = t.Size([1])
N_VARS = 3

# TODO: allow this to vary across variables
BATCH_SIZE = 2 # samples per latent
N_REC_LAYERS = 4

In [54]:
# Represent a normal as two adaptive params and an eval method
# from https://github.com/anonymous-78913/tmc-anon/blob/master/param/lvm.py
class ParamNormal(nn.Module):
    def __init__(self, sample_shape, scale=1.):
        super().__init__()
        self.loc = nn.Parameter(t.zeros(sample_shape))
        self.log_scale = nn.Parameter(math.log(scale)*t.ones(sample_shape))

    def forward(self):
        return Normal(self.loc, self.log_scale.exp())


class LinearNormal(nn.Module):
    def __init__(self, sample_shape=t.Size([]), scale=1.):
        super().__init__()
        self.log_scale = nn.Parameter(math.log(scale)*t.ones(sample_shape))

    def forward(self, input_):
        return Normal(input_, self.log_scale.exp())

# Hardcode a basic chain a -> b -> x

In [59]:
# generative_model
# TODO: let s vary
class ChainP(nn.Module):
    def __init__(self, s):
        super().__init__()
        self.Pa = ParamNormal((), scale=s)
        self.Pb = LinearNormal((), scale=s)
        self.Px = LinearNormal((), scale=s)

    def sample(self, N):
        a = self.Pa().sample()
        b = self.Pb(a).sample(sample_shape=t.Size([N]))
        x = self.Px(b).sample()
        
        return (x, (a.unsqueeze(-1), b))

    def log_prob(self, samples):
        x, (a, b) = samples
        logPa   = self.Pa().log_prob(a)
        logPb = self.Pb(a).log_prob(b)
        logPx = self.Px(b).log_prob(x)
        
        return logPa.sum(-1) \
                + logPb.sum(-1) \
                + logPx.sum(-1)


# TODO: isotropic?
class ChainQ(nn.Module):
    def __init__(self, s):
        super().__init__()
        self.Qa = ParamNormal((), scale=s)
        sigma_plus = math.sqrt(s**2 + s**2)
        self.Qb = ParamNormal((), scale=sigma_plus)

    def sample(self, N, sample_shape=t.Size([])):
        a = self.Qa().sample(sample_shape=sample_shape)
        b = self.Qb().sample(sample_shape=t.Size([*sample_shape, N]))
        
        return (a.unsqueeze(-1), b)

    def log_prob(self, samples):
        a, b = samples
        logQa = self.Qa().log_prob(a)
        logQb = self.Qb().log_prob(b)
        
        return logQa.sum(-1) + logQb.sum(-1)

In [84]:
def logmeanexp(x, dim=0):
    max_ = x.max(dim=dim, keepdim=True)[0]
    normed = x - max_
    lme = normed.exp() \
            .mean(dim, keepdim=True) \
            .log() 

    return (lme + max_).squeeze(dim)


class SimpleTMC(nn.Module) :
    def __init__(self, p, q, k):
        super().__init__()
        self.p = p
        self.q = q
        # TODO: allow variable k 
        self.K = k
    
    """
    def reduce_factors(self, log_ps) :
        return
    
    
    def tmc_marginal_likelihood(self, log_probs, ks) :
        norm = 1 / np.prod(ks)

        return norm * reduce_factors(log_probs)
    """

    def forward(self, x):
        a  = self.q.Qa().sample(sample_shape=t.Size([self.K, 1, 1]))
        b  = self.q.Qb().sample(sample_shape=t.Size([self.K, x.size(0)]))
        
        fa = self.p.Pa().log_prob(a) - self.q.Qa().log_prob(a)
        fb = self.p.Pb(a).log_prob(b) - self.q.Qb().log_prob(b)
        fx = self.p.Px(b).log_prob(x)
        
        f_int_b = logmeanexp(fb + fx, -2)
        f_int_b = f_int_b.sum(-1) + fa.view(-1)
        f_int_a = logmeanexp(f_int_b)

        return f_int_a
    
    
    def loss(self, **kwargs) :
        # log_p_x_z = 
        # kld_loss = 
        # importance weights
        # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1
        # scaled = torch.sum(weight * log_weight, dim=-1)
        #loss_ = torch.mean()

        return #loss_

In [79]:
N = 100
k = 2
n_vars = 3
iters = 100

# TODO: separate
sa = 1.
sb = sa
sx = sa


p = ChainP(sa)
q = ChainQ(sb)
tmc = SimpleTMC(p, q, k)
x, _ = tmc.p.sample(N)

tmcs = []
for i in range(iters):
    t.manual_seed(i)
    res = tmc(x)
    tmcs.append(res.detach().numpy())
    

np.array(tmcs).mean()

-219.23984

# optimise

In [82]:
EPOCHS = 3000

def run(x, ep=5000, eta=0.2) :
    tmc = SimpleTMC(p, q, k)
    optimiser = t.optim.Adam(tmc.parameters(), lr=eta)
    X = Variable(t.Tensor(x), requires_grad=False) 

    optimise(tmc, x, optimiser, ep)
    
    return tmc


def optimise(tmc, x, optimiser, eps, verbose=False) :
    for i in range(eps):
        loss = - tmc.loss(x)
        optimiser.zero_grad()
        loss.backward(retain_graph=True)
        optimiser.step()

        if verbose :
            if i % 500 == 0:
                print(q.get_mean(), q.get_var())


tmc = run(x, ep=EPOCHS)
tmc

NameError: name 'Variable' is not defined

# Other attempts (more manual)


In [None]:
# We sample from Q, an approx posterior
# Has to be an isotropic Gaussian, 
# also a nn.module
def sample_model(k, c, prior_mean=0, var=1.0):
    s = t.Size([k])
    
    a = Normal(prior_mean, var)
    z_a = a.rsample(s)

    b = Normal(z_a, var)
    z_b = b.rsample()
    
    # TODO: ?
    c 
    
    return z_a, z_b


# TODO
def get_factors(P, x) :
    return P.log_prob(x)


ks = [BATCH_SIZE] * N_VARS


data = t.randn(1)
samples = sample_model(BATCH_SIZE, data)

In [28]:
class BaseClassAE(nn.Module):
    
    def __init__(self) -> None:
        super(BaseClassAE, self).__init__()

    def encode(self, input: t.Tensor) -> list:
        raise NotImplementedError

    def decode(self, input: t.Tensor) :
        raise NotImplementedError

    def sample(self, batch_size:int, current_device: int, **kwargs) -> t.Tensor:
        raise RuntimeWarning()

    def generate(self, x: t.Tensor, **kwargs) -> t.Tensor:
        raise NotImplementedError

    def forward(self, *inputs: t.Tensor) -> t.Tensor:
        pass

    def loss_function(self, *inputs, **kwargs) -> t.Tensor:
        pass


tensor(-146.6355)


In [None]:
"""
    estimators
"""
def vae_marginal_likelihood(x,z) :
    return P(x,z) / Q(z)