<a href="https://colab.research.google.com/github/HugoSenetaire/LatentEBMXSNL/blob/main/Copie_de_LEBM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Intro

In [None]:
import torch as t
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as tfm

In [None]:
!pip install wandb

# Data

In [None]:
import torch as t, torch.nn as nn
import torchvision as tv, torchvision.transforms as tfm

dataset = "SVHN"

save_image_every = 50
log_every = 10

if dataset == "SVHN_original":
  img_size, batch_size = 32, 256
  nz, nc, ndf, ngf = 100, 3, 200, 64
  K_0, a_0, K_1, a_1 = 60, 0.4, 40, 0.1
  llhd_sigma = 0.3
  n_iter = 70000
  device = t.device('cuda' if t.cuda.is_available() else 'cpu')
elif dataset == "SVHN":
  img_size, batch_size = 32, 256
  nz, nc, ndf, ngf = 100, 3, 200, 64
  K_0, a_0, K_1, a_1 = 20, 0.4, 20, 0.1
  llhd_sigma = 0.3
  n_iter = 70000
  device = t.device('cuda' if t.cuda.is_available() else 'cpu')

elif dataset == "MNIST":
  img_size, batch_size = 28, 256
  nz, nc, ndf, ngf = 16, 1, 200, 16
  K_0, a_0, K_1, a_1 = 20, 0.4, 20, 0.1
  llhd_sigma = 0.3
  n_iter = 70000
  device = t.device('cuda' if t.cuda.is_available() else 'cpu')
elif dataset == "CIFAR_10" :
  img_size, batch_size = 28, 256
  nz, nc, ndf, ngf = 16, 1, 200, 16
  K_0, a_0, K_1, a_1 = 20, 0.4, 20, 0.1
  llhd_sigma = 0.3
  n_iter = 70000
  device = t.device('cuda' if t.cuda.is_available() else 'cpu')



cfg = {
    "dataset": dataset,
    "img_size": img_size,
    "batch_size": batch_size,
    "nz": nz,
    "nc": nc,
    "ndf": ndf,
    "ngf": ngf,
    "K_0": K_0,
    "a_0": a_0,
    "K_1": K_1,
    "a_1": a_1,
    "llhd_sigma": llhd_sigma,
    "n_iter": n_iter,
    "device": device,
}



In [None]:
if dataset.startswith('SVHN'):
  transform = tfm.Compose([tfm.Resize(img_size), tfm.ToTensor(), tfm.Normalize(([0.5]*3), ([0.5]*3)),])
  data = t.stack([x[0] for x in tv.datasets.SVHN(download=True, root='data/svhn', transform=transform)]).to(device)
elif dataset == "MNIST":
  transform = tfm.Compose([tfm.Resize(img_size), tfm.ToTensor(), tfm.Normalize((0.5), (0.5),)])
  data = t.stack([x[0] for x in tv.datasets.MNIST(download=True, root='data/mnist', transform=transform)]).to(device)


# Network

In [None]:

class _G_MNIST(nn.Module):
    def __init__(self):
        super().__init__()
        self.gen = nn.Sequential(nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0), nn.LeakyReLU(),
            nn.ConvTranspose2d(ngf*8, ngf*4, 3, 2, 1), nn.LeakyReLU(),
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1), nn.LeakyReLU(),
            nn.ConvTranspose2d(ngf*2, nc, 4, 2, 1), nn.Tanh())
    def forward(self, z):
        return self.gen(z)

class _G_SVHN(nn.Module):
    def __init__(self):
        super().__init__()
        self.gen = nn.Sequential(nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0), nn.LeakyReLU(),
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1), nn.LeakyReLU(),
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1), nn.LeakyReLU(),
            nn.ConvTranspose2d(ngf*2, nc, 4, 2, 1), nn.Tanh())
    def forward(self, z):
        return self.gen(z)


class _Encoder_SVHN(nn.Module):
    def __init__(self,):
        super().__init__()

        self.conv_net = nn.Sequential(
                nn.Conv2d(in_channels=nc, out_channels=ngf*2, kernel_size=5, stride=2, padding=2),
                nn.LeakyReLU(0.3),
                nn.Conv2d(in_channels=ngf*2, out_channels=ngf*4, kernel_size=5, stride=2, padding=2),
                nn.LeakyReLU(0.3),
                nn.Conv2d(in_channels=ngf*4, out_channels=ngf*8, kernel_size=5, stride=2, padding=2),
                nn.LeakyReLU(0.3),
        )
        self.fc=nn.Sequential(
                # nn.Linear(16*ngf*8, 256),
                # nn.ReLU(),
                # nn.Linear(256, 2*nz),
                nn.Linear(16*ngf*8, 2*nz),
                )

    def forward(self, x):
        x = self.conv_net(x)
        x = x.flatten(1)
        return self.fc(x)

class _E_MNIST(nn.Module):
    def __init__(self):
        super().__init__()
        self.ebm = nn.Sequential(nn.Linear(nz, ndf), nn.LeakyReLU(0.2),
            nn.Linear(ndf, ndf), nn.LeakyReLU(0.2),
            nn.Linear(ndf, ndf), nn.LeakyReLU(0.2),
            nn.Linear(ndf, 1))
    def forward(self, z):
        return self.ebm(z.squeeze()).view(-1, 1, 1, 1)

class _E_SVHN(nn.Module):
  def __init__(self):
        super().__init__()
        self.mean = torch.nn.parameter.Parameter(torch.tensor(0.,),requires_grad=False)
        self.std = torch.nn.parameter.Parameter(torch.tensor(1.,),requires_grad=False)
        self.ebm = nn.Sequential(nn.Linear(nz, ndf), nn.LeakyReLU(0.2),
            nn.Linear(ndf, ndf), nn.LeakyReLU(0.2),
            nn.Linear(ndf, ndf), nn.LeakyReLU(0.2),
            nn.Linear(ndf, 1))
  def forward(self, z):
      z_squeeze = z.squeeze()
      energy = self.ebm(z_squeeze)
      base_dist = torch.distributions.normal.Normal(self.mean, self.std).log_prob(z_squeeze).detach()
      base_dist = base_dist.reshape(z.shape).flatten(1).sum(1,).reshape(-1,1)
      # print(energy.shape)
      # print(base_dist.shape)
      return (energy-base_dist).view(-1, 1, 1, 1)


class _Encoder_MNIST(nn.Module):
    def __init__(self,):
        super().__init__()

        self.conv_net = nn.Sequential(
                nn.Conv2d(in_channels=nc, out_channels=ngf*2, kernel_size=5, stride=2, padding=2),
                nn.LeakyReLU(0.3),
                nn.Conv2d(in_channels=ngf*2, out_channels=ngf*4, kernel_size=5, stride=2, padding=2),
                nn.LeakyReLU(0.3),
                nn.Conv2d(in_channels=ngf*4, out_channels=ngf*8, kernel_size=5, stride=2, padding=2),
                nn.LeakyReLU(0.3),
        )
        self.fc=nn.Sequential(
                # nn.Linear(16*ngf*8, 256),
                # nn.ReLU(),
                # nn.Linear(256, 2*nz),
                nn.Linear(16*ngf*8, 2*nz),
                )


    def forward(self, x):
        x = self.conv_net(x)
        x = x.flatten(1)
        return self.fc(x)

if dataset == "MNIST":
  _G = _G_MNIST
  _Encoder = _Encoder_MNIST
  _E = _E_MNIST
elif dataset.startswith("SVHN"):
  _G = _G_SVHN
  _Encoder = _Encoder_SVHN
  _E = _E_SVHN

else :
  raise NotImplementedError()


# Utils

## Sample

In [None]:

def sample_p_data():
    return data[t.LongTensor(batch_size).random_(0, data.size(0))].detach()

def sample_p_0(n=batch_size):
    return t.randn(*[n, nz, 1, 1]).to(device)

def sample_langevin_prior(z, E):
    z = z.clone().detach().requires_grad_(True)
    for i in range(K_0):
        en = E(z)
        z_grad = t.autograd.grad(en.sum(), z)[0]
        z.data = z.data - 0.5 * a_0 * a_0 * (z_grad + 1.0 / z.data) + a_0 * t.randn_like(z).data
    return z.detach()

def sample_langevin_posterior(z, x, G, E):
    z = z.clone().detach().requires_grad_(True)
    for i in range(K_1):
        x_hat = G(z)
        g_log_lkhd = 1.0 / (2.0 * llhd_sigma * llhd_sigma) * mse(x_hat, x)
        grad_g = t.autograd.grad(g_log_lkhd, z)[0]
        en = E(z)
        grad_e = t.autograd.grad(en.sum(), z)[0]
        z.data = z.data - 0.5 * a_1 * a_1 * (grad_g + grad_e + 1.0 / z.data) + a_1 * t.randn_like(z).data
    return z.detach()






## Clipping Gradients

In [None]:
import torch



def clip_grad_adam(parameters, optimizer, nb_sigmas = 3):
    with torch.no_grad():
        for group in optimizer.param_groups:
            for p in group['params']:
                if p.grad is None or p.grad.data is None:
                    continue
                state = optimizer.state[p]

                if 'step' not in state or state['step'] < 1:
                    continue

                step = state['step']
                exp_avg_sq = state['exp_avg_sq']
                _, beta2 = group['betas']

                bound = nb_sigmas * torch.sqrt(exp_avg_sq / (1 - beta2 ** step)) + 0.1
                p.grad.data.copy_(torch.max(torch.min(p.grad.data, bound), -bound))


def grad_clipping(net, net_name, cfg, current_optim, logger, step):
        # Grad clipping
        clip_grad_type = cfg[net_name+"_clip_grad_type"]
        clip_grad_value = cfg[net_name+"_clip_grad_value"]
        nb_sigmas = cfg[net_name+"_nb_sigma"]
        if clip_grad_type == "norm":
            if clip_grad_value is not None:
                logger.log({"train/{}_clip_grad_norm".format(net_name): clip_grad_value}, step=step)
                torch.nn.utils.clip_grad_norm_(
                    parameters=net.parameters(),
                    max_norm=clip_grad_value,
                )
        elif clip_grad_type == "abs":
            if clip_grad_value is not None:
                logger.log({"train/{}_clip_grad_abs".format(net_name): clip_grad_value}, step=step)
                torch.nn.utils.clip_grad_value_(
                    parameters=net.parameters(),
                    clip_value=clip_grad_value,
                )
        elif clip_grad_type == "adam":
            if nb_sigmas is not None:
                logger.log({"train/{}_clip_grad_adam_nb_sigmas".format(net_name): nb_sigmas}, step=step)
                clip_grad_adam(net.parameters(),
                        current_optim,
                        nb_sigmas=nb_sigmas)
        elif clip_grad_type is None:
            pass
        else :
            raise NotImplementedError

def grad_clipping_all_net(liste_network = [], liste_name = [], liste_optim = [], logger = None, cfg =None, step=None):
    for net, net_name, optim in zip(liste_network, liste_name, liste_optim):
      grad_clipping(net, net_name, cfg, optim, logger, step=step)





## Regularization

In [None]:
from torch.autograd import grad as torch_grad

def wgan_gradient_penalty(ebm, x, x_gen,):
    batch_size = x.size()[0]
    min_data_len = min(batch_size,x_gen.size()[0])
    # Calculate interpolation
    epsilon = torch.rand(min_data_len, device=x.device)
    for _ in range(len(x.shape) - 1):
        epsilon = epsilon.unsqueeze(-1)
    epsilon = epsilon.expand(min_data_len, *x.shape[1:])
    epsilon = epsilon.to(x.device)
    interpolated = epsilon*x.data[:min_data_len] + (1-epsilon)*x_gen.data[:min_data_len]
    interpolated = interpolated.detach()
    interpolated.requires_grad_(True)

    # Calculate probability of interpolated examples
    prob_interpolated = ebm.f_theta(interpolated).flatten(1).sum(1)

    # Calculate gradients of probabilities with respect to examples
    gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                           grad_outputs=torch.ones(prob_interpolated.size()).to(x.device),
                           create_graph=True, retain_graph=True)[0]

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(min_data_len, -1)

    # Derivatives of the gradient close to 0 can cause problems because of
    # the square root, so manually calculate norm and add epsilon
    gradients_norm = torch.sum(gradients ** 2, dim=1).mean()
    return gradients_norm

def regularization(ebm, x, x_gen, energy_data, energy_samples, cfg, logger, step):
        '''
        Compute different gradients and regularization terms given the energy or the loss.
        '''
        dic_loss = {}
        # Regularization
        if cfg["l2_grad"] is not None and cfg["l2_grad"] > 0:
            grad_norm = wgan_gradient_penalty(ebm, x, x_gen)
            dic_loss["l2_grad"] = cfg["l2_grad"] * grad_norm
            logger.log({"penalty/grad_norm": grad_norm},step=step)
            logger.log({"penalty/regularization_l2_grad": dic_loss["l2_grad"]},step=step)

        if cfg["l2_output"] is not None and cfg["l2_output"] > 0:
            l2_output = ((energy_data**2).mean() + (energy_samples**2).mean())
            dic_loss["loss_l2_output"] = cfg["l2_output"] * l2_output
            logger.log({"penalty/l2_output": l2_output},step=step)
            logger.log({"penalty/regularization_l2_output": dic_loss["loss_l2_output"]},step=step)

        if cfg["l2_param"] is not None and cfg["l2_param"] > 0:
            penalty = 0.
            len_params = 0.
            for params in ebm.parameters():
                len_params += params.numel()
                penalty += torch.sum(params**2)
            penalty = penalty / len_params
            dic_loss["loss_l2_param"] = cfg["l2_param"] * penalty
            logger.log({"penalty/l2_param": penalty},step=step)
            logger.log({"penalty/regularization_l2_param": dic_loss["loss_l2_param"]},step=step)

        return dic_loss



## Logger

In [None]:


from IPython import display
import torch

global_dic_error = {}

def log(step, dic_loss, logger):
  for key,value in dic_loss.items():
    logger.log({key:value},step=step)


def draw_samples(fig, axs, prior_0, langevin_prior, posterior, approximate_posterior, step, dic_loss, logger):
  grid_prior = tv.utils.make_grid(prior_0/2+0.5,)
  grid_langevin_prior = tv.utils.make_grid(langevin_prior/2+0.5)
  grid_langevin_posterior = tv.utils.make_grid(posterior/2+0.5)

  axs[0].imshow(grid_prior.detach().cpu().permute(1,2,0).numpy())
  axs[0].set_title("BaseDistribution")
  axs[1].imshow(grid_langevin_prior.detach().cpu().permute(1,2,0).numpy())
  axs[1].set_title("Prior")
  axs[2].imshow(grid_langevin_posterior.detach().cpu().permute(1,2,0).numpy())
  axs[2].set_title("Posterior")

  if approximate_posterior is not None :
    grid_langevin_approximate_posterior = tv.utils.make_grid(approximate_posterior/2+0.5)
    axs[3].set_title("Approximate Posterior")
    axs[3].imshow(grid_langevin_approximate_posterior.detach().cpu().permute(1,2,0).numpy())

  img= wandb.Image(fig, caption=f"Step {step}")
  logger.log({f"All_samples.png": img},step=step)



  title = "Step:{}, time per step : {:2.3f}".format(step, (time.time()-start_time)/10,)
  for key,value in dic_loss.items():
    title+=", {} : {:2.3f}".format(key,value)
    logger.log({key:value},step=step)
  fig.suptitle(title)
  display.display(fig)
  display.clear_output(wait=True)





# Different Trainer

In [None]:

def train_cd(x, G,E, loss, optE, optG, logger, cfg):
  for i in range(n_iter):
      x = sample_p_data()
      z_e_0, z_g_0 = sample_p_0(), sample_p_0()
      z_e_k, z_g_k = sample_langevin_prior(z_e_0, E), sample_langevin_posterior(z_g_0, x, G, E)


      optG.zero_grad()
      x_hat = G(z_g_k.detach())
      loss_g = mse(x_hat, x) / batch_size
      loss_g.backward()
      grad_clipping_all_net([G], ["G"], [optG], logger, cfg, i)
      optG.step()

      optE.zero_grad()
      en_pos, en_neg = E(z_g_k.detach()).mean(), E(z_e_k.detach()).mean()
      loss_e = en_pos - en_neg
      regularization(E, z_g_k, z_e_k, en_pos, en_neg, cfg, logger,i)
      loss_e.backward()
      grad_clipping_all_net([E], ["E"], [optE], logger, cfg, i)
      optE.step()
      dic_loss = {
          "loss_e": loss_e,
          "loss_g": loss_g,
          "en_pos": en_pos,
          "en_neg": en_neg,
        }

      if i%log_every == 0 :
        log(i, dic_loss, logger)
      if i%save_image_every== 0:
        x_prior_langevin = G(z_e_k)
        x_prior = G(z_e_0)
        x_hat = x_hat
        draw_samples(fig, axs, x_prior[:64], x_prior_langevin[:64], x_hat[:64], None, i, dic_loss, logger)




In [None]:

def train_cd_trick(n_iter, G,E, loss, optE, optG, logger):
  proposal = torch.distributions.normal.Normal(torch.tensor(cfg["proposal_mean"],device=device),torch.tensor(cfg["proposal_std"],device=device))
  base_dist = torch.distributions.normal.Normal(torch.tensor(0,device=device),torch.tensor(1,device=device))

  for i in range(n_iter):
      x = sample_p_data()
      z_e_0, z_g_0 = sample_p_0(), sample_p_0()
      z_e_k, z_g_k = sample_langevin_prior(z_e_0, E), sample_langevin_posterior(z_g_0, x, G, E)


      optG.zero_grad()
      x_hat = G(z_g_k.detach())
      loss_g = mse(x_hat, x) / batch_size
      loss_g.backward()
      grad_clipping_all_net([G], ["G"], [optG], logger, cfg, i)
      optG.step()

      optE.zero_grad()
      # en_pos, en_neg = E(z_g_k.detach()).mean(), E(z_e_k.detach()).mean()
      energy_posterior = E(z_g_k.detach()).flatten(1).sum(1)
      z_proposal = proposal.sample(z_g_k.shape)
      energy_proposal = E(z_proposal.detach()).flatten(1).sum(1)
      base_dist_z_proposal = base_dist.log_prob(z_proposal.flatten(1)).sum(1)
      base_dist_z_posterior = base_dist.log_prob(z_g_k.flatten(1)).sum(1)
      base_dist_z_base_dist = base_dist.log_prob(z_e_0.flatten(1)).sum(1)
      proposal_z_proposal = proposal.log_prob(z_proposal.flatten(1)).sum(1)
      proposal_z_posterior = proposal.log_prob(z_g_k.flatten(1)).sum(1)
      proposal_z_base_dist = proposal.log_prob(z_e_0.flatten(1)).sum(1)


      log_partition_estimate = torch.logsumexp(-energy_proposal,0) - math.log(energy_proposal.shape[0])
      loss_e = (energy_posterior-proposal_z_posterior).mean() + log_partition_estimate.exp() -1
      regularization(E, z_g_k, z_proposal, energy_posterior, energy_proposal, cfg, logger,i)
      loss_e.backward()
      grad_clipping_all_net([E], ["E"], [optE], logger, cfg, i)
      optE.step()
      dic_loss = {
          "loss_e": loss_e.mean().item(),
          "loss_g": loss_g.mean().item(),
          "base_dist_z_proposal": base_dist_z_proposal.mean().item(),
          "base_dist_z_posterior":base_dist_z_posterior.mean().item(),
          "base_dist_z_base_dist": base_dist_z_base_dist.mean().item(),
          "proposal_z_proposal": proposal_z_proposal.mean().item(),
          "proposal_z_posterior":proposal_z_posterior.mean().item(),
          "proposal_z_base_dist": proposal_z_base_dist.mean().item(),
          "en_pos": energy_posterior.mean().item(),
          "en_neg": energy_proposal.mean().item(),
          "log_z" : log_partition_estimate.item()
      }

      if i%log_every == 0 :
        log(i, dic_loss, logger)
      if i%save_image_every== 0:
        x_prior_langevin = G(z_e_k)
        x_prior = G(z_e_0)
        x_hat = x_hat

        draw_samples(fig, axs, x_prior[:64], x_prior_langevin[:64], x_hat[:64], None, i, dic_loss, logger)




In [None]:
def train_auto_encoder(n_tier, G, E, Encoder, loss, optE, optG, optEncoder, logger):
  log_var_p = None
  for i in range(n_iter):
    optG.zero_grad()
    optE.zero_grad()
    optEncoder.zero_grad()

    x = sample_p_data()
    z_e_0, z_g_0 = sample_p_0(), sample_p_0()
    mu_q, log_var_q = Encoder(x).chunk(2,1)

    x_hat = G(mu_q.reshape(-1,nz,1,1))

    loss_g = (mse(x_hat, x) / batch_size)

    loss_g.backward()
    optG.step()
    optE.step()
    dic_loss = {"mse":loss_g,
                "mu_q_std": mu_q.flatten(1).std(1).mean()}


    if i%log_every == 0 :
        log(i, dic_loss, logger)
    if i%save_image_every== 0:
        draw_samples(fig, axs, x_hat[:64], x_hat[:64], x_hat[:64], x_hat[:64], i, dic_loss, logger)










In [None]:


def train_elbo_notrick(n_iter, G, E, Encoder, loss, optE, optG, optEncoder, logger):
  log_var_p = None
  for i in range(n_iter):
      optG.zero_grad()
      optE.zero_grad()
      optEncoder.zero_grad()


      x = sample_p_data()
      z_e_0, z_g_0 = sample_p_0(), sample_p_0()

      mu_q, log_var_q = Encoder(x).chunk(2,1)
      if log_var_p is None :
        log_var_p = torch.log(torch.full_like(log_var_q, llhd_sigma)).pow(2)

      std_q = torch.exp(0.5*log_var_q)
      eps = torch.randn_like(mu_q)
      z_q = (eps.mul(std_q).add_(mu_q)).reshape(-1,nz,1,1)

      x_hat = G(z_q)
      loss_g = (mse(x_hat, x) / batch_size)/(llhd_sigma**2)
      KL_loss = 0.5 * (log_var_p - log_var_q -1 +  (log_var_q.exp() + mu_q.pow(2))/log_var_p.exp())
      KL_loss = KL_loss.sum(dim=1).mean(dim=0)

      loss_total = loss_g + KL_loss
      loss_total.backward()

      dic_loss = {
          "loss_g":loss_g.item(),
          "KL_loss":KL_loss.item(),
          "elbo": -loss_total.item(),
      }
      optE.step()
      optG.step()
      optEncoder.step()

      if i%log_every == 0 :
        log(i, dic_loss, logger)
      if i%save_image_every== 0:
        z_e_k, z_g_k = sample_langevin_prior(z_e_0, E), sample_langevin_posterior(z_g_0, x, G, E)
        x_prior_langevin = G(z_e_k)
        x_prior = G(z_e_0)
        x_posterior = G(z_g_k)
        x_approximate_posterior = G(z_q)
        draw_samples(fig, axs, x_prior[:64], x_prior_langevin[:64], x_posterior[:64], x_approximate_posterior[:64], i, dic_loss, logger)




In [None]:
import math
def train_elbo_withtrick(n_iter, G, E, Encoder, loss, optE, optG, optEncoder, cfg, logger= None):
  """
  Here in the case where the base distribution is actually the proposal, then
  i can just calculate the entropy of the posterior, then I get loss ebm on the same level, hopefully
  """

  fix_encoder = cfg["fix_encoder"]
  fix_decoder= cfg["fix_decoder"]
  log_var_p = None
  device = next(G.parameters()).device
  proposal = torch.distributions.normal.Normal(torch.tensor(cfg["proposal_mean"],device=device),torch.tensor(cfg["proposal_std"],device=device))
  base_dist = torch.distributions.normal.Normal(torch.tensor(0,device=device, dtype=torch.float32),torch.tensor(1,device=device, dtype=torch.float32))

  for i in range(n_iter):
      optG.zero_grad()
      optE.zero_grad()
      optEncoder.zero_grad()
      x = sample_p_data()

      z_e_0, z_g_0 = sample_p_0(), sample_p_0()
      mu_q, log_var_q = Encoder(x).chunk(2,1)
      log_var_p = torch.log(torch.full_like(log_var_q, llhd_sigma)).pow(2)
      std_q = torch.exp(0.5*log_var_q)

      # Reparam trick
      eps = torch.randn_like(mu_q)
      z_q = (eps.mul(std_q).add_(mu_q)).reshape(-1,nz,1,1)
      x_hat = G(z_q)


      # Gaussian loss :
      loss_g = (mse(x_hat, x) / batch_size)/(llhd_sigma**2)

      # KL without ebm
      KL_loss = 0.5 * (log_var_p - log_var_q -1 +  (log_var_q.exp() + mu_q.pow(2))/log_var_p.exp())
      KL_loss = KL_loss.sum(dim=1).mean(dim=0)

      # Entropy posterior
      entropy_posterior = torch.sum(0.5* (math.log(2*math.pi) +  log_var_q + 1), dim=1).mean()

      # Energy :
      energy_approximate = E(z_q).flatten(1).sum(1)
      energy_base_dist = E(z_e_0).flatten(1).sum(1)

      base_dist_z_approximate = base_dist.log_prob(z_q.flatten(1)).sum(1)
      base_dist_z_base_dist = base_dist.log_prob(z_e_0.flatten(1)).sum(1)


      log_partition_estimate = torch.logsumexp(-energy_base_dist -base_dist_z_base_dist,0) - math.log(energy_base_dist.shape[0])
      loss_ebm = (energy_approximate - base_dist_z_approximate).mean() + log_partition_estimate.exp() -1
      # loss_ebm = (energy_approximate - base_dist_approximate).mean() + log_partition_estimate



      loss_total = loss_g - entropy_posterior + loss_ebm
      regularization(E, z_q, z_e_0, energy_approximate, energy_base_dist, cfg, logger,i)
      loss_total.backward()
      grad_clipping_all_net([E,G,Encoder], ["E", "G", "Encoder"], [optE, optG, optEncoder,], logger, cfg, i)

      dic_loss = {
          "loss_g":loss_g.item(),
          "entropy_posterior":entropy_posterior.item(),
          "loss_ebm": loss_ebm.item(),
          "base_dist_z_approximate": base_dist_z_approximate.mean().item(),
          "base_dist_z_base_dist" : base_dist_z_base_dist.mean().item(),
          "log_Z":log_partition_estimate.item(),
          "KL_loss_no_ebm": KL_loss.item(),
          "energy_approximate": energy_approximate.mean().item(),
          "energy_base_dist": energy_base_dist.mean().item(),
          "approx_elbo" : -loss_total.item(),
      }


      optE.step()
      if not fix_decoder :
        optG.step()
      if not fix_encoder :
        optEncoder.step()

      if i%log_every == 0 :
        log(i, dic_loss, logger)
      if i%save_image_every== 0:
        z_e_k, z_g_k = sample_langevin_prior(z_e_0, E), sample_langevin_posterior(z_g_0, x, G, E)
        x_prior_langevin = G(z_e_k)
        x_prior = G(z_e_0)
        x_posterior = G(z_g_k)
        x_approximate_posterior = G(z_q)
        draw_samples(fig, axs, x_prior[:64], x_prior_langevin[:64], x_posterior[:64], x_approximate_posterior[:64], i, dic_loss, logger)


In [None]:
import math
def train_elbo_withtrick_v2(n_iter, G, E, Encoder, loss, optE, optG, optEncoder, cfg, logger= None):

  fix_encoder = cfg["fix_encoder"]
  fix_decoder= cfg["fix_decoder"]
  log_var_p = None
  device = next(G.parameters()).device
  base_dist = torch.distributions.normal.Normal(torch.tensor(0,device=device),torch.tensor(1,device=device))
  proposal = torch.distributions.normal.Normal(torch.tensor(cfg["proposal_mean"],device=device, dtype=torch.float32),torch.tensor(cfg["proposal_std"],device=device, dtype=torch.float32))


  for i in range(n_iter):
      optG.zero_grad()
      optE.zero_grad()
      optEncoder.zero_grad()
      x = sample_p_data()

      z_e_0, z_g_0 = sample_p_0(), sample_p_0()
      mu_q, log_var_q = Encoder(x).chunk(2,1)
      log_var_p = torch.log(torch.full_like(log_var_q, llhd_sigma)).pow(2)
      std_q = torch.exp(0.5*log_var_q)

      # Reparam trick
      eps = torch.randn_like(mu_q)
      z_q = (eps.mul(std_q).add_(mu_q)).reshape(-1,nz,1,1)
      x_hat = G(z_q)


      # Gaussian loss :
      loss_g = (mse(x_hat, x) / batch_size)/(llhd_sigma**2)

      # KL without ebm
      KL_loss = 0.5 * (log_var_p - log_var_q -1 +  (log_var_q.exp() + mu_q.pow(2))/log_var_p.exp())
      KL_loss = KL_loss.sum(dim=1).mean(dim=0)

      # Entropy posterior
      entropy_posterior = torch.sum(0.5* (math.log(2*math.pi) +  log_var_q + 1), dim=1).mean()

      # Energy Proposal
      z_proposal = proposal.sample((z_e_0.shape[0], nz,1,1))
      energy_approximate = E(z_q).flatten(1).sum(1)
      energy_proposal = E(z_proposal).flatten(1).sum(1)
      energy_prior = E(z_e_0).flatten(1).sum(1)

      base_dist_z_proposal = base_dist.log_prob(z_proposal.flatten(1)).sum(1)
      base_dist_z_posterior = base_dist.log_prob(z_q.flatten(1)).sum(1)
      base_dist_z_base_dist = base_dist.log_prob(z_e_0.flatten(1)).sum(1)
      proposal_z_proposal = proposal.log_prob(z_proposal.flatten(1)).sum(1)
      proposal_z_posterior = proposal.log_prob(z_q.flatten(1)).sum(1)
      proposal_z_base_dist = proposal.log_prob(z_e_0.flatten(1)).sum(1)


      log_partition_estimate = torch.logsumexp(-energy_proposal - proposal_z_proposal,0) - math.log(energy_proposal.shape[0])
      loss_ebm = energy_approximate.mean() + log_partition_estimate.exp() -1
      # loss_ebm =  energy_approximate.mean() + log_partition_estimate

      loss_total = loss_g + KL_loss + loss_ebm
      regularization(E, z_q, z_e_0, energy_approximate, energy_proposal, cfg, logger,i)
      loss_total.backward()
      grad_clipping_all_net([E,G,Encoder], ["E", "G", "Encoder"], [optE, optG, optEncoder,], logger, cfg, i)

      dic_loss = {
          "loss_g": loss_g.item(),
          "entropy_posterior": entropy_posterior.item(),
          "loss_ebm": loss_ebm.item(),
          "log_Z": log_partition_estimate.item(),
          "KL_loss_no_ebm": KL_loss.item(),
          "energy_approximate": energy_approximate.mean().item(),
          "energy_proposal" : energy_proposal.mean().item(),
          "energy_prior": energy_prior.mean().item(),
          "base_dist_z_proposal": base_dist_z_proposal.mean().item(),
          "base_dist_z_posterior":base_dist_z_posterior.mean().item(),
          "base_dist_z_base_dist": base_dist_z_base_dist.mean().item(),
          "proposal_z_proposal": proposal_z_proposal.mean().item(),
          "proposal_z_posterior":proposal_z_posterior.mean().item(),
          "proposal_z_base_dist": proposal_z_base_dist.mean().item(),
          "approx_elbo" : -loss_total.item(),
      }


      optE.step()
      if not fix_decoder :
        optG.step()
      if not fix_encoder :
        optEncoder.step()

      if i%log_every == 0 :
        log(i, dic_loss, logger)
      if i%save_image_every== 0:
        z_e_k, z_g_k = sample_langevin_prior(z_e_0, E), sample_langevin_posterior(z_g_0, x, G, E)
        x_prior_langevin = G(z_e_k)
        x_prior = G(z_e_0)
        x_posterior = G(z_g_k)
        x_approximate_posterior = G(z_q)
        draw_samples(fig, axs, x_prior[:64], x_prior_langevin[:64], x_posterior[:64], x_approximate_posterior[:64], i, dic_loss, logger)




# Train

In [None]:

cfg_clipping = {
    "E_clip_grad_type": "norm",
    "E_clip_grad_value": 0.5,
    "E_nb_sigma": 3,

    "G_clip_grad_type": "norm",
    "G_clip_grad_value": 0.5,
    "G_nb_sigma": 3,

    "Encoder_clip_grad_type": "norm",
    "Encoder_clip_grad_value": 0.5,
    "Encoder_nb_sigma": 3,
}

cfg_regularization = {
    "l2_grad":0.0,
    "l2_param":1.0,
    "l2_output":0.0,
}

cfg_proposal = {
    "proposal_mean" : 0.0,
    "proposal_std" : 1.0,
}

cfg.update(cfg_proposal)
cfg.update(cfg_clipping)
cfg.update(cfg_regularization)

In [None]:
import tqdm
import matplotlib.pyplot as plt
import wandb




fig, axs = plt.subplots(1,4, figsize=(20,5))

import time
start_time = time.time()


cfg.update({"lr_E" : 0.00005,
            "beta_1_E": 0.5,
            "beta_2_E":0.999,
            "lr_G" : 0.0001,
            "beta_1_G": 0.5,
            "beta_2_G":0.999,
            "lr_Enc" : 0.0001,
            "beta_1_Enc": 0.5,
            "beta_2_Enc":0.999,
})
logger = wandb.init(project="LatentEBM", config=cfg)

G = _G().to(device)
E = _E().to(device)
Encoder = _Encoder().to(device)
mse = nn.MSELoss(reduction='sum')
optE = t.optim.Adam(E.parameters(), lr=cfg["lr_E"], betas=(cfg["beta_1_E"], cfg["beta_2_E"]))
optG = t.optim.Adam(G.parameters(), lr=cfg["lr_G"], betas=(cfg["beta_1_G"], cfg["beta_2_G"]))
optEncoder = t.optim.Adam(Encoder.parameters(), lr=cfg["lr_Enc"], betas=(cfg["beta_1_Enc"], cfg["beta_2_Enc"]))
# train_elbo_notrick(500, G, E, Encoder, mse, optE, optG, optEncoder, logger)
# train_auto_encoder(1000, G, E, Encoder, mse, optE, optG, optEncoder)
# train_cd(1000, G, E, mse, optE, optG)
# train_cd_trick(n_iter,G,E,mse,optE,optG, logger)

In [None]:

cfg.update({"lr_E" : 0.00002,
            "beta_1_E": 0.5,
            "beta_2_E":0.999,
            "lr_G" : 0.00002,
            "beta_1_G": 0.5,
            "beta_2_G":0.999,
            "lr_Enc" : 0.00002,
            "beta_1_Enc": 0.5,
            "beta_2_Enc":0.999,
            "trainer": "tricky_elbo",
            "fix_encoder": False,
            "fix_decoder": False,
            "proposal_mean": 0,
            "proposal_std":1,

})


logger = wandb.init(project="LatentEBM", config=cfg)

G2 = _G()
E2 = _E()
Encoder2 = _Encoder()
for p1, p2 in zip(G.parameters(),G2.parameters()):
  p2.data = p1.data.clone()

for p1, p2 in zip(E.parameters(),E2.parameters()):
  p2.data = p1.data.clone()

for p1, p2 in zip(Encoder.parameters(),Encoder2.parameters()):
  p2.data = p1.data.clone()

optE2 = t.optim.Adam(E2.parameters(), lr=cfg["lr_E"], betas=(cfg["beta_1_E"], cfg["beta_2_E"]))
optG2 = t.optim.Adam(G2.parameters(), lr=cfg["lr_G"], betas=(cfg["beta_1_G"], cfg["beta_2_G"]))
optEncoder2 = t.optim.Adam(Encoder2.parameters(), lr=cfg["lr_Enc"], betas=(cfg["beta_1_Enc"], cfg["beta_2_Enc"]))

train_elbo_withtrick(n_iter, G2, E2, Encoder2, mse, optE2, optG2, optEncoder2, cfg, logger)
# train_elbo_withtrick_v2(n_iter, G2, E2, Encoder2, mse, optE2, optG2, optEncoder2, cfg, logger)