# TMC by hand, in raw Pytorch 

e.g. Conjugate Gaussian chain

no `trace` or combine

Should be possible to do a chain model ("just a list of matrices")
 
 `a = N(0, s)`
 
 `b = N(a, s)`
 
 `c = N(b, s)`
 
 i.e. one known variable $c$: the data, at the end of the chain


* TMC module 
  * P module
    * ParamNormal modules
  * Q module
    * ParamNormal modules

In [2]:
import math
import numpy as np
import torch as t
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions import MultivariateNormal as MVN

In [3]:
# Represent a normal as two adaptive params and an eval method
# from https://github.com/anonymous-78913/tmc-anon/blob/master/param/lvm.py
class ParamNormal(nn.Module):
    def __init__(self, shape, mean=0, scale=1.):
        super().__init__()
        self.loc = nn.Parameter(t.ones(size=shape) * mean)
        self.log_scale = nn.Parameter(t.ones(shape) * math.log(scale))

        
    def forward(self):
        return Normal(self.loc, self.log_scale.exp())


class LinearNormal(nn.Module):
    def __init__(self, shape=t.Size([]), scale=1.):
        super().__init__()
        self.log_scale = nn.Parameter(t.ones(shape) * math.log(scale))

        
    def forward(self, input_):
        return Normal(input_, self.log_scale.exp())

# Hardcode a basic chain a -> b -> x

In [4]:
# generative_model
# TODO: let s vary
class ChainP(nn.Module):
    def __init__(self, sigma, mean=0):
        super().__init__()
        self.Pa = ParamNormal((), mean=mean, scale=sigma)
        self.Pb = LinearNormal((), scale=sigma)
        self.Px = LinearNormal((), scale=sigma)

        
    def sample(self, N):
        a = self.Pa().rsample()
        b = self.Pb(a).rsample(sample_shape=t.Size([N]))
        x = self.Px(b).rsample()
        
        return x, a.unsqueeze(-1), b

    
    def log_prob(self, samples):
        x, a, b = samples
        log_Pa = self.Pa().log_prob(a)
        log_Pb = self.Pb(a).log_prob(b)
        log_Px = self.Px(b).log_prob(x)
        
        return log_Pa.sum(-1) \
                + log_Pb.sum(-1) \
                + log_Px.sum(-1)


# TODO: isotropic?
class ChainQ(nn.Module):
    def __init__(self, s):
        super().__init__()
        self.Qa = ParamNormal((), scale=s)
        sigma_plus = math.sqrt(s**2 + s**2)
        self.Qb = ParamNormal((), scale=sigma_plus)

        
    def sample(self, N, shape=t.Size([])):
        a = self.Qa().rsample(sample_shape=shape)
        b = self.Qb().rsample(sample_shape=t.Size([*shape, N]))
        
        return (a.unsqueeze(-1), b)

    
    def log_prob(self, samples):
        a, b = samples
        logQa = self.Qa().log_prob(a)
        logQb = self.Qb().log_prob(b)
        
        return logQa.sum(-1) + logQb.sum(-1)

In [5]:
def logmeanexp(x, dim=0):
    max_ = x.max(dim=dim, keepdim=True)[0]
    normed = x - max_
    lme = normed.exp() \
            .mean(dim, keepdim=True) \
            .log() 

    return (lme + max_).squeeze(dim)


class SimpleTMC(nn.Module) :
    def __init__(self, p, q, k):
        super().__init__()
        self.p = p
        self.q = q
        # TODO: allow variable k 
        self.K = k

    
    def get_error_on_a(self) :
        return float(self.p.Pa.loc - self.q.Qa.loc)
    

    def forward(self, x):
        Qa, Qb = self.q.Qa, self.q.Qb
        Pa, Pb, Px = self.p.Pa, self.p.Pb, self.p.Px
        
        a = Qa().rsample(sample_shape=t.Size([self.K, 1, 1]))
        b = Qb().rsample(sample_shape=t.Size([self.K, x.size(0)]))
        
        fa = Pa().log_prob(a) - Qa().log_prob(a)
        fb = Pb(a).log_prob(b) - Qb().log_prob(b)
        fx = Px(b).log_prob(x)
        
        f_int_b = logmeanexp(fb + fx, dim=-2).sum(dim=-1) \
                    + fa.view(-1) # plate involved here
        
        return logmeanexp(f_int_b)


# optimise

In [6]:
# >10,000 epochs for serious eval
def setup_and_run(tmc, x, ep=2000, eta=0.2) :
    optimiser = t.optim.Adam(tmc.q.parameters(), lr=eta) # optimising q only
    X = nn.Parameter(t.Tensor(x), requires_grad=False) 

    optimise(tmc, X, optimiser, ep)
    
    return tmc


def optimise(tmc, x, optimiser, eps) :
    for i in range(eps):
        loss = - tmc(x)
        #print(loss)
        optimiser.zero_grad()
        loss.backward(retain_graph=True)
        optimiser.step()


def main(a_mu) :
    sa = 1.
    sb = sa
    sx = sa
    
    N = 100
    k = 5
    
    p = ChainP(sa, mean=a_mu)
    q = ChainQ(sb)
    tmc = SimpleTMC(p, q, k)
    x, _, _ = tmc.p.sample(N)

    tmc = setup_and_run(tmc, x)
    
    return tmc.get_error_on_a()

a_mu = 100
main(a_mu)

tensor(493785.9062, grad_fn=<NegBackward>)
tensor(488125.5312, grad_fn=<NegBackward>)
tensor(482489., grad_fn=<NegBackward>)
tensor(475840.0312, grad_fn=<NegBackward>)
tensor(468980.1875, grad_fn=<NegBackward>)
tensor(454950.5625, grad_fn=<NegBackward>)
tensor(444112.0938, grad_fn=<NegBackward>)
tensor(431173.5000, grad_fn=<NegBackward>)
tensor(419242.0938, grad_fn=<NegBackward>)
tensor(405388.0625, grad_fn=<NegBackward>)
tensor(393437.3125, grad_fn=<NegBackward>)
tensor(370368.1250, grad_fn=<NegBackward>)
tensor(339497.3750, grad_fn=<NegBackward>)
tensor(307782.4375, grad_fn=<NegBackward>)
tensor(304594.3750, grad_fn=<NegBackward>)
tensor(268993.6875, grad_fn=<NegBackward>)
tensor(248115.6094, grad_fn=<NegBackward>)
tensor(230215.4062, grad_fn=<NegBackward>)
tensor(239558.8594, grad_fn=<NegBackward>)
tensor(191103.8750, grad_fn=<NegBackward>)
tensor(241810.3906, grad_fn=<NegBackward>)
tensor(173769.9219, grad_fn=<NegBackward>)
tensor(179350.7656, grad_fn=<NegBackward>)
tensor(289920.9

tensor(74220.0547, grad_fn=<NegBackward>)
tensor(78498.7422, grad_fn=<NegBackward>)
tensor(104399.4609, grad_fn=<NegBackward>)
tensor(91234.2031, grad_fn=<NegBackward>)
tensor(129641.3359, grad_fn=<NegBackward>)
tensor(82608.2734, grad_fn=<NegBackward>)
tensor(146638.6719, grad_fn=<NegBackward>)
tensor(68222.3125, grad_fn=<NegBackward>)
tensor(99912.1484, grad_fn=<NegBackward>)
tensor(103759.6797, grad_fn=<NegBackward>)
tensor(96066.6016, grad_fn=<NegBackward>)
tensor(75357.0312, grad_fn=<NegBackward>)
tensor(93240.6562, grad_fn=<NegBackward>)
tensor(104392.2578, grad_fn=<NegBackward>)
tensor(141412.5781, grad_fn=<NegBackward>)
tensor(95338.8203, grad_fn=<NegBackward>)
tensor(81558.3594, grad_fn=<NegBackward>)
tensor(61239.4492, grad_fn=<NegBackward>)
tensor(88713.1016, grad_fn=<NegBackward>)
tensor(75735.2812, grad_fn=<NegBackward>)
tensor(74221.4219, grad_fn=<NegBackward>)
tensor(81196.0703, grad_fn=<NegBackward>)
tensor(102484.6094, grad_fn=<NegBackward>)
tensor(89562.3750, grad_fn=

tensor(38798.7500, grad_fn=<NegBackward>)
tensor(94739.7578, grad_fn=<NegBackward>)
tensor(53793.2109, grad_fn=<NegBackward>)
tensor(46700.8711, grad_fn=<NegBackward>)
tensor(56929.1445, grad_fn=<NegBackward>)
tensor(37290.3984, grad_fn=<NegBackward>)
tensor(36375.0234, grad_fn=<NegBackward>)
tensor(39304.4453, grad_fn=<NegBackward>)
tensor(50278.4141, grad_fn=<NegBackward>)
tensor(33887.9062, grad_fn=<NegBackward>)
tensor(39755.3359, grad_fn=<NegBackward>)
tensor(40140.3516, grad_fn=<NegBackward>)
tensor(27605.5762, grad_fn=<NegBackward>)
tensor(55177.6172, grad_fn=<NegBackward>)
tensor(33122.5039, grad_fn=<NegBackward>)
tensor(91045.3984, grad_fn=<NegBackward>)
tensor(54909.6602, grad_fn=<NegBackward>)
tensor(34712.0234, grad_fn=<NegBackward>)
tensor(32823.9180, grad_fn=<NegBackward>)
tensor(106452.5000, grad_fn=<NegBackward>)
tensor(82607.3516, grad_fn=<NegBackward>)
tensor(35499.6211, grad_fn=<NegBackward>)
tensor(39921.5664, grad_fn=<NegBackward>)
tensor(39546.4609, grad_fn=<NegBa

tensor(19692.5391, grad_fn=<NegBackward>)
tensor(11599.7783, grad_fn=<NegBackward>)
tensor(32729.9980, grad_fn=<NegBackward>)
tensor(21424.9258, grad_fn=<NegBackward>)
tensor(15575.6006, grad_fn=<NegBackward>)
tensor(35071.0742, grad_fn=<NegBackward>)
tensor(64325.2148, grad_fn=<NegBackward>)
tensor(12023.1357, grad_fn=<NegBackward>)
tensor(26984.9277, grad_fn=<NegBackward>)
tensor(31596.0488, grad_fn=<NegBackward>)
tensor(27523.5332, grad_fn=<NegBackward>)
tensor(24160.7480, grad_fn=<NegBackward>)
tensor(25917.1133, grad_fn=<NegBackward>)
tensor(11438.3018, grad_fn=<NegBackward>)
tensor(58100.0469, grad_fn=<NegBackward>)
tensor(9305.0850, grad_fn=<NegBackward>)
tensor(10762.8008, grad_fn=<NegBackward>)
tensor(9304.8486, grad_fn=<NegBackward>)
tensor(43427.4141, grad_fn=<NegBackward>)
tensor(10562.3926, grad_fn=<NegBackward>)
tensor(9616.4062, grad_fn=<NegBackward>)
tensor(15979.9482, grad_fn=<NegBackward>)
tensor(17779.6074, grad_fn=<NegBackward>)
tensor(9951.9619, grad_fn=<NegBackwar

tensor(3514.2559, grad_fn=<NegBackward>)
tensor(3494.5139, grad_fn=<NegBackward>)
tensor(7465.7983, grad_fn=<NegBackward>)
tensor(4104.2886, grad_fn=<NegBackward>)
tensor(4521.6328, grad_fn=<NegBackward>)
tensor(3751.6772, grad_fn=<NegBackward>)
tensor(22906.3125, grad_fn=<NegBackward>)
tensor(23355.5352, grad_fn=<NegBackward>)
tensor(17931.3027, grad_fn=<NegBackward>)
tensor(15974.7305, grad_fn=<NegBackward>)
tensor(4045.5598, grad_fn=<NegBackward>)
tensor(4578.5449, grad_fn=<NegBackward>)
tensor(4637.7466, grad_fn=<NegBackward>)
tensor(2873.5156, grad_fn=<NegBackward>)
tensor(10218.0264, grad_fn=<NegBackward>)
tensor(9411.4834, grad_fn=<NegBackward>)
tensor(10946.9717, grad_fn=<NegBackward>)
tensor(10455.7041, grad_fn=<NegBackward>)
tensor(28323.3555, grad_fn=<NegBackward>)
tensor(20881.8867, grad_fn=<NegBackward>)
tensor(4118.3887, grad_fn=<NegBackward>)
tensor(22193.8320, grad_fn=<NegBackward>)
tensor(3920.7134, grad_fn=<NegBackward>)
tensor(3496.3765, grad_fn=<NegBackward>)
tensor

tensor(7843.8340, grad_fn=<NegBackward>)
tensor(3343.6960, grad_fn=<NegBackward>)
tensor(1241.2631, grad_fn=<NegBackward>)
tensor(1348.0588, grad_fn=<NegBackward>)
tensor(1018.0379, grad_fn=<NegBackward>)
tensor(3086.3586, grad_fn=<NegBackward>)
tensor(1033.7583, grad_fn=<NegBackward>)
tensor(2346.0298, grad_fn=<NegBackward>)
tensor(1403.2225, grad_fn=<NegBackward>)
tensor(5292.4907, grad_fn=<NegBackward>)
tensor(1921.9236, grad_fn=<NegBackward>)
tensor(928.6342, grad_fn=<NegBackward>)
tensor(25024.7559, grad_fn=<NegBackward>)
tensor(867.6439, grad_fn=<NegBackward>)
tensor(4586.4629, grad_fn=<NegBackward>)
tensor(1420.6211, grad_fn=<NegBackward>)
tensor(1051.3213, grad_fn=<NegBackward>)
tensor(4342.4785, grad_fn=<NegBackward>)
tensor(1181.0032, grad_fn=<NegBackward>)
tensor(1395.5952, grad_fn=<NegBackward>)
tensor(3095.9033, grad_fn=<NegBackward>)
tensor(3345.6858, grad_fn=<NegBackward>)
tensor(1196.7172, grad_fn=<NegBackward>)
tensor(1997.7634, grad_fn=<NegBackward>)
tensor(1732.4960,

tensor(727.1432, grad_fn=<NegBackward>)
tensor(3659.4631, grad_fn=<NegBackward>)
tensor(541.2255, grad_fn=<NegBackward>)
tensor(828.0511, grad_fn=<NegBackward>)
tensor(478.8044, grad_fn=<NegBackward>)
tensor(693.7191, grad_fn=<NegBackward>)
tensor(454.0412, grad_fn=<NegBackward>)
tensor(669.4799, grad_fn=<NegBackward>)
tensor(728.8568, grad_fn=<NegBackward>)
tensor(650.4921, grad_fn=<NegBackward>)
tensor(3373.3113, grad_fn=<NegBackward>)
tensor(870.1290, grad_fn=<NegBackward>)
tensor(529.1665, grad_fn=<NegBackward>)
tensor(906.8413, grad_fn=<NegBackward>)
tensor(1373.9738, grad_fn=<NegBackward>)
tensor(743.4361, grad_fn=<NegBackward>)
tensor(337.3846, grad_fn=<NegBackward>)
tensor(424.1455, grad_fn=<NegBackward>)
tensor(393.5352, grad_fn=<NegBackward>)
tensor(484.4401, grad_fn=<NegBackward>)
tensor(4132.4893, grad_fn=<NegBackward>)
tensor(371.6100, grad_fn=<NegBackward>)
tensor(773.1045, grad_fn=<NegBackward>)
tensor(719.4949, grad_fn=<NegBackward>)
tensor(347.8835, grad_fn=<NegBackwar

tensor(297.4888, grad_fn=<NegBackward>)
tensor(302.5374, grad_fn=<NegBackward>)
tensor(292.0438, grad_fn=<NegBackward>)
tensor(284.0302, grad_fn=<NegBackward>)
tensor(362.4503, grad_fn=<NegBackward>)
tensor(289.4213, grad_fn=<NegBackward>)
tensor(465.1627, grad_fn=<NegBackward>)
tensor(276.9112, grad_fn=<NegBackward>)
tensor(254.7585, grad_fn=<NegBackward>)
tensor(400.9377, grad_fn=<NegBackward>)
tensor(3203.1575, grad_fn=<NegBackward>)
tensor(2577.9309, grad_fn=<NegBackward>)
tensor(458.8666, grad_fn=<NegBackward>)
tensor(255.8298, grad_fn=<NegBackward>)
tensor(262.4509, grad_fn=<NegBackward>)
tensor(368.2512, grad_fn=<NegBackward>)
tensor(395.1582, grad_fn=<NegBackward>)
tensor(263.7575, grad_fn=<NegBackward>)
tensor(2034.4324, grad_fn=<NegBackward>)
tensor(277.4263, grad_fn=<NegBackward>)
tensor(266.8943, grad_fn=<NegBackward>)
tensor(344.9327, grad_fn=<NegBackward>)
tensor(587.2587, grad_fn=<NegBackward>)
tensor(458.4171, grad_fn=<NegBackward>)
tensor(362.7578, grad_fn=<NegBackward

tensor(247.6588, grad_fn=<NegBackward>)
tensor(226.5456, grad_fn=<NegBackward>)
tensor(238.2239, grad_fn=<NegBackward>)
tensor(223.4195, grad_fn=<NegBackward>)
tensor(708.5432, grad_fn=<NegBackward>)
tensor(310.0786, grad_fn=<NegBackward>)
tensor(381.7982, grad_fn=<NegBackward>)
tensor(236.3225, grad_fn=<NegBackward>)
tensor(701.1791, grad_fn=<NegBackward>)
tensor(212.0039, grad_fn=<NegBackward>)
tensor(841.4307, grad_fn=<NegBackward>)
tensor(346.3544, grad_fn=<NegBackward>)
tensor(519.9086, grad_fn=<NegBackward>)
tensor(257.1023, grad_fn=<NegBackward>)
tensor(220.9452, grad_fn=<NegBackward>)
tensor(228.5663, grad_fn=<NegBackward>)
tensor(243.6449, grad_fn=<NegBackward>)
tensor(592.5652, grad_fn=<NegBackward>)
tensor(238.1723, grad_fn=<NegBackward>)
tensor(228.6878, grad_fn=<NegBackward>)
tensor(520.5508, grad_fn=<NegBackward>)
tensor(405.9785, grad_fn=<NegBackward>)
tensor(434.4062, grad_fn=<NegBackward>)
tensor(282.5668, grad_fn=<NegBackward>)
tensor(538.3871, grad_fn=<NegBackward>)


4.8034820556640625

In [None]:
errors = np.array([main(a_mu) for i in range(30)])
avg_error = errors.mean()

In [None]:
print("Average error:", round(avg_error / a_mu * 100, 2), "%")
print("Error variance:", round(np.array(errors).var() / a_mu * 100, 2), "%")


# Other attempts (more manual)


In [None]:
# We sample from Q, an approx posterior
# Has to be an isotropic Gaussian, 
# also a nn.module
def sample_model(k, c, prior_mean=0, var=1.0):
    s = t.Size([k])
    
    a = Normal(prior_mean, var)
    z_a = a.rsample(s)

    b = Normal(z_a, var)
    z_b = b.rsample()
    
    return z_a, z_b


# TODO
def get_factors(P, x) :
    return P.log_prob(x)


BATCH_SIZE = 2
ks = [BATCH_SIZE] * N_VARS


data = t.randn(1)
samples = sample_model(BATCH_SIZE, data)

In [None]:
"""
    estimators
"""
def vae_marginal_likelihood(x,z) :
    return P(x,z) / Q(z)