In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import tqdm.auto as tqdm
from FrEIA.utils import force_to
import os
from pinf.losses.utils import get_beta

Settings

---

In [None]:
torch.manual_seed(7)
np.random.seed(7)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
lr = 1e-5
bs_nll = 512
n_iter = int(0.5 * 1e5)
save_freq = 5000
r_final = 0.1
gamma_lr_step = r_final ** (1 / n_iter)
lamba_weight_decay = 0.0

beta_0 = 1.0
beta_min = 1 / 3
beta_max = 3.0

t_burn_in = 0.0
t_full = int(0.8 * n_iter)

bs_TRADE = 512
lambda_TS = 0.1

fs = 20

Networks to model the mean and the standard deviation of a one diemensional normal distribution as a function of $\beta$

---

In [None]:
class Model(nn.Module):
    def __init__(self, d_hidden = 128,activation_function = nn.SiLU, device = device):
        super().__init__()

        self.device = device

        self.mean = nn.Sequential(
            nn.Linear(1, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, 1)
        )

        self.sigma = nn.Sequential(
            nn.Linear(1, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, d_hidden),
            activation_function(),
            nn.Linear(d_hidden, 1)
        )

        
        for module in self.sigma:
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)

        for module in self.mean:
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)

    def forward(self,x,beta_tensor):

        assert(x.shape == beta_tensor.shape)

        sigma = self.get_sigma(beta_tensor=beta_tensor)
        mean = self.get_mean(beta_tensor=beta_tensor)

        assert (sigma.shape == x.shape)
        assert (mean.shape == x.shape)

        log_prob = - (x - mean).pow(2) / (2 * sigma.pow(2)) - 0.5 * torch.log(2 * np.pi * sigma.pow(2))
        return log_prob
    
    def get_sigma(self,beta_tensor):
        return self.sigma(beta_tensor.log()).exp()
    
    def get_mean(self,beta_tensor):
        return self.mean(beta_tensor.log()).exp()
    
    def sample(self,n,beta):
        
        assert(isinstance(beta,float))

        beta_tensor = torch.ones([1,1]).to(self.device) * beta

        sigma = self.get_sigma(beta_tensor=beta_tensor).item()
        mean = self.get_mean(beta_tensor=beta_tensor).item()

        assert(isinstance(sigma,float))
        assert(isinstance(mean,float))

        x = torch.randn(n).reshape(-1,1) * sigma + mean

        return x

Initialize the target distribution

---

In [None]:
p_target = force_to(torch.distributions.Normal(loc = 0.0,scale = 1.0),device)

Initialize the model

---

In [None]:
model = Model()
model.to(device)
model.train(True)

optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay = lamba_weight_decay)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,gamma = gamma_lr_step)

folder = "../../results/TRADE_1D_Proof_of_concept/"
if not os.path.exists(folder):
    os.makedirs(folder)

Train the model

---

In [None]:
beta_eval = torch.linspace(beta_min,beta_max,1000).to(device)
sigma_target = 1 / beta_eval.sqrt().cpu()
mu_target = torch.zeros_like(beta_eval).cpu()
beta_eval = beta_eval.reshape(-1,1)

loss_nll_storage = torch.zeros(n_iter)
loss_TRADE_storage = torch.zeros(n_iter)
loss_total_storage = torch.zeros(n_iter)

for t in tqdm.tqdm(range(n_iter)):
    
    x_target = p_target.sample([bs_nll]).reshape(-1,1)
    beta_0_tensor = torch.ones([bs_nll,1]).to(device) * beta_0

    # Get the nll loss
    nll = - model(x = x_target,beta_tensor = beta_0_tensor).mean()

    with torch.no_grad():
        beta_k,left,right = get_beta(
            t = t,
            t_burn_in=t_burn_in,
            t_full=t_full,
            beta_star=beta_0,
            beta_max=beta_max,
            beta_min=beta_min,
            mode = "log-linear"
        )

        x_eval = model.sample(n = bs_TRADE,beta = beta_k).to(device)

        beta_k_tensor_TRADE = torch.ones((bs_TRADE,1)).to(device) * beta_k
        beta_0_tensor_TRADE = torch.ones((bs_TRADE,1)).to(device) * beta_0

        d_log_q_d_c = p_target.log_prob(x_eval) / beta_0

        # Compute the importance weights
        log_q_target_c = beta_k / beta_0 * p_target.log_prob(x_eval) 
        log_p_model_c = model(x = x_eval,beta_tensor = beta_k_tensor_TRADE)

        assert (log_q_target_c.shape == log_p_model_c.shape)

        log_omega = (log_q_target_c - log_p_model_c)

        assert(log_omega.shape == d_log_q_d_c.shape)

        EX = (log_omega.exp() * d_log_q_d_c).mean() / log_omega.exp().mean()

        target = (d_log_q_d_c - EX).detach()

    beta_k_tensor_TRADE.requires_grad_(True)
    d_log_p_theta_d_c = torch.autograd.grad(model(x_eval,beta_k_tensor_TRADE).sum(),beta_k_tensor_TRADE,create_graph=True)[0]
    
    assert(d_log_p_theta_d_c.shape == target.shape)

    loss_TRADE = (target - d_log_p_theta_d_c).pow(2).mean()

    loss = nll + lambda_TS * loss_TRADE

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    scheduler.step()

    loss_nll_storage[t] = nll.item()
    loss_TRADE_storage[t] = loss_TRADE.item()
    loss_total_storage[t] = loss.item()

    
    #Plot the current state of the networks
    if ((t + 1) % save_freq) == 0 or (t == 0):
        torch.save(model.state_dict(),os.path.join(folder,f"step_{t+1}_model.ckpt"))

Plot the training objective

---

In [None]:
fig,axes = plt.subplots(3,1, figsize = (10,15))
start = 2
end = len(loss_TRADE_storage)
stepsize = 1
n_iter_tensor = torch.arange(0,n_iter)

axes[0].plot(n_iter_tensor[start:end:stepsize],loss_nll_storage.detach().numpy()[start:end:stepsize])
axes[0].set_title('NLL',fontsize = fs)
axes[0].set_xlabel('iteration',fontsize = fs)
axes[0].set_ylabel(r"$\mathcal{L}_{nll}$",fontsize = fs)
axes[0].tick_params(axis='x', labelsize=fs)
axes[0].tick_params(axis='y', labelsize=fs)
axes[0].set_xticks(n_iter_tensor[::50000].numpy(),n_iter_tensor[::50000].numpy())

axes[1].plot(n_iter_tensor[start:end:stepsize],loss_TRADE_storage.detach()[start:end:stepsize].numpy())
axes[1].set_title('TS',fontsize = fs)
axes[1].set_xlabel('iteration',fontsize = fs)
axes[1].set_ylabel(r"$\mathcal{L}_{TS}$",fontsize = fs)
axes[1].tick_params(axis='x', labelsize=fs)
axes[1].tick_params(axis='y', labelsize=fs)
axes[1].set_xticks(n_iter_tensor[::50000].numpy(),n_iter_tensor[::50000].numpy())

axes[2].plot(n_iter_tensor[start:end:stepsize],loss_total_storage.detach().numpy()[start:end:stepsize])
axes[2].set_title(r'Objective $\mathcal{L}= \mathcal{L}_{nll} + \lambda \cdot \mathcal{L}_{TS}$',fontsize = fs)
axes[2].set_xlabel('iteration',fontsize = fs)
axes[2].set_ylabel(r"$\mathcal{L}$",fontsize = fs)
axes[2].tick_params(axis='x', labelsize=fs)
axes[2].tick_params(axis='y', labelsize=fs)
axes[2].set_xticks(n_iter_tensor[::50000].numpy(),n_iter_tensor[::50000].numpy())

plt.tight_layout()

Compare the learned parameters to the ground truth

---

In [None]:
sigma_pred = model.get_sigma(beta_eval.reshape(-1,1)).detach().flatten().abs().cpu()
mu_pred = model.get_mean(beta_eval.reshape(-1,1)).detach().flatten().cpu()


fs = 15
fig,axes = plt.subplots(2,1,figsize = (10,8))

axes[0].set_xlabel(r"$c$",fontsize = fs)
axes[0].set_ylabel(r"$\sigma(c$)",fontsize = fs)
axes[1].set_xlabel(r"$c$",fontsize = fs)
axes[1].set_ylabel(r"$\mu(c$)",fontsize = fs)

axes[0].plot(beta_eval.cpu(),sigma_target,label = "target",ls = "-",c = "k",lw = 3)
axes[1].plot(beta_eval.cpu(),mu_target,label = "target",ls = "-",c = "k",lw = 3)

axes[0].plot(beta_eval.cpu(),sigma_pred,label = "prediction",c = f"b",ls = "-.",lw = 4)
axes[1].plot(beta_eval.cpu(),mu_pred,label = "prediction",c = f"b",ls = "-.",lw = 4)

axes[0].tick_params(axis='x', labelsize=fs)
axes[0].tick_params(axis='y', labelsize=fs)
axes[1].tick_params(axis='x', labelsize=fs)
axes[1].tick_params(axis='y', labelsize=fs)

axes[0].legend(fontsize = fs)
axes[1].legend(fontsize = fs)

axes[0].legend(fontsize = fs)
axes[1].legend(fontsize = fs)

#Markt the position of the training data
axes[0].plot([1.0],[1.0],marker = "o",ms = 8,c = "r")
axes[1].plot([1.0],[0.0],marker = "o",ms = 8,c = "r")

plt.tight_layout()
plt.savefig(os.path.join(folder,'Learned_parameters_1D_normal.pdf'))

Plot the learned probability density functions and compare them to the ground truth

---

In [None]:
beta_list = [0.2,0.4,0.6,0.8,1.0,2.0,3.0,4.0]
x_eval = torch.linspace(-7, 7, 200).reshape(-1,1).to(device)

fig,axes = plt.subplots(4,2, figsize = (10,20))

model.eval()


counter = 0

for i in range(4):
    for j in range(2):
        ax = axes[i][j]
    
        beta_tensor_i = torch.ones_like(x_eval).to(device) * beta_list[counter]
        sigma_ref = 1 / np.sqrt(beta_list[counter])

        y_est_i = model(x_eval,beta_tensor_i).exp().detach()
        
        p_ref = torch.distributions.Normal(0, sigma_ref)
        y_target = p_ref.log_prob(x_eval.cpu()).exp().detach().numpy()

        ax.plot(x_eval.flatten().cpu(), y_target, label='Target',linewidth = 3)
        ax.plot(x_eval.flatten().cpu(), y_est_i.flatten().cpu(), label='Estimated',ls = "-.",linewidth = 3)
        ax.set_title(r"$\beta$"+f' = {beta_list[counter]}',fontsize = fs)
        ax.set_ylim([0,1.0])

        ax.set_xlabel('x',fontsize = fs)
        ax.set_ylabel(r'$p(x|\beta)$',fontsize = fs)

        ax.legend(fontsize = fs * 0.75) 
        ax.tick_params(axis='x', labelsize=fs)
        ax.tick_params(axis='y', labelsize=fs)

        counter += 1
    
plt.tight_layout()
plt.savefig(os.path.join(folder,'Learned_likelihoods_1D_normal.pdf'))