In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import tqdm.auto as tqdm
from FrEIA.utils import force_to
from collections import namedtuple
from math import sqrt, prod
import os
import torch
from torch.distributions import Distribution
from torch.autograd import grad
from torch.autograd.forward_ad import dual_level, make_dual, unpack_dual

from fff.utils.utils import sum_except_batch
from fff.utils.types import Transform

from fff.utils.func import (
    compute_volume_change,
    compute_jacobian
    )

from pinf.losses.utils import get_beta

Settings

---

In [None]:
fs = 20
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
train_new = True

torch.manual_seed(7)
np.random.seed(7)

In [None]:
lr = 1e-4
bs_nll = 512
bs_TRADE = 512
n_iter = int(0.5 * 1e5)
plot_freq = 5000
r_final = 0.01
gamma_lr_step = r_final ** (1 / n_iter)
lamba_weight_decay = 0.0

c_min = 1 / 3
c_max = 3.0

t_burn_in = 0.0
t_full = int(0.8 * n_iter)

beta_recon = 10
beta_TRADE = 0.5
n_hutchinson_samples_PI = 1

Define the model

---



In [None]:
class Model(nn.Module):
    def __init__(self, d_hidden = 512,activation_function = nn.SiLU, device = device,d_cond = 1):
        super().__init__()

        self.device = device

        self.encoder_block = nn.Sequential(
            nn.Linear(1+d_cond, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, 1)
        )

        self.decoder_block = nn.Sequential(
            nn.Linear(1+d_cond, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, 1)
        )

        
        for module in self.encoder_block:
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)

        for module in self.encoder_block:
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)

        self.p_0 = force_to(torch.distributions.Normal(loc = 0.0,scale = 1.0),self.device)
    
    def transform_condition(self,c):
        return c.log()

    def decode(self,z,c):
        x = self.decoder_block(torch.cat((z,self.transform_condition(c)),1)) + z

        return x
    
    def encode(self,x,c):
        z = self.encoder_block(torch.cat((x,self.transform_condition(c)),1)) + x

        return z

    def sample(self,n,c):
        
        assert(isinstance(c,float))

        c_tensor = torch.ones([n,1]).to(self.device) * c

        z = self.p_0.sample([n]).reshape(-1,1)
        x = self.decode(z = z,c = c_tensor)

        return x
    
    def log_prob(self,x,c):

        
        if not isinstance(c,float):
            raise ValueError("Beta tensor must be a float")
        
        #get the volume change of the transformation
        z,jac = compute_jacobian(
            x_in = x,
            fn = self.encode,
            chunk_size=1000,
            c = torch.tensor(c).to(self.device).reshape(-1,1)
            )

        log_jac_det = compute_volume_change(jac).reshape(-1,1)

        #get the log-likelihood of the latent code
        log_p_z = self.p_0.log_prob(z)
        
        #print(log_p_z.shape,log_jac_det.shape)
        assert log_p_z.shape == log_jac_det.shape
        log_p = log_p_z + log_jac_det

        return log_p


Free Form Flow utils

---

In [None]:

# The code in this cell is adapted from 
# Code adapted from the repository 'FFF'
# License: MIT License
# Source: https://github.com/vislearn/FFF

def sample_v(
    x: torch.Tensor, hutchinson_samples: int
) -> torch.Tensor:
    
    batch_size, total_dim = x.shape[0], prod(x.shape[1:])

    if hutchinson_samples > total_dim:
        raise ValueError(
            f"Too many Hutchinson samples: got {hutchinson_samples}, \
                expected <= {total_dim}"
        )

    v = torch.randn(
        batch_size, total_dim, hutchinson_samples, device=x.device, dtype=x.dtype
    )
    q = torch.linalg.qr(v).Q.reshape(*x.shape, hutchinson_samples)
    return q * sqrt(total_dim)

SurrogateOutput = namedtuple(
    "SurrogateOutput", ["surrogate", "z", "x1", "regularizations"]
)

def volume_change_surrogate(
    x: torch.Tensor,
    c,
    encode: Transform,
    decode: Transform,
    hutchinson_samples: int = 1
) -> SurrogateOutput:
    r"""Computes the surrogate for the volume change term in the change of
    variables formula. The surrogate is given by:
    $$
    v^T f_\theta'(x) \texttt{SG}(g_\phi'(z) v).
    $$
    The gradient of the surrogate is the gradient of the volume change term.

    :param x: Input data. Shape: (batch_size, ...)
    :param encode: Encoder function. Takes `x` as input and returns a latent
        representation `z` of shape (batch_size, latent_shape).
    :param decode: Decoder function. Takes a latent representation `z` as input
        and returns a reconstruction `x1`.
    :param hutchinson_samples: Number of Hutchinson samples to use for the
        volume change estimator. The number of hutchinson samples must be less
        than or equal to the total dimension of the data.
    :param manifold: Manifold on which the latent space lies. If provided, the
        volume change is computed in the tangent space of the manifold.
    :return: The computed surrogate of shape (batch_size,), latent representation
        `z`, reconstruction `x1` and regularization metrics computed on the fly.
    """
    regularizations = {}
    surrogate = 0

    x.requires_grad_()
    z = encode(x,c)

    vs = sample_v(z, hutchinson_samples)

    for k in range(hutchinson_samples):
        v = vs[..., k]

        # $ g'(z) v $ via forward-mode AD
        with dual_level():
            dual_z = make_dual(z, v)
            dual_x1 = decode(dual_z,c)

            x1, v1 = unpack_dual(dual_x1)

        # $ v^T f'(x) $ via backward-mode AD
        (v2,) = grad(z, x, v, create_graph=True)

        # $ v^T f'(x) stop_grad(g'(z)) v $
        surrogate += sum_except_batch(v2 * v1.detach()) / hutchinson_samples

    return SurrogateOutput(surrogate, z, x1, regularizations)

def fff_loss(
    x: torch.Tensor,
    c,
    encode: Transform,
    decode: Transform,
    latent_distribution: Distribution,
    hutchinson_samples: int = 1,
) -> torch.Tensor:
    r"""Compute the per-sample FFF/FIF loss:
    $$
    \mathcal{L} = \beta ||x - decode(encode(x))||^2 - \log p_Z(z)
        - \sum_{k=1}^K v_k^T f'(x) stop_grad(g'(z)) v_k
    $$
    where $E[v_k^T v_k] = 1$, and $ f'(x) $ and $ g'(z) $ are the Jacobians of
    `encode` and `decode`.

    :param x: Input data. Shape: (batch_size, ...)
    :param encode: Encoder function. Takes `x` as input and returns a latent
        representation `z` of shape (batch_size, latent_shape).
    :param decode: Decoder function. Takes a latent representation `z` as input
        and returns a reconstruction `x1`.
    :param latent_distribution: Latent distribution of the model.
    :param beta: Weight of the mean squared error.
    :param hutchinson_samples: Number of Hutchinson samples to use for the
        volume change estimator.
    :return: Per-sample loss. Shape: (batch_size,)"""
    surrogate = volume_change_surrogate(x,c, encode, decode, hutchinson_samples)
    log_prob = latent_distribution.log_prob(surrogate.z)
    nll = -sum_except_batch(log_prob) - surrogate.surrogate
    return nll

Initialize the target distribution

---

In [None]:
p_1 = force_to(torch.distributions.Normal(loc = 2.0,scale = 0.5),device)
p_2 = force_to(torch.distributions.Normal(loc = -2.0,scale = 1.5),device)

def p_target(x):
    return (p_1.log_prob(x).exp() + p_2.log_prob(x).exp()) / 2

def get_training_samples(n):
    x_1 = p_1.sample([int(n / 2)])
    x_2 = p_2.sample([n - int(n / 2)])

    return torch.cat((x_1,x_2),0).reshape(-1,1)

c_0 = 1.0

Initialize the model

---

In [None]:
model = Model()
model.to(device)
model.train(True)

optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay = lamba_weight_decay)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,gamma = gamma_lr_step)

Training

---

In [None]:
if train_new:
    storage_loss_nll = []
    storage_loss_TRADE = []
    storage_loss_recon = []

    for t in tqdm.tqdm(range(n_iter)):

        x_target = get_training_samples(bs_nll)
        c_0_tensor = torch.ones_like(x_target).to(device) * c_0

        loss_nll = fff_loss(
            x = x_target,
            c = c_0_tensor,
            encode=model.encode,
            decode=model.decode,
            latent_distribution=model.p_0
        ).mean()

        loss_recon = (x_target - model.decode(model.encode(x_target,c_0_tensor),c_0_tensor)).pow(2).sum(-1)

        if beta_TRADE > 0:

            with torch.no_grad():

                c_k,left,right = get_beta(
                    t = t,
                    t_burn_in=t_burn_in,
                    t_full=t_full,
                    beta_star=c_0,
                    beta_max=c_max,
                    beta_min=c_min,
                    mode = "log-linear"
                )

                x_eval = model.sample(n = bs_TRADE,c = c_k).to(device)

                beta_k_tensor_TRADE = torch.ones((bs_TRADE,1)).to(device) * c_k
                beta_0_tensor_TRADE = torch.ones((bs_TRADE,1)).to(device) * c_0

                d_log_q_d_c = p_target(x_eval).log() / c_0

                # Compute the importance weights
                log_q_target_c = c_k / c_0 *  p_target(x_eval).log()
                log_p_model_c = model.log_prob(x = x_eval,c = float(c_k))

                assert (log_q_target_c.shape == log_p_model_c.shape)

                log_omega = (log_q_target_c - log_p_model_c)

                assert(log_omega.shape == d_log_q_d_c.shape)

                EX = (log_omega.exp() * d_log_q_d_c).mean() / log_omega.exp().mean()

                target = (d_log_q_d_c - EX).detach()

            ####################################

            surrogate = 0

            c_k_tensor = torch.ones_like(x_eval).to(device) * c_k
            c_k_tensor.requires_grad_(True)

            vs = sample_v(torch.ones_like(x_eval), n_hutchinson_samples_PI)

            for k in range(n_hutchinson_samples_PI):

                v = vs[..., k]

                ############################################
                #compute v^T g'(z,c)
                ############################################

                c_k_tensor_clone = c_k_tensor.clone().detach().requires_grad_(False)
                z_no_c_grad = model.encode(x_eval,c_k_tensor_clone)

                (v1,) = grad(model.decode(z_no_c_grad,c_k_tensor_clone), z_no_c_grad, v, create_graph=True,retain_graph=True)

                ############################################
                #Compute f'(x) v
                ############################################

                with dual_level():
                    dual_x = make_dual(x_eval, v)
                    dual_z = model.encode(dual_x,c_k_tensor)

                    x1, v2 = unpack_dual(dual_z)

                assert(v1.shape == v2.shape)

                a = (v1 * v2).sum(-1,keepdim=True)

                # Compute the change with respect to the parameter
                (v3,) = grad(a.sum(),c_k_tensor,create_graph=True,retain_graph=True)

                surrogate += v3 / n_hutchinson_samples_PI

            # Compute the PI loss
                
            log_p_0 = model.p_0.log_prob(model.encode(x_eval,c_k_tensor))

            (d_log_p_0_d_c,) = grad(log_p_0.sum(),c_k_tensor, create_graph=True,retain_graph=True)

            d_log_p_theta_d_c = d_log_p_0_d_c + surrogate

            assert(d_log_p_theta_d_c.shape == target.shape)
            
            loss_TRADE = (target - d_log_p_theta_d_c).pow(2).mean()


            # Get the reconstruction loss
            recon_TRADE = (x_eval - model.decode(model.encode(x_eval,c_k_tensor),c_k_tensor)).pow(2).sum(-1)
            loss_recon = torch.cat((loss_recon,recon_TRADE),0).mean()

            # Combine the different loss contributions
            loss = loss_nll + beta_recon * loss_recon + beta_TRADE * loss_TRADE

            storage_loss_TRADE.append(loss_TRADE.item())

        else:
            
            loss = loss_nll + beta_recon * loss_recon.mean()


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        storage_loss_nll.append(loss_nll.item())
        storage_loss_recon.append(loss_recon.item())

In [None]:
folder = "../../results/FFF_TRADE_1D_Proof_of_concept"

if train_new:

    if not os.path.exists(folder):
        os.makedirs(folder)

    data = np.zeros([len(storage_loss_nll),3])

    data[:,0] = storage_loss_nll
    data[:,1] = storage_loss_recon
    data[:,2] = storage_loss_TRADE

    np.savetxt(
        fname = os.path.join(folder,"loss.txt"),
        X = data,
        header = "loss_nll\tloss_reconstruction\tloss_TRADE"
    )

    torch.save(
        model.state_dict(),
        f = os.path.join(folder,"final.ckpt")
        )

Load the pretrained model

---

In [None]:
model.load_state_dict(torch.load(f = os.path.join(folder,"final.ckpt")))

Plotting

---

In [None]:
fig,axes = plt.subplots(3,1,figsize = (10,10))

axes[0].set_title(r"$\mathcal{L}_{nll}$")
axes[0].plot(storage_loss_nll)
axes[0].set_xlabel("t [iter.]")

axes[1].set_title(r"$\mathcal{L}_{recon}$")
axes[1].plot(storage_loss_recon)
axes[1].set_xlabel("t [iter.]")

axes[2].set_title(r"$\mathcal{L}_{TRADE}$")
axes[2].plot(storage_loss_TRADE)
axes[2].set_xlabel("t [iter.]")

plt.tight_layout()

In [None]:
c_eval = [0.5,1.0,2.0]

fig,axes = plt.subplots(len(c_eval),1,figsize = (8 ,2.5 * len(c_eval)))

x_eval = torch.linspace(-10,6,1000).reshape(-1,1).to(device)
p_target_c_0 = p_target(x_eval.squeeze().detach()).squeeze().detach().cpu()

with torch.no_grad():
    for i,c_i in enumerate(c_eval):
        axes[i].set_title(f"c = {c_i}")

        x_i = model.sample(500000,c = c_i).detach().cpu().squeeze()
        axes[i].hist(x_i,density=True,bins = 100,edgecolor = "orange",histtype = "step",linewidth = 2,label = "model data")

        p_theta = model.log_prob(x = x_eval,c = c_i).detach().cpu().exp()
        axes[i].plot(x_eval.squeeze().detach().cpu(),p_theta,lw = 2,c = "b",label = "model density")

        p_target_c = p_target_c_0 ** (c_i / c_0)
        Z_i = p_target_c.sum() * (x_eval.squeeze()[1] - x_eval.squeeze()[0])
        axes[i].plot(x_eval.squeeze().detach().cpu(),p_target_c.cpu() / Z_i.cpu(),color = "k",lw = 2,label = "target")
        axes[i].set_xlabel("x")
        axes[i].set_ylabel(f"p(x|c = {c_i})")

handles, labels = [], []

for handle, label in zip(*axes[len(c_eval)-1].get_legend_handles_labels()):
    handles.append(handle)
    labels.append(label)

# Add a single legend below all subplots
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.01), ncol=4)

plt.tight_layout()
plt.savefig(os.path.join(folder,"densities.pdf"),bbox_inches='tight')