In [24]:
import torch.nn as nn
import torch as t

from torch.distributions import Normal, Categorical
from torch.distributions import MultivariateNormal as MVN
import math
import numpy as np

In [4]:
class P(nn.Module):
    def __init__(self, Pw, PzGw, PxGz):
        super().__init__()
        self.Pw = Pw
        self.PzGw = PzGw
        self.PxGz = PxGz

    def sample(self, N):
        w = self.Pw().sample()
        z = self.PzGw(w).sample(sample_shape=t.Size([N]))
        x = self.PxGz(z).sample()
        return (x, (w.unsqueeze(-1), z))

    def log_prob(self, xwz):
        x, (w, z) = xwz
        logPw   = self.Pw().log_prob(w)
        logPzGw = self.PzGw(w).log_prob(z)
        logPxGz = self.PxGz(z).log_prob(x)
        return logPw.sum(-1) + logPzGw.sum(-1) + logPxGz.sum(-1)


class Q(nn.Module):
    def __init__(self, Qw, Qz):
        super().__init__()
        self.Qw = Qw
        self.Qz = Qz

    def sample(self, N, sample_shape=t.Size([])):
        w = self.Qw().sample(sample_shape=sample_shape)
        z = self.Qz().sample(sample_shape=t.Size([*sample_shape, N]))
        return (w.unsqueeze(-1), z)

    def log_prob(self, wz):
        w, z = wz
        logQw = self.Qw().log_prob(w)
        logQz = self.Qz().log_prob(z)
        return logQw.sum(-1) + logQz.sum(-1)

In [14]:
def pqx(N, sw, sz, sx):
    w = ParamNormal((), scale=sw)
    zw = LinearNormal((), scale=sz)
    xz = LinearNormal((), scale=sx)
    p = P(w, zw, xz)
    x, _ = p.sample(N)

    w_q = ParamNormal((), scale=sw)
    zw_q = ParamNormal((), scale=math.sqrt(sw**2+sz**2))
    q = Q(w_q, zw_q)
    #(w, z) = q.sample(t.Size([3]))
    return (p, q, x)

    
class ParamNormal(nn.Module):
    def __init__(self, sample_shape, scale=1.):
        super().__init__()
        self.loc = nn.Parameter(t.zeros(sample_shape))
        self.log_scale = nn.Parameter(math.log(scale)*t.ones(sample_shape))

    def forward(self):
        return Normal(self.loc, self.log_scale.exp())

    

class LinearNormal(nn.Module):
    def __init__(self, sample_shape=t.Size([]), scale=1.):
        super().__init__()
        self.log_scale = nn.Parameter(math.log(scale)*t.ones(sample_shape))

    def forward(self, input):
        return Normal(input, self.log_scale.exp())




In [15]:
N = 100
sw = 1.
sz = 1.
sx = 0.1
p, q, x = pqx(N, sw=sw, sz=sz, sx=sx)
var = sw**2 * t.ones(N, N) + (sz**2+sx**2)*t.eye(N)
m = MVN(t.zeros(N), var)

print(m.log_prob(x.cpu()))

tensor(-132.7456)


In [None]:
class VAE(nn.Module):
    """
    Usual single/multi-sample VAE
    """
    def __init__(self, p, q, K):
        super().__init__()
        self.p = p
        self.q = q
        self.K = K

    def forward(self, x):
        wz = self.q.sample(x.size(0), sample_shape=t.Size([self.K]))
        elbo = self.p.log_prob((x, wz)) - self.q.log_prob(wz)
        lme = logmeanexp(elbo)
        return lme

    def train(self, x):
        opt = t.optim.Adam(q.parameters())
        for i in range(100):
            #opt.zero_grad()
            obj = self(x)
            #(-obj).backward()
            #opt.step()
            print(obj)

In [16]:
class TMC(nn.Module):
    def __init__(self, p, q, Kw, Kz=None):
        super().__init__()
        self.p = p
        self.q = q
        if Kz is None:
            Kz = Kw
        self.Kw = Kw
        self.Kz = Kz

    def train(self, x):
        opt = t.optim.Adam(q.parameters())
        for i in range(100):
            #opt.zero_grad()
            obj = self(x)
            #(-obj).backward()
            #opt.step()
            print(obj)

class TMC(TMC):
    def forward(self, x):
        w  = self.q.Qw().sample(sample_shape=t.Size([self.Kw, 1, 1]))
        z  = self.q.Qz().sample(sample_shape=t.Size([self.Kz, x.size(0)]))
        fw = self.p.Pw().log_prob(w) - self.q.Qw().log_prob(w)
        fz = self.p.PzGw(w).log_prob(z) - self.q.Qz().log_prob(z)
        fx = self.p.PxGz(z).log_prob(x)
        f_int_z = logmeanexp(fz + fx, -2)
        f_int_z = f_int_z.sum(-1) + fw.view(-1)
        f_int_w = logmeanexp(f_int_z)

        return f_int_w#.sum(0)

class TMC_Shared(TMC):
    def forward(self, x):
        w  = self.q.Qw().sample(sample_shape=t.Size([self.Kw]))
        z  = self.q.Qz().sample(sample_shape=t.Size([self.Kz]))
        fw = self.p.Pw().log_prob(w) - self.q.Qw().log_prob(w)

        fz = self.p.PzGw(w.unsqueeze(1)).log_prob(z) - self.q.Qz().log_prob(z)
        fx = self.p.PxGz(z.unsqueeze(1)).log_prob(x)
        #f_int_z = logmeanexp(fz + fx, -2)
        f_int_z = logmmmeanexp(fz, fx)
        f_int_z = f_int_z.sum(-1) + fw.view(-1)
        f_int_w = logmeanexp(f_int_z)

        return f_int_w#.sum(0)

In [20]:


def logmeanexp(x, dim=0):
    max = x.max(dim=dim, keepdim=True)[0]
    return ((x-max).exp().mean(dim, keepdim=True).log()+max).squeeze(dim)

def logsumexp(x, dim=0):
    max = x.max(dim=dim, keepdim=True)[0]
    return ((x-max).exp().sum(dim, keepdim=True).log()+max).squeeze(dim)

def logmmmeanexp(X, Y):
    x = X.max(dim=1, keepdim=True)[0]
    y = Y.max(dim=0, keepdim=True)[0]
    X = X - x
    Y = Y - y
    return x + y + t.mm(X.exp(), Y.exp()).log() - t.log(t.ones((), device=x.device)*X.size(1))

In [25]:
iters = 10

tmc = TMC(p, q, 501, 502)
tmcs = []

for i in range(iters):
    t.manual_seed(i)
    res = tmc(x)
    tmcs.append(res.detach().cpu().numpy())
    
    
tmcs = np.array(tmcs)
print(tmcs.mean())

-133.81053


# Markov model


In [1]:


# https://pyro.ai/examples/vae.html
# https://pyro.ai/examples/dmm.html
def model():
    z_prev = self.z_0

    # sample the latents z and observed x's one time step at a time
    for t in range(1, T_max + 1):
        # the next two lines of code sample z_t ~ p(z_t | z_{t-1}).
        # first compute the parameters of the diagonal gaussian
        # distribution p(z_t | z_{t-1})
        z_loc, z_scale = self.trans(z_prev)
        # then sample z_t according to dist.Normal(z_loc, z_scale)
        normal = Normal(z_loc, z_scale)
        z_t = normal.rsample()

        # compute the probabilities that parameterize the bernoulli likelihood
        emission_probs_t = self.emitter(z_t)
        # the next statement instructs pyro to observe x_t according to the
        # bernoulli distribution p(x_t|z_t)
        pyro.sample("obs_x_%d" % t,
                    dist.Bernoulli(emission_probs_t),
                    obs=mini_batch[:, t - 1, :])
        # the latent sampled at this time step will be conditioned upon
        # in the next time step so keep track of it
        z_prev = z_t