# VSMC with Normalizing Flows

### To do:
- [ ] Think about how to take into consideration `x_prev`
- [ ] Implement it!

In [6]:
# Import normalizing flows module
import sys
sys.path.append("./src")
from flows import *

In [7]:
# Import other libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import distributions
from torch.nn.parameter import Parameter

In [8]:
a = torch.randn(4, 2)
b = torch.randn(4, 2)
torch.cat([a, b], dim=0)

tensor([[ 0.0280,  0.4885],
        [-0.3894, -0.9720],
        [ 1.4886,  0.6052],
        [ 0.8581, -0.0915],
        [-1.8067, -1.3981],
        [-0.2618, -0.3907],
        [ 1.0203, -0.9154],
        [-0.7010,  1.5876]])

In [9]:
torch.randn(4, 2)

tensor([[ 0.6002, -1.4475],
        [ 0.5411, -1.4240],
        [ 0.1398, -0.1670],
        [ 0.2691,  0.0152]])

In [10]:
import torch
from torch import nn
from torch.distributions import Normal, MultivariateNormal

# Classes and functions necessary for Real NVP

class MLP(nn.Module):
    '''
    Multilayer perceptron module.
    '''
    def __init__(self, in_dim, out_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.net(x)


class RealNVPLayer(nn.Module):
    '''
    Real NVP layer module.
    '''
    def __init__(self, dim, mask, context_dim):
        super().__init__()
        
        # Uses MLP module as scale and
        # translate neural networks
        self.scale_net = MLP(dim + context_dim, dim)
        self.translate_net = MLP(dim + context_dim, dim)

        # Defines mask to implement
        # coupling layers
        self.mask = mask

    def forward(self, x, context):
        # Forward transformation
        x_masked = x * self.mask
        # Adding context
        print(x_masked.shape,'\n',context.shape)
        x_masked_with_context = torch.cat([x_masked, context], dim=-1)
        s = self.scale_net(x_masked_with_context) * (1 - self.mask)
        t = self.translate_net(x_masked_with_context) * (1 - self.mask)
        z = x_masked + (1 - self.mask) * (x * torch.exp(s) + t)
        log_det_J = torch.sum(s, dim=1)
        return z, log_det_J

    def inverse(self, y, context):
        # Inverse transformation
        y_masked = y * self.mask
        # Adding context
        y_masked_with_context = torch.cat([y_masked, context], dim=-1)
        s = self.scale_net(y_masked_with_context) * (1 - self.mask)
        t = self.translate_net(y_masked_with_context) * (1 - self.mask)
        x = y_masked + (1 - self.mask) * ((y - t) * torch.exp(-s))
        log_det_J = torch.sum(s, dim=1)
        return x, log_det_J


class RealNVP(nn.Module):
    '''
    Real NVP module.
    Consists of multiple Real NVP layers.
    '''
    def __init__(self, dim, n_layers, base_dist, context_dim):
        super().__init__()
        self.dim = dim
        self.base_dist = base_dist
        self.layers = nn.ModuleList()
        for i in range(n_layers):
            # Create masks (alternating 0s and 1s)
            mask_list = [(i + j) % 2 for j in range(dim)]
            mask = torch.tensor(mask_list, dtype=torch.float32)

            # Add Real NVP layer
            self.layers.append(RealNVPLayer(self.dim, mask, context_dim))

    def forward(self, x, context):
        # Forward transformation
        log_det_J = torch.zeros(x.size(0), device=x.device)
        for layer in self.layers:
            x, log_det = layer.forward(x, context)
            log_det_J += log_det
        return x, log_det_J

    def inverse(self, z, context):
        # Inverse transformation
        log_det_J = torch.zeros(z.size(0), device=z.device)
        for layer in reversed(self.layers):
            z, log_det = layer.inverse(z, context)
            log_det_J -= log_det
        return z, log_det_J

    def log_prob(self, x, context):
        # Computes the log pdf of the final samples
        z, log_det = self.inverse(x, context)
        return self.base_dist.log_prob(z) + log_det
        
    def sample(self, n_samples, context):
        # Sample from the final distribution
        z = self.base_dist.sample((n_samples,))
        x, _ = self.forward(z.view(n_samples, self.dim), context)
        return x





In [None]:
# Functions to be used
def prior_VSMC(x, x_prev):
    return prior(x_prev).log_prob(x)

def likelihood_VSMC(y_t, x_t):
    return likelihood(x_t).log_prob(y_t)

def proposal_logpdf(x, x_prev, lambda_):
    if x_prev is None:
        return Normal(lambda_, 1).log_prob(x)
    else:
        return Normal(lambda_ * x_prev, 1).log_prob(x)

def proposal(x_prev, eps, lambda_):
    if x_prev is None:
        mu = lambda_
    else:
        mu = lambda_ * x_prev
    return eps + mu

# VSMC
def run_vsmc(y_seq, lambda_, N, T):
    latent_dim = y_seq[0].size(0)
    log_w = torch.zeros(T, N)
    x_particles = torch.zeros(T, N, latent_dim)
    ancestors = torch.zeros(T, N, dtype=torch.long)

    eps = torch.randn(N, latent_dim)
    x_t = proposal(None, eps, lambda_)
    x_particles[0] = x_t
    log_w[0] = (
        prior_VSMC(x_t, None) + likelihood_VSMC(y_seq[0], x_t) - proposal_logpdf(x_t, None, lambda_)
    ).squeeze()

    for t in range(1, T):
        w_prev = torch.softmax(log_w[t - 1], dim=0)
        # print(w_prev)
        a_t = torch.multinomial(w_prev, N, replacement=True)
        ancestors[t] = a_t

        x_prev = x_particles[t - 1, a_t]
        eps = torch.randn(N, latent_dim)
        x_t = proposal(x_prev, eps, lambda_)
        x_particles[t] = x_t

        log_w[t] = (
            prior_VSMC(x_t, x_prev) + likelihood_VSMC(y_seq[t], x_t) - proposal_logpdf(x_t, x_prev, lambda_)
        ).squeeze()

    # Sample final trajectory
    w_T = torch.softmax(log_w[-1], dim=0)
    b_T = torch.multinomial(w_T, 1).item()
    
    # Trace back trajectory
    x_estimates = torch.zeros_like(x_particles[:, 0])
    i = b_T
    for t in reversed(range(T)):
        x_estimates[t] = x_particles[t, i]
        if t > 0:
            i = ancestors[t, i]
    
    return log_w, x_particles, ancestors, x_estimates

def surrogate_elbo(log_w):
    logZ = torch.logsumexp(log_w, dim=1) - torch.log(torch.tensor(log_w.size(1), dtype=torch.float32))
    return logZ.sum()

def train_vsmc(y_seq, T, N, n_steps=100, lr=1e-3):
    # Parameter to learn
    lambda_ = torch.nn.Parameter(torch.randn((1,)))
    optimizer = torch.optim.Adam([lambda_], lr=lr)
    history = []
    
    # Training loop
    for step in range(n_steps):
        optimizer.zero_grad()
        log_w,_,_,_ = run_vsmc(y_seq, lambda_, N, T=T)
        elbo = surrogate_elbo(log_w)
        loss = -elbo
        loss.backward()
        if step % 50 == 0:
            print(f"Step {step:03d} | ELBO: {elbo.item():.2f} | grad: {lambda_.grad.item():.4f}")
        optimizer.step()
        history.append(loss)
    print(f"Training finished!\nFinal loss: {loss.item()}")
    return history, lambda_



In [12]:
# Define a base distribution
base_dist = MultivariateNormal(torch.zeros(2), torch.eye(2))

# Example usage
dim = 2  # Dimensionality of data
n_layers = 4  # Number of RealNVP layers
context_dim = 2  # Dimensionality of context

realnvp = RealNVP(dim, n_layers, base_dist, context_dim)

# Sample data
context = torch.randn(5, context_dim)  # 10 samples, 2-dimensional context
x = torch.randn(5, dim)  # 10 samples, 3-dimensional data

# Forward pass
print('forward pass')
z, log_det_J = realnvp(x, context)

# Inverse pass
print('inverse pass')
x_reconstructed, _ = realnvp.inverse(z, context)

# Log probability
print('log prob')
log_prob = realnvp.log_prob(x, context)

# Sampling from the model
print('sample')
samples = realnvp.sample(5, context)


forward pass
torch.Size([5, 2]) 
 torch.Size([5, 2])
torch.Size([5, 2]) 
 torch.Size([5, 2])
torch.Size([5, 2]) 
 torch.Size([5, 2])
torch.Size([5, 2]) 
 torch.Size([5, 2])
inverse pass
log prob
sample
torch.Size([5, 2]) 
 torch.Size([5, 2])
torch.Size([5, 2]) 
 torch.Size([5, 2])
torch.Size([5, 2]) 
 torch.Size([5, 2])
torch.Size([5, 2]) 
 torch.Size([5, 2])


In [15]:
MultivariateNormal(torch.zeros(2), torch.eye(2)).log_prob(x)

tensor([-2.4961, -2.7053, -2.0743, -2.2819, -2.2937])

In [16]:
x[0]

tensor([-0.0368,  1.1468])

In [17]:
normal = distributions.Normal(0, 1)
flow = RealNVP(dim=2, n_layers=5, base_dist=normal, context_dim=2)

In [18]:
context = torch.tensor([[1.0, 2.1]])
context

tensor([[1.0000, 2.1000]])

In [19]:
# Not trained yet
flow.log_prob(torch.tensor([[0.0]]),context)

tensor([[-0.6514, -0.6892]], grad_fn=<AddBackward0>)

In [9]:
# trying to fix it (by gepeto)
import torch
from torch import nn

# ---------------------
# Helper MLP class
# ---------------------
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        return self.net(x)


# ---------------------
# RealNVP Layer
# ---------------------
class RealNVPLayer(nn.Module):
    def __init__(self, dim, mask, context_dim):
        super().__init__()
        self.scale_net = MLP(dim + context_dim, dim)
        self.translate_net = MLP(dim + context_dim, dim)
        self.mask = nn.Parameter(mask, requires_grad=False)

    def forward(self, x, context):
        x_masked = x * self.mask
        print(x_masked.size(), context.size())
        x_input = torch.cat([x_masked, context], dim=-1)
        s = self.scale_net(x_input) * (1 - self.mask)
        t = self.translate_net(x_input) * (1 - self.mask)
        z = x_masked + (1 - self.mask) * (x * torch.exp(s) + t)
        log_det_J = torch.sum(s, dim=1)
        return z, log_det_J

    def inverse(self, y, context):
        y_masked = y * self.mask
        y_input = torch.cat([y_masked, context], dim=-1)
        s = self.scale_net(y_input) * (1 - self.mask)
        t = self.translate_net(y_input) * (1 - self.mask)
        x = y_masked + (1 - self.mask) * ((y - t) * torch.exp(-s))
        log_det_J = torch.sum(s, dim=1)
        return x, log_det_J


# ---------------------
# RealNVP Model
# ---------------------
class RealNVP(nn.Module):
    def __init__(self, dim, n_layers, base_dist, context_dim):
        super().__init__()
        self.dim = dim
        self.base_dist = base_dist
        self.layers = nn.ModuleList()

        for i in range(n_layers):
            mask = torch.tensor([(i + j) % 2 for j in range(dim)], dtype=torch.float32)
            self.layers.append(RealNVPLayer(dim, mask, context_dim))

    def forward(self, x, context):
        log_det_J = torch.zeros(x.size(0), device=x.device)
        for layer in self.layers:
            x, log_det = layer(x, context)
            log_det_J += log_det
        return x, log_det_J

    def inverse(self, z, context):
        log_det_J = torch.zeros(z.size(0), device=z.device)
        for layer in reversed(self.layers):
            z, log_det = layer.inverse(z, context)
            log_det_J -= log_det
        return z, log_det_J

    def log_prob(self, x, context):
        z, log_det = self.inverse(x, context)
        return self.base_dist.log_prob(z).sum(dim=1) + log_det

    def sample(self, n_samples, context):
        z = self.base_dist.sample((n_samples, self.dim)).to(context.device)
        x, _ = self.forward(z, context)
        return x


In [12]:
if __name__ == "__main__":
    dim = 3
    context_dim = 2
    n_layers = 4

    base_dist = torch.distributions.Normal(torch.zeros(dim), torch.ones(dim))
    model = RealNVP(dim, n_layers, base_dist, context_dim)

    x = torch.randn(10, dim)
    context = torch.randn(10, context_dim)

    logp = model.log_prob(x, context)
    print("Log prob:", logp)

    samples = model.sample(10, context)
    print("Samples:", samples)


RuntimeError: The size of tensor a (3) must match the size of tensor b (10) at non-singleton dimension 1

In [None]:
def proposal_logpdf(x_t, x_prev, flow):
    flow.log_prob(x_t)
    # finish this

def vsmc_nf(y_seq, lambda_, N, T):
    latent_dim = y_seq[0].size(0)
    log_w = torch.zeros(T, N)
    x_particles = torch.zeros(T, N, latent_dim)
    ancestors = torch.zeros(T, N, dtype=torch.long)

    eps = torch.randn(N, latent_dim)
    x_t = proposal(None, eps, lambda_)
    x_particles[0] = x_t
    log_w[0] = (
        prior(x_t, None) + likelihood(y_seq[0], x_t) - proposal_logpdf(x_t, None, flow)
    ).squeeze()

    for t in range(1, T):
        w_prev = torch.softmax(log_w[t - 1], dim=0)
        # print(w_prev)
        a_t = torch.multinomial(w_prev, N, replacement=True)
        ancestors[t] = a_t

        x_prev = x_particles[t - 1, a_t]
        eps = torch.randn(N, latent_dim)
        x_t = proposal(x_prev, eps, lambda_)
        x_particles[t] = x_t

        log_w[t] = (
            prior(x_t, x_prev) + likelihood(y_seq[t], x_t) - proposal_logpdf(x_t, x_prev, flow)
        ).squeeze()

    # Sample final trajectory
    w_T = torch.softmax(log_w[-1], dim=0)
    b_T = torch.multinomial(w_T, 1).item()
    
    # Trace back trajectory
    x_estimates = torch.zeros_like(x_particles[:, 0])
    # i = b_T
    for t in reversed(range(T)):
        x_estimates[t] = x_particles[t, b_T]
        if t > 0:
            i = ancestors[t, b_T]
    
    return log_w, x_particles, ancestors, x_estimates
    # return log_w, x_particles



função run_vsmc-> olhar pra ela e tudo que ta na mesma celula. tem que fazer o treinamento dela mas a proposta deve ser substituida pela distribuição aprendida pelo normalizing flow --- pedir pro manus se sobrar processamento

também é interessante configurar github nesse compiuter pra subir as coisas
