In [1]:
import copy
import torch
import warnings
import torchvision
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.distributions as D
import torch.nn.functional as F

from pathlib import Path
from tqdm.auto import tqdm
from torchvision import transforms
from sklearn.manifold import TSNE
# from torchvision.datasets import MNIST
from torch.utils.data import TensorDataset, DataLoader

warnings.filterwarnings('ignore')


In [2]:
# @title Set random seed

# @markdown Executing `set_seed(seed=seed)` you are setting the seed

# For DL its critical to set the random seed so that students can have a
# baseline to compare their results to expected results.
# Read more here: https://pytorch.org/docs/stable/notes/randomness.html

# Call `set_seed` function in the exercises to ensure reproducibility.
import random
import torch

def set_seed(seed=None, seed_torch=True):
  """
  Function that controls randomness. NumPy and random modules must be imported.

  Args:
    seed : Integer
      A non-negative integer that defines the random state. Default is `None`.
    seed_torch : Boolean
      If `True` sets the random seed for pytorch tensors, so pytorch module
      must be imported. Default is `True`.

  Returns:
    Nothing.
  """
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)
  if seed_torch:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

  print(f'Random seed {seed} has been set.')


# In case that `DataLoader` is used
def seed_worker(worker_id):
  """
  DataLoader will reseed workers following randomness in
  multi-process data loading algorithm.

  Args:
    worker_id: integer
      ID of subprocess to seed. 0 means that
      the data will be loaded in the main process
      Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details

  Returns:
    Nothing
  """
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)

In [3]:
set_seed(2024)

Random seed 2024 has been set.


In [4]:
OUTPUTS_DIR = 'outputs_toy_final'

if not Path.exists(Path(OUTPUTS_DIR)):
    Path.mkdir(Path(OUTPUTS_DIR), exist_ok=True, parents=True)

if torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'

## Gaussian Mixture Class

Reference: https://deeplearning.neuromatch.io/tutorials/W2D4_GenerativeModels/student/W2D4_Tutorial2.html#custom-gaussian-mixture-class

In [5]:
class GaussianMixture:

  def __init__(self, mus, covs, weights, device='cuda'):
    """
    mus: a list of K 1d np arrays (D,)
    covs: a list of K 2d np arrays (D, D)
    weights: a list or array of K unnormalized non-negative weights, signifying the possibility of sampling from each branch.
      They will be normalized to sum to 1. If they sum to zero, it will err.
    """
    self.n_component = len(mus)
    self.mus = mus
    self.a = covs
    self.device = device
    self.precs = [torch.from_numpy(np.linalg.inv((cov * torch.eye(2).cuda()).cpu().numpy())).to(cov.device) for cov in covs]
    self.weights = np.array(weights)
    self.norm_weights = self.weights / self.weights.sum()
    self.RVs = []
    for i in range(len(mus)):
      self.RVs.append(D.Independent(D.Normal(mus[i], covs[i]), 1))
    self.dim = len(mus[0])

  def compute_mean(self):

    mu_bar = 0.0
    for weight, mu in zip(self.norm_weights, self.mus):
      mu_bar += weight * mu

    return mu_bar

  def compute_variance(self):

    mu_bar = self.compute_mean()
    var = None
    for weight, mu, cov in zip(self.norm_weights, self.mus, self.covs):

      assert cov[0] == cov[1]

      covmat = cov[0]**2 * torch.eye(2).to(cov.device)

      temp = weight * (covmat + torch.mm((mu - mu_bar).unsqueeze(0).T, (mu - mu_bar).unsqueeze(0)))
      var = temp if var is None else var + temp

    return var

  def add_component(self, mu, cov, weight=1):
    self.mus.append(mu)
    self.covs.append(cov)
    self.precs.append(np.linalg.inv(cov.cpu().numpy()))
    self.RVs.append(D.Independent(D.Normal(mu, cov), 1))
    self.weights.append(weight)
    self.norm_weights = self.weights / self.weights.sum()
    self.n_component += 1

  def pdf_decompose(self, x):
    """
      probability density (PDF) at $x$.
    """
    component_pdf = []
    prob = None
    for weight, RV in zip(self.norm_weights, self.RVs):
        pdf = weight * RV.log_prob(x).exp()
        prob = pdf if prob is None else (prob + pdf)
        component_pdf.append(pdf)
    component_pdf = np.array(component_pdf)
    return prob, component_pdf

  def pdf(self, x):
    """
      probability density (PDF) at $x$.
    """
    isnumpy = False
    if type(x) is np.ndarray:
      isnumpy = True
      x = torch.from_numpy(x).to(self.mus[0].device)

    prob = None
    for weight, RV in zip(self.norm_weights, self.RVs):
        pdf = weight * RV.log_prob(x).exp()
        prob = pdf if prob is None else (prob + pdf)

    if isnumpy:
      prob = prob.cpu().numpy()

    return prob
  
  def scaled_pdf(self, x, scale):
    """
      probability density (PDF) at $x$.
    """
    isnumpy = False
    if type(x) is np.ndarray:
      isnumpy = True
      x = torch.from_numpy(x).to(self.mus[0].device)

    scale = torch.tensor(scale).to(self.mus[0].device)
    scale = torch.sqrt(scale)

    scaled_x = scale * x
    prob = None
    for idx in range(self.n_component):
      component = D.Independent(D.Normal(scale * self.mus[idx], self.a[idx]), 1)
      pdf = self.norm_weights[idx] * component.log_prob(scaled_x).exp()
      prob = pdf if prob is None else (prob + pdf)

    if isnumpy:
      prob = prob.cpu().numpy()

    return prob

  def score(self, x):
    """
    Compute the score $\nabla_x \log p(x)$ for the given $x$.
    """
    isnumpy = False
    if type(x) is np.ndarray:
      isnumpy = True
      x = torch.from_numpy(x).to(torch.float).to(self.mus[0].device)

    component_pdf = np.array([rv.log_prob(x).exp().cpu().numpy() for rv in self.RVs])
    component_pdf = torch.from_numpy(component_pdf).T

    weighted_compon_pdf = component_pdf * self.norm_weights[np.newaxis, :]
    participance = weighted_compon_pdf / weighted_compon_pdf.sum(axis=1, keepdims=True)
    participance = participance.to(self.mus[0].device)

    scores = torch.zeros_like(x)
    for i in range(self.n_component):
      gradvec = - (x - self.mus[i]) @ self.precs[i]
      scores += participance[:, i:i+1] * gradvec

    if isnumpy:
      scores = scores.cpu().numpy()

    return scores

  def score_decompose(self, x):
    """
    Compute the grad to each branch for the score $\nabla_x \log p(x)$ for the given $x$.
    """
    component_pdf = np.array([rv.log_prob(x).exp() for rv in self.RVs]).T
    weighted_compon_pdf = component_pdf * self.norm_weights[np.newaxis, :]
    participance = weighted_compon_pdf / weighted_compon_pdf.sum(axis=1, keepdims=True)

    gradvec_list = []
    for i in range(self.n_component):
      gradvec = - (x - self.mus[i]) @ self.precs[i]
      gradvec_list.append(gradvec)
      # scores += participance[:, i:i+1] * gradvec

    return gradvec_list, participance

  def sample(self, N):
    """ Draw N samples from Gaussian mixture
    Procedure:
      Draw N samples from each Gaussian
      Draw N indices, according to the weights.
      Choose sample between the branches according to the indices.
    """
    rand_component = np.random.choice(self.n_component, size=N, p=self.norm_weights)
    all_samples = np.array([rv.sample((N,)).cpu().numpy() for rv in self.RVs])
    gmm_samps = all_samples[rand_component, np.arange(N),:]
    return gmm_samps, rand_component, all_samples

In [6]:
# Gaussian mixture
mu1 = torch.tensor([2.0, 8.0]).to('cuda')
Cov1 = torch.tensor([1.0, 1.0]).to('cuda')

mu2 = torch.tensor([5.0, 2.0]).to('cuda')
Cov2 = torch.tensor([1.0, 1.0]).to('cuda')

mu3 = torch.tensor([8.0, 8.0]).to('cuda')
Cov3 = torch.tensor([1.0, 1.0]).to('cuda')

gmm = GaussianMixture([mu1, mu2, mu3],[Cov1, Cov2, Cov3], [1.0, 1.0, 1.0])

In [None]:
def gmm_pdf_contour_plot(gmm, xlim=None,ylim=None,ticks=100,logprob=False,label=None,**kwargs):
    if xlim is None:
        xlim = plt.xlim()
    if ylim is None:
        ylim = plt.ylim()
    xx, yy = np.meshgrid(np.linspace(*xlim, ticks), np.linspace(*ylim, ticks))
    pdf = gmm.pdf(np.dstack((xx,yy)))
    if logprob:
        pdf = np.log(pdf)
    plt.contour(xx, yy, pdf, **kwargs,)

# @title Visualize log density
show_samples = True  # @param {type:"boolean"}
np.random.seed(42)
gmm_samples, _, _ = gmm.sample(2000)
plt.figure(figsize=[8, 8])
plt.axis([0,10,0,10])
plt.scatter(gmm_samples[:, 0],
            gmm_samples[:, 1],
            s=10,
            alpha=0.4 if show_samples else 0.0)
gmm_pdf_contour_plot(gmm, cmap="Greys", levels=20, logprob=True)
plt.title("log density of gaussian mixture $\log p(x)$")
# plt.axis("image")
plt.show()

In [None]:
# @title Visualize Score
set_seed(2024)

basedist = D.Independent(D.Normal(torch.tensor([5.0, 6.0]), torch.ones(2)), 1)
gmm_samps_few, _, _ = gmm.sample(300)
gmm_samps_few = np.concatenate([gmm_samps_few, basedist.sample((100,)).cpu().numpy()], axis=0)
scorevecs_few = gmm.score(gmm_samps_few)
# gradvec_list, participance = gmm.score_decompose(gmm_samps_few)

In [None]:
def quiver_plot(pnts, vecs, *args, **kwargs):
  plt.quiver(pnts[:, 0], pnts[:,1], vecs[:, 0], vecs[:, 1], *args, **kwargs)

# @title Score for Gaussian mixture
plt.figure(figsize=[8, 8])
quiver_plot(gmm_samps_few, scorevecs_few,
            color="black", scale=25, alpha=0.7, width=0.003,
            label="score of GMM")
gmm_pdf_contour_plot(gmm, cmap="Greys")
plt.title("Score vector field $\\nabla\log p(x)$ for Gaussian Mixture")
plt.axis("image")
plt.legend()
plt.show()

## Scenarios - Gaussian Mixture Model

In [None]:
# @title Combined plots

from matplotlib.lines import Line2D

# Prior
mu_b_1 = torch.tensor([5.0, 3.0]).to('cuda')
Cov_b_1 = torch.tensor([2.0, 2.0]).to('cuda')

mu_b_2 = torch.tensor([3.0, 7.0]).to('cuda')
Cov_b_2 = torch.tensor([2.0, 2.0]).to('cuda')

mu_b_3 = torch.tensor([7.0, 7.0]).to('cuda')
Cov_b_3 = torch.tensor([2.0, 2.0]).to('cuda')

gmm_b = GaussianMixture([mu_b_1, mu_b_2, mu_b_3],[Cov_b_1, Cov_b_2, Cov_b_3], [1.0, 1.0, 1.0])

# Reward: Setup I
mu_r1 = torch.tensor([5.0, 3.0]).to('cuda')
Cov_r1 = torch.tensor([1.0, 1.0]).to('cuda')

gmm_r1 = GaussianMixture([mu_r1],[Cov_r1], [1.0])

# Reward: Setup II
mu_r2 = torch.tensor([14.0, 3.0]).to('cuda')
Cov_r2 = torch.tensor([1.0, 1.0]).to('cuda')

gmm_r2 = GaussianMixture([mu_r2],[Cov_r2], [1.0])

def z_func(x, y, func, scale=None, posterior=None):
    h, w = x.shape
    x_flat = x.reshape(-1)
    y_flat = y.reshape(-1)
    x1 = np.column_stack((y_flat, x_flat))

    if scale is None:
        results = np.apply_along_axis(func, 1, x1)
    else:
        results = np.apply_along_axis(func, 1, x1, scale=scale, posterior=posterior)
        
    return results.reshape(h,w).T

def posterior(x, scale, posterior):
    return gmm_b.pdf(x) * posterior.scaled_pdf(x, scale=scale)

# x = np.arange(15.5, -6, -0.1) # -5, 15
x = np.arange(-5, 16.5, 0.1)
y = np.arange(-5, 16.5, 0.1)
X,Y = np.meshgrid(x, y) # grid of point

IS_OOD = True # Setup I: False, Setup II: True

fig, ax = plt.subplots(1,2,figsize=(10,5))

scale_factor = 0.1

out_b = z_func(X, Y, gmm_b.pdf)
cnt1 = ax[0].contour(X,Y,out_b, label=r'$p(x)$', cmap='Blues')

if not IS_OOD:
    out_pr1 = z_func(X, Y, gmm_r1.pdf)
    cnt2 = ax[0].contour(X,Y,out_pr1, label=r'$exp(r(x))$', cmap='Reds')
    ax[0].set_xticks(x[0::50])
    ax[0].set_yticks(y[0::50])

    out_post1 = z_func(X, Y, posterior, scale=scale_factor, posterior=gmm_r1)
    out_post1 /= np.sum(out_post1)
    cnt3 = ax[1].contour(X,Y,out_post1, label=r'$p^*(x) \propto p(x)exp(r(x))$', cmap='Greens')
    ax[1].set_xticks(x[0::50])
    ax[1].set_yticks(y[0::50])

else:
    out_pr2 = z_func(X, Y, gmm_r2.pdf)
    cnt2 = ax[0].contour(X,Y,out_pr2, label=r'$exp(r(x))$', cmap='Reds')
    ax[0].set_xticks(x[0::50])
    ax[0].set_yticks(y[0::50])

    out_post2 = z_func(X, Y, posterior, scale=scale_factor, posterior=gmm_r2)
    out_post2 /= np.sum(out_post2)
    cnt3 = ax[1].contour(X,Y,out_post2, label=r'$p^*(x) \propto p(x)exp(r(x))$', cmap='Greens')
    ax[1].set_xticks(x[0::50])
    ax[1].set_yticks(y[0::50])

handles, labels = plt.gca().get_legend_handles_labels()

line_b = Line2D([0], [0], label=r'Prior: $p(x)$', color=cnt1.collections[8].get_edgecolor()[0], lw=2)
line_pr = Line2D([0], [0], label=r'Reward: $p(r|x)$', color=cnt2.collections[7].get_edgecolor()[0], lw=2)
line_post = Line2D([0], [0], label=r'Posterior: $p(x|r) \propto p(x)p(r|x)$', color=cnt3.collections[6].get_edgecolor()[0], lw=2)

handles.extend([line_b, line_pr, line_post])
fig.legend(handles=handles, bbox_to_anchor=(0.97, 1.09), ncols=3, fontsize=16, frameon=False)
fig.tight_layout()

fig.savefig(f'{OUTPUTS_DIR}/toysetup_{"sc1" if not IS_OOD else "sc2"}.png', dpi=300, bbox_inches='tight')

## Diffusion Setup

Reference: https://github.com/tanelp/tiny-diffusion

In [None]:
NUM_SAMPLES = 1000
RESAMPLE_NOISE = False # CAUTION: IT WILL OVERWRITE THE NOISE SAMPLES!!!!

if RESAMPLE_NOISE:
    print('Sampling noise')
    NOISE_SAMPLES = torch.randn(NUM_SAMPLES, 2)
    torch.save(NOISE_SAMPLES, f'{OUTPUTS_DIR}/noise.pt')
else:
    print('Loading noise samples')
    NOISE_SAMPLES = torch.load(f'{OUTPUTS_DIR}/noise.pt')

NOISE_SAMPLES = NOISE_SAMPLES.to(DEVICE)

In [7]:
class SinusoidalEmbedding(nn.Module):
    def __init__(self, size: int, scale: float = 1.0):
        super().__init__()
        self.size = size
        self.scale = scale

    def forward(self, x: torch.Tensor):
        x = x * self.scale
        half_size = self.size // 2
        emb = torch.log(torch.Tensor([10000.0])) / (half_size - 1)
        emb = torch.exp(-emb * torch.arange(half_size)).to(x.device)
        emb = x.unsqueeze(-1) * emb.unsqueeze(0)
        emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1)
        return emb

    def __len__(self):
        return self.size


class LinearEmbedding(nn.Module):
    def __init__(self, size: int, scale: float = 1.0):
        super().__init__()
        self.size = size
        self.scale = scale

    def forward(self, x: torch.Tensor):
        x = x / self.size * self.scale
        return x.unsqueeze(-1)

    def __len__(self):
        return 1


class LearnableEmbedding(nn.Module):
    def __init__(self, size: int):
        super().__init__()
        self.size = size
        self.linear = nn.Linear(1, size)

    def forward(self, x: torch.Tensor):
        return self.linear(x.unsqueeze(-1).float() / self.size)

    def __len__(self):
        return self.size


class IdentityEmbedding(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor):
        return x.unsqueeze(-1)

    def __len__(self):
        return 1


class ZeroEmbedding(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor):
        return x.unsqueeze(-1) * 0

    def __len__(self):
        return 1


class PositionalEmbedding(nn.Module):
    def __init__(self, size: int, type: str, **kwargs):
        super().__init__()

        if type == "sinusoidal":
            self.layer = SinusoidalEmbedding(size, **kwargs)
        elif type == "linear":
            self.layer = LinearEmbedding(size, **kwargs)
        elif type == "learnable":
            self.layer = LearnableEmbedding(size)
        elif type == "zero":
            self.layer = ZeroEmbedding()
        elif type == "identity":
            self.layer = IdentityEmbedding()
        else:
            raise ValueError(f"Unknown positional embedding type: {type}")

    def forward(self, x: torch.Tensor):
        return self.layer(x)
class Block(nn.Module):
    def __init__(self, size: int):
        super().__init__()

        self.ff = nn.Linear(size, size)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor):
        return x + self.act(self.ff(x))


class MLP(nn.Module):
    def __init__(self, hidden_size: int = 128, hidden_layers: int = 3, emb_size: int = 128,
                 time_emb: str = "sinusoidal", input_emb: str = "sinusoidal"):
        super().__init__()

        self.time_mlp = PositionalEmbedding(emb_size, time_emb)
        self.input_mlp1 = PositionalEmbedding(emb_size, input_emb, scale=25.0)
        self.input_mlp2 = PositionalEmbedding(emb_size, input_emb, scale=25.0)

        concat_size = len(self.time_mlp.layer) + \
            len(self.input_mlp1.layer) + len(self.input_mlp2.layer)
        layers = [nn.Linear(concat_size, hidden_size), nn.GELU()]
        for _ in range(hidden_layers):
            layers.append(Block(hidden_size))
        layers.append(nn.Linear(hidden_size, 2))
        self.joint_mlp = nn.Sequential(*layers)

    def forward(self, x, t):
        x1_emb = self.input_mlp1(x[:, 0]).to(x.device)
        x2_emb = self.input_mlp2(x[:, 1]).to(x.device)
        t_emb = self.time_mlp(t).to(x.device)
        x = torch.cat((x1_emb, x2_emb, t_emb), dim=-1)
        x = self.joint_mlp(x)
        return x

class Conditional_MLP(nn.Module):
    def __init__(self, hidden_size: int = 128, hidden_layers: int = 3, emb_size: int = 128,
                 time_emb: str = "sinusoidal", input_emb: str = "sinusoidal"):
        super().__init__()

        self.time_mlp = PositionalEmbedding(emb_size, time_emb)
        self.input_mlp1 = PositionalEmbedding(emb_size, input_emb, scale=25.0)
        self.input_mlp2 = PositionalEmbedding(emb_size, input_emb, scale=25.0)
        self.class_mlp = PositionalEmbedding(emb_size, input_emb, scale=25.0)

        concat_size = len(self.time_mlp.layer) + \
            len(self.input_mlp1.layer) + len(self.input_mlp2.layer) + len(self.class_mlp.layer)
        layers = [nn.Linear(concat_size, hidden_size), nn.GELU()]
        for _ in range(hidden_layers):
            layers.append(Block(hidden_size))
        layers.append(nn.Linear(hidden_size, 2))
        self.joint_mlp = nn.Sequential(*layers)

    def forward(self, x, y, t):
        # print(f'{x.shape} {y.shape} {t.shape}')
        x1_emb = self.input_mlp1(x[:, 0]).to(x.device)
        x2_emb = self.input_mlp2(x[:, 1]).to(x.device)
        t_emb = self.time_mlp(t).to(x.device)
        y_emb = self.class_mlp(y[:,0]).to(x.device)
        # print(f'{x1_emb.shape} {x2_emb.shape} {t_emb.shape} {y_emb.shape}')
        x = torch.cat((x1_emb, x2_emb, t_emb, y_emb), dim=-1)
        x = self.joint_mlp(x)
        return x

class NoiseScheduler():
    def __init__(self,
                 num_timesteps=1000,
                 beta_start=0.0001,
                 beta_end=0.02,
                 beta_schedule="linear"):

        self.num_timesteps = num_timesteps
        if beta_schedule == "linear":
            self.betas = torch.linspace(
                beta_start, beta_end, num_timesteps, dtype=torch.float32)
        elif beta_schedule == "quadratic":
            self.betas = torch.linspace(
                beta_start ** 0.5, beta_end ** 0.5, num_timesteps, dtype=torch.float32) ** 2

        self.alphas = 1.0 - self.betas
        self.sqrt_alphas =  self.alphas ** 0.5
        self.sqrt_betas = self.betas ** 0.5
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(
            self.alphas_cumprod[:-1], (1, 0), value=1.)

        # required for self.add_noise
        self.sqrt_alphas_cumprod = self.alphas_cumprod ** 0.5
        self.sqrt_one_minus_alphas_cumprod = (1 - self.alphas_cumprod) ** 0.5

        # required for reconstruct_x0
        self.sqrt_inv_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod)
        self.sqrt_inv_alphas_cumprod_minus_one = torch.sqrt(
            1 / self.alphas_cumprod - 1)

        # required for q_posterior
        self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1. - self.alphas_cumprod)

    def reconstruct_x0(self, x_t, t, noise):
        s1 = self.sqrt_inv_alphas_cumprod[t]
        s2 = self.sqrt_inv_alphas_cumprod_minus_one[t]
        s1 = s1.reshape(-1, 1).to(x_t.device)
        s2 = s2.reshape(-1, 1).to(x_t.device)
        return s1 * x_t - s2 * noise

    def q_posterior(self, x_0, x_t, t):
        s1 = self.posterior_mean_coef1[t]
        s2 = self.posterior_mean_coef2[t]
        s1 = s1.reshape(-1, 1).to(x_t.device)
        s2 = s2.reshape(-1, 1).to(x_t.device)
        mu = s1 * x_0 + s2 * x_t
        return mu

    def get_variance(self, t):
        if t == 0:
            return 0

        variance = self.betas[t] * (1. - self.alphas_cumprod_prev[t]) / (1. - self.alphas_cumprod[t])
        variance = variance.clip(1e-20)
        return variance

    def step(self, model_output, timestep, sample):
        t = timestep
        pred_original_sample = self.reconstruct_x0(sample, t, model_output)
        pred_prev_sample = self.q_posterior(pred_original_sample, sample, t)

        variance = 0
        if t > 0:
            noise = torch.randn_like(model_output).to(model_output.device)
            variance = (self.get_variance(t) ** 0.5) * noise

        pred_prev_sample = pred_prev_sample + variance

        return pred_original_sample, pred_prev_sample

    def step_wgrad(self, model_output, timestep, sample, grad, scale = 0.5):

        t = timestep

        ## Change
        res = (scale * self.sqrt_one_minus_alphas_cumprod[t] * grad.float())
        model_output = model_output - (scale * self.sqrt_one_minus_alphas_cumprod[t] * grad.float())

        pred_original_sample = self.reconstruct_x0(sample, t, model_output)
        pred_prev_sample = self.q_posterior(pred_original_sample, sample, t)

        variance = 0
        if t > 0:
            noise = torch.randn_like(model_output).to(model_output.device)
            variance = (self.get_variance(t) ** 0.5) * noise

        pred_prev_sample = pred_prev_sample + variance

        return res, pred_prev_sample

    def add_noise(self, x_start, x_noise, timesteps):
        s1 = self.sqrt_alphas_cumprod[timesteps]
        s2 = self.sqrt_one_minus_alphas_cumprod[timesteps]

        s1 = s1.reshape(-1, 1).to(x_start.device)
        s2 = s2.reshape(-1, 1).to(x_start.device)

        return s1 * x_start + s2 * x_noise

    def __len__(self):
        return self.num_timesteps
# Define beta schedule
T = 1000
noise_scheduler = NoiseScheduler(
      num_timesteps=T,
      beta_schedule="linear")

def forward_diffusion_sample(x_0, t):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)

    return noise_scheduler.add_noise(x_0, noise, t)

## Benchmarking Utils

In [6]:
# @title Compute KL Divergence

def kl_mvn(m0, S0, m1, S1):
    """
    https://stackoverflow.com/questions/44549369/kullback-leibler-divergence-from-gaussian-pm-pv-to-gaussian-qm-qv
    The following function computes the KL-Divergence between any two
    multivariate normal distributions
    (no need for the covariance matrices to be diagonal)
    Kullback-Liebler divergence from Gaussian pm,pv to Gaussian qm,qv.
    Also computes KL divergence from a single Gaussian pm,pv to a set
    of Gaussians qm,qv.
    Diagonal covariances are assumed.  Divergence is expressed in nats.
    - accepts stacks of means, but only one S0 and S1
    From wikipedia
    KL( (m0, S0) || (m1, S1))
         = .5 * ( tr(S1^{-1} S0) + log |S1|/|S0| +
                  (m1 - m0)^T S1^{-1} (m1 - m0) - N )
    # 'diagonal' is [1, 2, 3, 4]
    tf.diag(diagonal) ==> [[1, 0, 0, 0]
                          [0, 2, 0, 0]
                          [0, 0, 3, 0]
                          [0, 0, 0, 4]]
    # See wikipedia on KL divergence special case.
    #KL = 0.5 * tf.reduce_sum(1 + t_log_var - K.square(t_mean) - K.exp(t_log_var), axis=1)
                if METHOD['name'] == 'kl_pen':
                self.tflam = tf.placeholder(tf.float32, None, 'lambda')
                kl = tf.distributions.kl_divergence(oldpi, pi)
                self.kl_mean = tf.reduce_mean(kl)
                self.aloss = -(tf.reduce_mean(surr - self.tflam * kl))
    """
    # store inv diag covariance of S1 and diff between means
    N = m0.shape[0]
    try:
        iS1 = np.linalg.inv(S1)
    except:
        print(S1)
        print(np.linalg.det(S1))
        raise ValueError()

    diff = m1 - m0

    # kl is made of three terms
    tr_term   = np.trace(iS1 @ S0)
    det_term  = np.log(np.linalg.det(S1)/np.linalg.det(S0)) #np.sum(np.log(S1)) - np.sum(np.log(S0))
    quad_term = diff.T @ np.linalg.inv(S1) @ diff #np.sum( (diff*diff) * iS1, axis=1)
    #print(tr_term,det_term,quad_term)
    return .5 * (tr_term + det_term + quad_term - N)

In [7]:
# @title Compute MMD (maximum mean discrepancy)

import numpy as np
from sklearn import metrics


def mmd_linear(X, Y):
    """MMD using linear kernel (i.e., k(x,y) = <x,y>)
    Note that this is not the original linear MMD, only the reformulated and faster version.
    The original version is:
        def mmd_linear(X, Y):
            XX = np.dot(X, X.T)
            YY = np.dot(Y, Y.T)
            XY = np.dot(X, Y.T)
            return XX.mean() + YY.mean() - 2 * XY.mean()

    Arguments:
        X {[n_sample1, dim]} -- [X matrix]
        Y {[n_sample2, dim]} -- [Y matrix]

    Returns:
        [scalar] -- [MMD value]
    """
    delta = X.mean(0) - Y.mean(0)
    return delta.dot(delta.T)


def mmd_rbf(X, Y, gamma=0.1):
    """MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2))

    Arguments:
        X {[n_sample1, dim]} -- [X matrix]
        Y {[n_sample2, dim]} -- [Y matrix]

    Keyword Arguments:
        gamma {float} -- [kernel parameter] (default: {1.0})

    Returns:
        [scalar] -- [MMD value]
    """
    XX = metrics.pairwise.rbf_kernel(X, X, gamma)
    YY = metrics.pairwise.rbf_kernel(Y, Y, gamma)
    XY = metrics.pairwise.rbf_kernel(X, Y, gamma)
    return XX.mean() + YY.mean() - 2 * XY.mean()


def mmd_poly(X, Y, degree=2, gamma=1, coef0=0):
    """MMD using polynomial kernel (i.e., k(x,y) = (gamma <X, Y> + coef0)^degree)

    Arguments:
        X {[n_sample1, dim]} -- [X matrix]
        Y {[n_sample2, dim]} -- [Y matrix]

    Keyword Arguments:
        degree {int} -- [degree] (default: {2})
        gamma {int} -- [gamma] (default: {1})
        coef0 {int} -- [constant item] (default: {0})

    Returns:
        [scalar] -- [MMD value]
    """
    XX = metrics.pairwise.polynomial_kernel(X, X, degree, gamma, coef0)
    YY = metrics.pairwise.polynomial_kernel(Y, Y, degree, gamma, coef0)
    XY = metrics.pairwise.polynomial_kernel(X, Y, degree, gamma, coef0)
    return XX.mean() + YY.mean() - 2 * XY.mean()

In [8]:
# @title Calculate Frechet Distance (FD)

from scipy.linalg import sqrtm

# calculate frechet inception distance
def calculate_fid(mu1, sigma1, mu2, sigma2):

	# calculate sum squared difference between means
	ssdiff = np.sum((mu1 - mu2)**2.0)
	# calculate sqrt of product between cov
	covmean = sqrtm(sigma1.dot(sigma2))
	# check and correct imaginary numbers from sqrt
	if np.iscomplexobj(covmean):
		covmean = covmean.real
	# calculate score
	fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)

	return fid

## Experiments - Non-Gaussian

In [None]:
LOAD_DATA_FROM_FILE = True

if LOAD_DATA_FROM_FILE:
    dataX = np.load(f'{OUTPUTS_DIR}/dataX_newset.npy')
else:
    mu_b_1 = torch.tensor([5.0, 3.0]).to('cuda')
    Cov_b_1 = torch.tensor([2.0, 2.0]).to('cuda')

    mu_b_2 = torch.tensor([3.0, 7.0]).to('cuda')
    Cov_b_2 = torch.tensor([2.0, 2.0]).to('cuda')

    mu_b_3 = torch.tensor([7.0, 7.0]).to('cuda')
    Cov_b_3 = torch.tensor([2.0, 2.0]).to('cuda')

    gmm_b = GaussianMixture([mu_b_1, mu_b_2, mu_b_3],[Cov_b_1, Cov_b_2, Cov_b_3], [1.0, 1.0, 1.0])

    dataX, _, _ = gmm_b.sample(3000) # np.random.multivariate_normal([5.0, 5.0], cov, size=3000)

    np.save(f'{OUTPUTS_DIR}/dataX_newset.npy', dataX)


data_df = pd.DataFrame({'x1': dataX[:, 0], 'x2': dataX[:, 1]})

fig, ax = plt.subplots()

sns.kdeplot(data_df, x='x1', y='x2', fill=True, ax=ax, cmap="Blues")
ax.scatter(dataX[:, 0], dataX[:, 1], c='b', alpha=0.5, s=10)
ax.axis([-2.5,12.5,-2.5,12.5])

plt.grid()
plt.tight_layout()

plt.show()

In [None]:
data_x = torch.from_numpy(dataX.astype(np.float32))

simpledataset = TensorDataset(data_x) # create your datset
dataloader = DataLoader(simpledataset, shuffle=True, batch_size=32) # create your dataloader
model = MLP(
        hidden_size=128,
        hidden_layers=3,
        emb_size=128,
        time_emb="sinusoidal",
        input_emb="identity")

optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=1e-3,
    )

num_epochs = 200

print("Training model...")

progress_bar = tqdm(total=num_epochs)
for epoch in range(num_epochs):
    model.train()
    progress_bar.set_description(f"Epoch {epoch}")

    losses = []
    for step, batch in enumerate(dataloader):

        # print(batch)

        input = batch[0]

        timesteps = torch.randint(
            0, noise_scheduler.num_timesteps, (input.shape[0],)
        ).long()
        noise = torch.randn(input.shape)
        noisy = noise_scheduler.add_noise(input, noise, timesteps)
        noise_pred = model(noisy, timesteps)
        loss = F.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward(loss)
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        losses.append(loss.detach().item())

    progress_bar.update(1)
    logs = {"loss": sum(losses)/len(losses)}
    progress_bar.set_postfix(**logs)

progress_bar.close()

torch.save(model, f'{OUTPUTS_DIR}/model_uncond_newset.pt')

In [None]:
# @title Unconditional sampling

export_path = Path(OUTPUTS_DIR).joinpath('uncond')
if not Path.exists(export_path):
    Path.mkdir(export_path, exist_ok=True, parents=True)

model = torch.load(f'{OUTPUTS_DIR}/model_uncond_newset.pt')
model = model.to(DEVICE)
model.eval()

curr_sample = copy.deepcopy(NOISE_SAMPLES)

timesteps = list(range(len(noise_scheduler)))[::-1]
for i, t in enumerate(tqdm(timesteps)):
    t = torch.from_numpy(np.repeat(t, NUM_SAMPLES)).long()
    with torch.no_grad():
        residual = model(curr_sample, t)

    _, curr_sample = noise_scheduler.step(residual, t[0], curr_sample)

curr_sample = curr_sample.detach().cpu()

np.save(export_path.joinpath('gensample.npy'), curr_sample.numpy())

data_df = pd.DataFrame({'x1': curr_sample[:, 0], 'x2': curr_sample[:, 1]})

fig, ax = plt.subplots()

sns.kdeplot(data_df, x='x1', y='x2', fill=True, ax=ax, cmap="Blues")
ax.scatter(curr_sample[:, 0], curr_sample[:, 1], c='b', alpha=0.5, s=10)
ax.axis([-2.5,12.5,-2.5,12.5])

plt.grid()
plt.tight_layout()
plt.savefig(export_path.joinpath('gensample.png'), dpi=300)

In [77]:
# @title Gradient guidance

IS_OOD = False

if not IS_OOD:
    mu_r1 = torch.tensor([5.0, 3.0]).to('cuda')
    Cov_r1 = torch.tensor([1.0, 1.0]).to('cuda')

    gmm = GaussianMixture([mu_r1],[Cov_r1], [1.0])

else:
    mu_r2 = torch.tensor([9.0, 3.0]).to('cuda')
    Cov_r2 = torch.tensor([2.0, 2.0]).to('cuda')

    gmm = GaussianMixture([mu_r2],[Cov_r2], [1.0])

model = torch.load(f'{OUTPUTS_DIR}/model_uncond_newset.pt')
model = model.to(DEVICE)
model.eval()

if not IS_OOD:
    export_path = Path(OUTPUTS_DIR).joinpath('gg_id')
else:
    export_path = Path(OUTPUTS_DIR).joinpath('gg_ood')

if not Path.exists(export_path):
    Path.mkdir(export_path, exist_ok=True, parents=True)

guidance_scales = np.arange(1,20,1).tolist()
guidance_scales.extend(np.arange(20, 55, 5).tolist())

for guidance_scale in tqdm(guidance_scales, total=len(guidance_scales)):

    curr_path = export_path.joinpath(f'scale_{guidance_scale}')
    if not Path.exists(curr_path):
        Path.mkdir(curr_path, exist_ok=True, parents=True)

    curr_sample = copy.deepcopy(NOISE_SAMPLES)

    timesteps = list(range(len(noise_scheduler)))[::-1]
    for i, t in enumerate(timesteps):
        t = torch.from_numpy(np.repeat(t, NUM_SAMPLES)).long()
        with torch.no_grad():
            residual = model(curr_sample, t)

        grad = gmm.score(curr_sample)

        res, curr_sample = noise_scheduler.step_wgrad(residual, t[0], curr_sample, grad, scale=guidance_scale)

    curr_sample = curr_sample.detach().cpu()

    np.save(curr_path.joinpath('gensample.npy'), curr_sample.numpy())

    data_df = pd.DataFrame({'x1': curr_sample[:, 0], 'x2': curr_sample[:, 1]})

    fig, ax = plt.subplots()

    sns.kdeplot(data_df, x='x1', y='x2', fill=True, ax=ax, cmap="Blues")
    ax.scatter(curr_sample[:, 0], curr_sample[:, 1], c='b', alpha=0.5, s=10)
    ax.axis([0, 10 , 0, 10])

    plt.grid()
    plt.tight_layout()
    plt.savefig(curr_path.joinpath('gensample.png'), dpi=300)

In [43]:
# @title Universal guidance

IS_OOD = True

if not IS_OOD:
    mu_r1 = torch.tensor([5.0, 3.0]).to('cuda')
    Cov_r1 = torch.tensor([1.0, 1.0]).to('cuda')

    gmm = GaussianMixture([mu_r1],[Cov_r1], [1.0])

else:
    mu_r2 = torch.tensor([9.0, 3.0]).to('cuda')
    Cov_r2 = torch.tensor([2.0, 2.0]).to('cuda')

    gmm = GaussianMixture([mu_r2],[Cov_r2], [1.0])

model = torch.load(f'{OUTPUTS_DIR}/model_uncond_newset.pt')
model = model.to(DEVICE)
model.eval()

if not IS_OOD:
    export_path = Path(OUTPUTS_DIR).joinpath('unvg_id')
else:
    export_path = Path(OUTPUTS_DIR).joinpath('unvg_ood')

if not Path.exists(export_path):
    Path.mkdir(export_path, exist_ok=True, parents=True)

guidance_scales = np.arange(1,20,1).tolist()
guidance_scales.extend(np.arange(20, 55, 5).tolist())

rsteps = 1

for guidance_scale in tqdm(guidance_scales, total=len(guidance_scales)):

    curr_path = export_path.joinpath(f'scale_{guidance_scale}')
    if not Path.exists(curr_path):
        Path.mkdir(curr_path, exist_ok=True, parents=True)

    sample = copy.deepcopy(NOISE_SAMPLES)

    timesteps = list(range(len(noise_scheduler)))[::-1]
    for i, t in enumerate(tqdm(timesteps)):
        t = torch.from_numpy(np.repeat(t, NUM_SAMPLES)).long()

        for k in range(rsteps):

            with torch.no_grad():
                residual = model(sample, t)

            residual = copy.deepcopy(residual)

            # forward guidance
            grad = gmm.score(sample)
            residual = residual - (guidance_scale * noise_scheduler.sqrt_one_minus_alphas_cumprod[t[0]] * grad.float())

            if True: # SGD

                param = copy.deepcopy(sample)
                for i in range(1):
                    grad_sgd = gmm.score(param)
                    param = param + (0.01 * grad_sgd)

                deltas = param - sample

                # backward guidance
                residual = residual - (noise_scheduler.sqrt_alphas_cumprod[t[0]] / noise_scheduler.sqrt_one_minus_alphas_cumprod[t[0]]) * deltas

            _, sample = noise_scheduler.step(residual, t[0], sample)

        if k < (rsteps - 1):
            noise = torch.randn_like(sample).to(sample.device)
            sample = (noise_scheduler.sqrt_alphas[t[0]]) * sample + (noise_scheduler.sqrt_betas[t[0]]) * noise

    curr_sample = sample.detach().cpu()

    np.save(curr_path.joinpath('gensample.npy'), curr_sample.numpy())

    data_df = pd.DataFrame({'x1': curr_sample[:, 0], 'x2': curr_sample[:, 1]})

    fig, ax = plt.subplots()

    sns.kdeplot(data_df, x='x1', y='x2', fill=True, ax=ax, cmap="Blues")
    ax.scatter(curr_sample[:, 0], curr_sample[:, 1], c='b', alpha=0.5, s=10)
    ax.axis([0, 10 , 0, 10])

    plt.grid()
    plt.tight_layout()
    plt.savefig(curr_path.joinpath('gensample.png'), dpi=300)

In [None]:
# @title Controlled Decoding (Block-wise)

IS_OOD = True

if not IS_OOD:
    mu_r1 = torch.tensor([5.0, 3.0]).to('cuda')
    Cov_r1 = torch.tensor([1.0, 1.0]).to('cuda')

    gmm = GaussianMixture([mu_r1],[Cov_r1], [1.0])

else:
    mu_r2 = torch.tensor([9.0, 3.0]).to('cuda')
    Cov_r2 = torch.tensor([3.0, 3.0]).to('cuda')

    gmm = GaussianMixture([mu_r2],[Cov_r2], [1.0])

model = torch.load(f'{OUTPUTS_DIR}/model_uncond_newset.pt')
model = model.to(DEVICE)
model.eval()

bsize = 1

if not IS_OOD:
    export_path = Path(OUTPUTS_DIR).joinpath(f'CD_b{bsize}_id')
else:
    export_path = Path(OUTPUTS_DIR).joinpath(f'CD_b{bsize}_ood')

if not Path.exists(export_path):
    Path.mkdir(export_path, exist_ok=True, parents=True)

n_vals = np.arange(2, 20, 2).tolist()
n_vals.extend(np.arange(20, 110, 10).tolist())

for n_val in tqdm(n_vals, total=len(n_vals)):

    curr_path = export_path.joinpath(f'n_{n_val}')
    if not Path.exists(curr_path):
        Path.mkdir(curr_path, exist_ok=True, parents=True)

    curr_sample = copy.deepcopy(NOISE_SAMPLES)
    curr_sample = curr_sample.repeat(n_val, 1)

    timesteps = list(range(len(noise_scheduler)))[::-1]
    for i, t in enumerate(timesteps):
        t = torch.from_numpy(np.repeat(t, n_val * NUM_SAMPLES)).long().to(DEVICE)

        with torch.no_grad():
            residual = model(curr_sample, t)

        _, curr_sample = noise_scheduler.step(residual, t[0], curr_sample)

        if t[0] % bsize == 0: # at the end of block do BoN
            
            if t[0] > 0: # If not final step use estimates x0
                
                prev_t = torch.from_numpy(np.repeat(timesteps[i + 1], n_val * NUM_SAMPLES)).long().to(DEVICE)
                with torch.no_grad():
                    residual = model(curr_sample, prev_t)

                pred_x0 = noise_scheduler.reconstruct_x0(curr_sample, timesteps[i + 1], residual)
                reward = gmm.pdf(pred_x0)
            else:
                reward = gmm.pdf(curr_sample)

            reward = torch.cat([x.unsqueeze(0) for x in reward.chunk(n_val)], dim=0)
            select_ind = torch.max(reward, dim=0)[1]

            gen_sample = copy.deepcopy(curr_sample)
            gen_sample = torch.cat([x.unsqueeze(0) for x in gen_sample.chunk(n_val)], dim=0)
            gen_sample = gen_sample.permute(1,0,2)
            curr_sample = copy.deepcopy(torch.cat([x[select_ind[idx]].unsqueeze(0) for idx, x in enumerate(gen_sample)], dim=0)) # TODO: Make it efficient
        
            if t[0] > 0: # If not the end replicate n times
                curr_sample = curr_sample.repeat(n_val, 1)

    # raise ValueError()
    curr_sample = curr_sample.detach().cpu()

    np.save(curr_path.joinpath('gensample.npy'), curr_sample.numpy())

    data_df = pd.DataFrame({'x1': curr_sample[:, 0], 'x2': curr_sample[:, 1]})

    fig, ax = plt.subplots()

    # sns.kdeplot(data_df, x='x1', y='x2', fill=True, ax=ax, cmap="Blues")
    ax.scatter(curr_sample[:, 0], curr_sample[:, 1], c='b', alpha=0.5, s=10)
    ax.axis([0, 20 , 0, 20])

    plt.grid()
    plt.tight_layout()
    plt.savefig(curr_path.joinpath('gensample.png'), dpi=300)
    plt.close()

In [None]:
# @title BoN Sampling

IS_OOD = True

if not IS_OOD:
    mu_r1 = torch.tensor([5.0, 3.0]).to('cuda')
    Cov_r1 = torch.tensor([1.0, 1.0]).to('cuda')

    gmm = GaussianMixture([mu_r1],[Cov_r1], [1.0])

else:
    mu_r2 = torch.tensor([9.0, 3.0]).to('cuda')
    Cov_r2 = torch.tensor([2.0, 2.0]).to('cuda')

    gmm = GaussianMixture([mu_r2],[Cov_r2], [1.0])

model = torch.load(f'{OUTPUTS_DIR}/model_uncond_newset.pt')
model = model.to(DEVICE)
model.eval()

if not IS_OOD:
    export_path = Path(OUTPUTS_DIR).joinpath(f'BoN_id')
else:
    export_path = Path(OUTPUTS_DIR).joinpath(f'BoN_ood')

if not Path.exists(export_path):
    Path.mkdir(export_path, exist_ok=True, parents=True)

# n_vals = np.arange(2, 20, 2).tolist()
# n_vals.extend(np.arange(20, 110, 10).tolist())
n_vals = np.arange(200, 400, 50).tolist()

for n_val in tqdm(n_vals, total=len(n_vals)):

    curr_path = export_path.joinpath(f'n_{n_val}')
    if not Path.exists(curr_path):
        Path.mkdir(curr_path, exist_ok=True, parents=True)

    curr_sample = copy.deepcopy(NOISE_SAMPLES)
    curr_sample = curr_sample.repeat(n_val, 1)

    timesteps = list(range(len(noise_scheduler)))[::-1]
    for i, t in enumerate(timesteps):
        t = torch.from_numpy(np.repeat(t, n_val * NUM_SAMPLES)).long().to(DEVICE) ###### Debug
        with torch.no_grad():
            residual = model(curr_sample, t)

        _, curr_sample = noise_scheduler.step(residual, t[0], curr_sample)

    np.save(curr_path.joinpath('gen_uncond.npy'), curr_sample.detach().cpu().numpy())

    reward = gmm.pdf(curr_sample)
    reward = torch.cat([x.unsqueeze(0) for x in reward.chunk(n_val)], dim=0)

    # Find the direction that minimizes the loss
    select_ind = torch.max(reward, dim=0)[1]
    curr_sample = torch.cat([x.unsqueeze(0) for x in curr_sample.chunk(n_val)], dim=0)
    curr_sample = curr_sample.permute(1,0,2)
    result = copy.deepcopy(torch.cat([x[select_ind[idx]].unsqueeze(0) for idx, x in enumerate(curr_sample)], dim=0)) # TODO: Make it efficient
    result = result.detach().cpu()

    np.save(curr_path.joinpath('gensample.npy'), result.numpy())

    data_df = pd.DataFrame({'x1': result[:, 0], 'x2': result[:, 1]})

    fig, ax = plt.subplots()

    sns.kdeplot(data_df, x='x1', y='x2', fill=True, ax=ax, cmap="Blues")
    ax.scatter(result[:, 0], result[:, 1], c='b', alpha=0.5, s=10)
    ax.axis([0, 10 , 0, 10])

    plt.grid()
    plt.tight_layout()
    plt.savefig(curr_path.joinpath('gensample.png'), dpi=300)
    plt.close()

In [22]:
# @title SDEdit

IS_OOD = False

if not IS_OOD:
    mu_r1 = torch.tensor([5.0, 3.0]).to('cuda')
    Cov_r1 = torch.tensor([1.0, 1.0]).to('cuda')

    gmm = GaussianMixture([mu_r1],[Cov_r1], [1.0])

else:
    mu_r2 = torch.tensor([9.0, 3.0]).to('cuda')
    Cov_r2 = torch.tensor([2.0, 2.0]).to('cuda')

    gmm = GaussianMixture([mu_r2],[Cov_r2], [1.0])

if not IS_OOD:
    export_path = Path(OUTPUTS_DIR).joinpath(f'SDEdit_id')
else:
    export_path = Path(OUTPUTS_DIR).joinpath(f'SDEdit_ood')

if not Path.exists(export_path):
    Path.mkdir(export_path, exist_ok=True, parents=True)

model = torch.load(f'{OUTPUTS_DIR}/model_uncond_newset.pt')
model = model.to(DEVICE)
model.eval()

# strengths = np.arange(0.4,1.0,0.1).tolist()
strengths = np.arange(0.1,1.4,0.1).tolist()

target_samples, _, _ = gmm.sample(NUM_SAMPLES)
target_samples = torch.from_numpy(target_samples)

# print(target_samples.shape)

for strength in tqdm(strengths, total=len(strengths)):

    curr_path = export_path.joinpath(f'strength_{round(strength * 10)}')
    if not Path.exists(curr_path):
        Path.mkdir(curr_path, exist_ok=True, parents=True)

    curr_sample = copy.deepcopy(target_samples)

    curr_sample = forward_diffusion_sample(curr_sample, int(strength*noise_scheduler.num_timesteps))
    curr_sample = curr_sample.to(DEVICE)

    timesteps = list(range(len(noise_scheduler)))[::-1]
    for i, t in enumerate(tqdm(timesteps)):

        if t > strength*noise_scheduler.num_timesteps:
            continue

        t = torch.from_numpy(np.repeat(t, NUM_SAMPLES)).long()
        with torch.no_grad():
            residual = model(curr_sample, t)

        _, curr_sample = noise_scheduler.step(residual, t[0], curr_sample)

    curr_sample = curr_sample.detach().cpu()

    np.save(curr_path.joinpath('gensample.npy'), curr_sample.numpy())

    data_df = pd.DataFrame({'x1': curr_sample[:, 0], 'x2': curr_sample[:, 1]})

    fig, ax = plt.subplots()

    sns.kdeplot(data_df, x='x1', y='x2', fill=True, ax=ax, cmap="Blues")
    ax.scatter(curr_sample[:, 0], curr_sample[:, 1], c='b', alpha=0.5, s=10)
    ax.axis([0, 10 , 0, 10])

    plt.grid()
    plt.tight_layout()
    plt.savefig(curr_path.joinpath('gensample.png'), dpi=300)

In [None]:
# @title SDEdit + Gradient guidance

IS_OOD = True

if not IS_OOD:
    mu_r1 = torch.tensor([5.0, 3.0]).to('cuda')
    Cov_r1 = torch.tensor([1.0, 1.0]).to('cuda')

    gmm = GaussianMixture([mu_r1],[Cov_r1], [1.0])

else:
    mu_r2 = torch.tensor([9.0, 3.0]).to('cuda')
    Cov_r2 = torch.tensor([2.0, 2.0]).to('cuda')

    gmm = GaussianMixture([mu_r2],[Cov_r2], [1.0])

model = torch.load(f'{OUTPUTS_DIR}/model_uncond_newset.pt')
model = model.to(DEVICE)
model.eval()

if not IS_OOD:
    export_path = Path(OUTPUTS_DIR).joinpath('SDEdit_gg_id')
else:
    export_path = Path(OUTPUTS_DIR).joinpath('SDEdit_gg_ood')

if not Path.exists(export_path):
    Path.mkdir(export_path, exist_ok=True, parents=True)

strengths = np.arange(0.1,1.0,0.1).tolist()

guidance_scales = np.arange(1,20,1).tolist()
guidance_scales.extend(np.arange(20, 55, 5).tolist())

target_samples, _, _ = gmm.sample(NUM_SAMPLES)
target_samples = torch.from_numpy(target_samples)

for strength in tqdm(strengths, total=len(strengths)):

    for guidance_scale in guidance_scales:

        curr_path = export_path.joinpath(f'strength_{round(strength * 10)}_scale_{guidance_scale}')
        if not Path.exists(curr_path):
            Path.mkdir(curr_path, exist_ok=True, parents=True)

        curr_sample = copy.deepcopy(target_samples)

        curr_sample = forward_diffusion_sample(curr_sample, int(strength*noise_scheduler.num_timesteps))
        curr_sample = curr_sample.to(DEVICE)

        timesteps = list(range(len(noise_scheduler)))[::-1]
        for i, t in enumerate(timesteps):

            if t > strength*noise_scheduler.num_timesteps:
                continue

            t = torch.from_numpy(np.repeat(t, NUM_SAMPLES)).long()
            with torch.no_grad():
                residual = model(curr_sample, t)

            grad = gmm.score(curr_sample)

            res, curr_sample = noise_scheduler.step_wgrad(residual, t[0], curr_sample, grad, scale=guidance_scale)

        curr_sample = curr_sample.detach().cpu()

        np.save(curr_path.joinpath('gensample.npy'), curr_sample.numpy())

        data_df = pd.DataFrame({'x1': curr_sample[:, 0], 'x2': curr_sample[:, 1]})

        fig, ax = plt.subplots()

        sns.kdeplot(data_df, x='x1', y='x2', fill=True, ax=ax, cmap="Blues")
        ax.scatter(curr_sample[:, 0], curr_sample[:, 1], c='b', alpha=0.5, s=10)
        ax.axis([0, 10 , 0, 10])

        plt.grid()
        plt.tight_layout()
        plt.savefig(curr_path.joinpath('gensample.png'), dpi=300)
        plt.close()

In [None]:
# @title SDEdit + Universal guidance

IS_OOD = False

if not IS_OOD:
    mu_r1 = torch.tensor([5.0, 3.0]).to('cuda')
    Cov_r1 = torch.tensor([1.0, 1.0]).to('cuda')

    gmm = GaussianMixture([mu_r1],[Cov_r1], [1.0])

else:
    mu_r2 = torch.tensor([9.0, 3.0]).to('cuda')
    Cov_r2 = torch.tensor([2.0, 2.0]).to('cuda')

    gmm = GaussianMixture([mu_r2],[Cov_r2], [1.0])

model = torch.load(f'{OUTPUTS_DIR}/model_uncond_newset.pt')
model = model.to(DEVICE)
model.eval()

if not IS_OOD:
    export_path = Path(OUTPUTS_DIR).joinpath('SDEdit_unvg_id')
else:
    export_path = Path(OUTPUTS_DIR).joinpath('SDEdit_unvg_ood')

if not Path.exists(export_path):
    Path.mkdir(export_path, exist_ok=True, parents=True)

guidance_scales = np.arange(1,20,1).tolist()
guidance_scales.extend(np.arange(20, 55, 5).tolist())

rsteps = 1

strengths = np.arange(0.1,1.0,0.1).tolist()
# strengths = np.arange(0.1,1.4,0.1).tolist()

target_samples, _, _ = gmm.sample(NUM_SAMPLES)
target_samples = torch.from_numpy(target_samples)

for strength in tqdm(strengths, total=len(strengths)):

    for guidance_scale in tqdm(guidance_scales, total=len(guidance_scales)):

        curr_path = export_path.joinpath(f'strength_{round(strength * 10)}_scale_{guidance_scale}')
        if not Path.exists(curr_path):
            Path.mkdir(curr_path, exist_ok=True, parents=True)

        sample = copy.deepcopy(target_samples)

        sample = forward_diffusion_sample(sample, int(strength*noise_scheduler.num_timesteps))
        sample = sample.to(DEVICE)

        timesteps = list(range(len(noise_scheduler)))[::-1]
        for i, t in enumerate(timesteps):

            if t > strength*noise_scheduler.num_timesteps:
                continue

            t = torch.from_numpy(np.repeat(t, NUM_SAMPLES)).long()

            for k in range(rsteps):

                with torch.no_grad():
                    residual = model(sample, t)

                residual = copy.deepcopy(residual)

                # forward guidance
                grad = gmm.score(sample)
                residual = residual - (guidance_scale * noise_scheduler.sqrt_one_minus_alphas_cumprod[t[0]] * grad.float())

                if True: # SGD

                    param = copy.deepcopy(sample)
                    for i in range(1):
                        grad_sgd = gmm.score(param)
                        param = param + (0.01 * grad_sgd)

                    deltas = param - sample

                    # backward guidance
                    residual = residual - (noise_scheduler.sqrt_alphas_cumprod[t[0]] / noise_scheduler.sqrt_one_minus_alphas_cumprod[t[0]]) * deltas

                _, sample = noise_scheduler.step(residual, t[0], sample)

            if k < (rsteps - 1):
                noise = torch.randn_like(sample).to(sample.device)
                sample = (noise_scheduler.sqrt_alphas[t[0]]) * sample + (noise_scheduler.sqrt_betas[t[0]]) * noise

        curr_sample = sample.detach().cpu()

        np.save(curr_path.joinpath('gensample.npy'), curr_sample.numpy())

        data_df = pd.DataFrame({'x1': curr_sample[:, 0], 'x2': curr_sample[:, 1]})

        fig, ax = plt.subplots()

        sns.kdeplot(data_df, x='x1', y='x2', fill=True, ax=ax, cmap="Blues")
        ax.scatter(curr_sample[:, 0], curr_sample[:, 1], c='b', alpha=0.5, s=10)
        ax.axis([0, 10 , 0, 10])

        plt.grid()
        plt.tight_layout()
        plt.savefig(curr_path.joinpath('gensample.png'), dpi=300)
        plt.close()

In [34]:
# @title SDEdit + Controlled Decoding (Block-wise)

IS_OOD = False

if not IS_OOD:
    mu_r1 = torch.tensor([5.0, 3.0]).to('cuda')
    Cov_r1 = torch.tensor([1.0, 1.0]).to('cuda')

    gmm = GaussianMixture([mu_r1],[Cov_r1], [1.0])

else:
    mu_r2 = torch.tensor([9.0, 3.0]).to('cuda')
    Cov_r2 = torch.tensor([2.0, 2.0]).to('cuda')

    gmm = GaussianMixture([mu_r2],[Cov_r2], [1.0])

model = torch.load(f'{OUTPUTS_DIR}/model_uncond_newset.pt')
model = model.to(DEVICE)
model.eval()

for bsize in [1, 10, 20, 40, 80, 100]:

    if not IS_OOD:
        export_path = Path(OUTPUTS_DIR).joinpath(f'SDEdit_CD_b{bsize}_id')
    else:
        export_path = Path(OUTPUTS_DIR).joinpath(f'SDEdit_CD_b{bsize}_ood')

    if not Path.exists(export_path):
        Path.mkdir(export_path, exist_ok=True, parents=True)

    strengths = np.arange(0.1,1.0,0.1).tolist()

    n_vals = np.arange(2, 20, 2).tolist()
    n_vals.extend(np.arange(20, 110, 10).tolist())

    target_samples, _, _ = gmm.sample(NUM_SAMPLES)
    target_samples = torch.from_numpy(target_samples)

    for strength in tqdm(strengths, total=len(strengths)):

        for n_val in n_vals:

            curr_path = export_path.joinpath(f'strength_{round(strength * 10)}_n_{n_val}')
            if not Path.exists(curr_path):
                Path.mkdir(curr_path, exist_ok=True, parents=True)

            curr_sample = copy.deepcopy(target_samples)
            curr_sample = forward_diffusion_sample(curr_sample, int(strength*noise_scheduler.num_timesteps))
            curr_sample = curr_sample.to(DEVICE)

            curr_sample = curr_sample.repeat(n_val, 1)

            counter = 0
            timesteps = list(range(len(noise_scheduler)))[::-1]
            for i, t in enumerate(timesteps):
                if t > strength*noise_scheduler.num_timesteps:
                    continue

                counter += 1
        
                t = torch.from_numpy(np.repeat(t, n_val * NUM_SAMPLES)).long().to(DEVICE) ###### Debug

                with torch.no_grad():
                    residual = model(curr_sample, t)

                _, curr_sample = noise_scheduler.step(residual, t[0], curr_sample)

                if (counter + 1) % bsize == 0 or t[0] == timesteps[-1]: # at the end of block do BoN
                    
                    if t[0] > timesteps[-1]: # If not final step use estimates x0
                        
                        prev_t = torch.from_numpy(np.repeat(timesteps[i + 1], n_val * NUM_SAMPLES)).long().to(DEVICE)
                        with torch.no_grad():
                            residual = model(curr_sample, prev_t)

                        pred_x0 = noise_scheduler.reconstruct_x0(curr_sample, timesteps[i + 1], residual)
                        reward = gmm.pdf(pred_x0)
                    else:
                        reward = gmm.pdf(curr_sample)

                    reward = torch.cat([x.unsqueeze(0) for x in reward.chunk(n_val)], dim=0)
                    select_ind = torch.max(reward, dim=0)[1]

                    gen_sample = copy.deepcopy(curr_sample)
                    gen_sample = torch.cat([x.unsqueeze(0) for x in gen_sample.chunk(n_val)], dim=0)
                    gen_sample = gen_sample.permute(1,0,2)
                    curr_sample = copy.deepcopy(torch.cat([x[select_ind[idx]].unsqueeze(0) for idx, x in enumerate(gen_sample)], dim=0)) # TODO: Make it efficient
                
                    if t[0] > timesteps[-1]: # If not the end replicate n times
                        curr_sample = curr_sample.repeat(n_val, 1)

            # print(curr_sample.shape)

            # raise ValueError()
            curr_sample = curr_sample.detach().cpu()

            np.save(curr_path.joinpath('gensample.npy'), curr_sample.numpy())

            data_df = pd.DataFrame({'x1': curr_sample[:, 0], 'x2': curr_sample[:, 1]})

            fig, ax = plt.subplots()

            sns.kdeplot(data_df, x='x1', y='x2', fill=True, ax=ax, cmap="Blues")
            ax.scatter(curr_sample[:, 0], curr_sample[:, 1], c='b', alpha=0.5, s=10)
            ax.axis([0, 10 , 0, 10])

            plt.grid()
            plt.tight_layout()
            plt.savefig(curr_path.joinpath('gensample.png'), dpi=300)
            plt.close()

In [36]:
# @title SDEdit + BoN Sampling

IS_OOD = True

if not IS_OOD:
    mu_r1 = torch.tensor([5.0, 3.0]).to('cuda')
    Cov_r1 = torch.tensor([1.0, 1.0]).to('cuda')

    gmm = GaussianMixture([mu_r1],[Cov_r1], [1.0])

else:
    mu_r2 = torch.tensor([9.0, 3.0]).to('cuda')
    Cov_r2 = torch.tensor([2.0, 2.0]).to('cuda')

    gmm = GaussianMixture([mu_r2],[Cov_r2], [1.0])

model = torch.load(f'{OUTPUTS_DIR}/model_uncond_newset.pt')
model = model.to(DEVICE)
model.eval()

if not IS_OOD:
    export_path = Path(OUTPUTS_DIR).joinpath(f'SDEdit_BoN_id')
else:
    export_path = Path(OUTPUTS_DIR).joinpath(f'SDEdit_BoN_ood')

if not Path.exists(export_path):
    Path.mkdir(export_path, exist_ok=True, parents=True)


strengths = np.arange(0.1,0.4,0.1).tolist()

n_vals = np.arange(2, 12, 2).tolist()
n_vals.extend(np.arange(20, 110, 10).tolist())

target_samples, _, _ = gmm.sample(NUM_SAMPLES)
target_samples = torch.from_numpy(target_samples)

# print(target_samples.shape)

for strength in tqdm(strengths, total=len(strengths)):

    for n_val in tqdm(n_vals, total=len(n_vals)):

        curr_path = export_path.joinpath(f'strength_{round(strength * 10)}_n_{n_val}')
        if not Path.exists(curr_path):
            Path.mkdir(curr_path, exist_ok=True, parents=True)

        curr_sample = copy.deepcopy(target_samples)

        curr_sample = forward_diffusion_sample(curr_sample, int(strength*noise_scheduler.num_timesteps))
        curr_sample = curr_sample.to(DEVICE)

        curr_sample = curr_sample.repeat(n_val, 1)

        timesteps = list(range(len(noise_scheduler)))[::-1]
        for i, t in enumerate(timesteps):

            if t > strength*noise_scheduler.num_timesteps:
                continue

            t = torch.from_numpy(np.repeat(t, n_val * NUM_SAMPLES)).long().to(DEVICE) ###### Debug
            with torch.no_grad():
                residual = model(curr_sample, t)

            _, curr_sample = noise_scheduler.step(residual, t[0], curr_sample)

        np.save(curr_path.joinpath('gen_uncond.npy'), curr_sample.detach().cpu().numpy())

        reward = gmm.pdf(curr_sample)
        reward = torch.cat([x.unsqueeze(0) for x in reward.chunk(n_val)], dim=0)

        # Find the direction that minimizes the loss
        select_ind = torch.max(reward, dim=0)[1]
        curr_sample = torch.cat([x.unsqueeze(0) for x in curr_sample.chunk(n_val)], dim=0)
        curr_sample = curr_sample.permute(1,0,2)
        result = copy.deepcopy(torch.cat([x[select_ind[idx]].unsqueeze(0) for idx, x in enumerate(curr_sample)], dim=0)) # TODO: Make it efficient
        result = result.detach().cpu()

        np.save(curr_path.joinpath('gensample.npy'), result.numpy())

        data_df = pd.DataFrame({'x1': result[:, 0], 'x2': result[:, 1]})

        fig, ax = plt.subplots()

        sns.kdeplot(data_df, x='x1', y='x2', fill=True, ax=ax, cmap="Blues")
        ax.scatter(result[:, 0], result[:, 1], c='b', alpha=0.5, s=10)
        ax.axis([0, 10 , 0, 10])

        plt.grid()
        plt.tight_layout()
        plt.savefig(curr_path.joinpath('gensample.png'), dpi=300)
        plt.close()

In [61]:
# @title Reward vs Divergence Plots

IS_OOD = False

if not IS_OOD:
    mu_r1 = torch.tensor([5.0, 3.0]).to('cuda')
    Cov_r1 = torch.tensor([1.0, 1.0]).to('cuda')

    gmm = GaussianMixture([mu_r1],[Cov_r1], [1.0])

    setting = 'id'

else:
    mu_r2 = torch.tensor([9.0, 3.0]).to('cuda')
    Cov_r2 = torch.tensor([2.0, 2.0]).to('cuda')

    gmm = GaussianMixture([mu_r2],[Cov_r2], [1.0])

    setting = 'ood'

export_path = Path(OUTPUTS_DIR).joinpath(f'plots_{setting}')
if not Path.exists(export_path):
    Path.mkdir(export_path, exist_ok=True, parents=True)

p_samples = np.load(f'{OUTPUTS_DIR}/uncond/gensample.npy')
p_mean = np.mean(p_samples, axis=0)
p_cov = np.cov(p_samples.T)

rew_p = gmm.pdf(torch.from_numpy(p_samples).to('cuda')).cpu().numpy()

results = {
    'win_rate': [],
    'rewards': [],
    'fid': [],
    'mmd': [],
    'kl': [],
    'method': []
}

#####################
# Best-of-K
#####################

results['rewards'].append(rew_p.mean())
results['method'].append('Best-of-N')
results['fid'].append(0.0)
results['kl'].append(0.0)
results['mmd'].append(0.0)
results['win_rate'].append(0.5)

n_vals = np.arange(2, 12, 2).tolist()
n_vals.extend(np.arange(20, 50, 10).tolist())
n_vals.extend(np.arange(200, 400, 50).tolist())

for i, n in enumerate(n_vals):

    pi_samples = np.load(f'{OUTPUTS_DIR}/BoN_{setting}/n_{n}/gensample.npy')
    pi_samples_filtered = pi_samples[~np.any(np.isnan(pi_samples),axis=1)]
    pi_mean = np.mean(pi_samples_filtered, axis=0)
    pi_cov = np.cov(pi_samples_filtered.T)

    # p_samples_bok = np.load(f'{OUTPUTS_DIR}/BoN_{setting}/n_{n}/gen_uncond.npy')
    # p_mean_bok = np.mean(p_samples_bok, axis=0)
    # p_cov_bok = np.cov(p_samples_bok.T)

    rew = gmm.pdf(torch.from_numpy(pi_samples_filtered).to('cuda')).cpu().numpy()

    results['rewards'].append(rew.mean())
    results['method'].append('Best-of-N')
    results['fid'].append(calculate_fid(pi_mean, pi_cov, p_mean, p_cov))
    results['kl'].append(kl_mvn(pi_mean, pi_cov, p_mean, p_cov))
    results['mmd'].append(mmd_rbf(pi_samples, p_samples))
    results['win_rate'].append((rew > rew_p[~np.any(np.isnan(pi_samples),axis=1)]).astype(int).sum() / len(rew))

#####################
# Gradient guidance
#####################

results['rewards'].append(rew_p.mean())
results['method'].append('DPS (Chung, 2023)')
results['fid'].append(0.0)
results['kl'].append(0.0)
results['mmd'].append(0.0)
results['win_rate'].append(0.5)

# Loop over guidance scale
guidance_scales = np.arange(1,20,1).tolist()
guidance_scales.extend(np.arange(20, 55, 5).tolist())

for i, scale in enumerate(guidance_scales):

    pi_samples = np.load(f'{OUTPUTS_DIR}/gg_{setting}/scale_{scale}/gensample.npy')
    pi_samples_filtered = pi_samples[~np.any(np.isnan(pi_samples),axis=1)]
    pi_mean = np.mean(pi_samples_filtered, axis=0)
    pi_cov = np.cov(pi_samples_filtered.T)

    rew = gmm.pdf(torch.from_numpy(pi_samples_filtered).to('cuda')).cpu().numpy()

    results['rewards'].append(rew.mean())
    results['method'].append('DPS (Chung, 2023)')
    results['fid'].append(calculate_fid(pi_mean, pi_cov, p_mean, p_cov))
    results['kl'].append(kl_mvn(pi_mean, pi_cov, p_mean, p_cov))
    results['mmd'].append(mmd_rbf(pi_samples, p_samples))
    results['win_rate'].append((rew > rew_p[~np.any(np.isnan(pi_samples),axis=1)]).astype(int).sum() / len(rew))

#####################
# Universal guidance
#####################

results['rewards'].append(rew_p.mean())
results['method'].append('UG (Bansal, 2024)')
results['fid'].append(0.0)
results['kl'].append(0.0)
results['mmd'].append(0.0)
results['win_rate'].append(0.5)

# Loop over guidance scale
guidance_scales = np.arange(1,20,1).tolist()
guidance_scales.extend(np.arange(20, 55, 5).tolist())

for i, scale in enumerate(guidance_scales):

    pi_samples = np.load(f'{OUTPUTS_DIR}/unvg_{setting}/scale_{scale}/gensample.npy')
    pi_samples_filtered = pi_samples[~np.any(np.isnan(pi_samples),axis=1)]
    pi_mean = np.mean(pi_samples_filtered, axis=0)
    pi_cov = np.cov(pi_samples_filtered.T)

    rew = gmm.pdf(torch.from_numpy(pi_samples_filtered).to('cuda')).cpu().numpy()

    results['rewards'].append(rew.mean())
    results['method'].append('UG (Bansal, 2024)')
    results['fid'].append(calculate_fid(pi_mean, pi_cov, p_mean, p_cov))
    results['kl'].append(kl_mvn(pi_mean, pi_cov, p_mean, p_cov))
    results['mmd'].append(mmd_rbf(pi_samples, p_samples))
    results['win_rate'].append((rew > rew_p[~np.any(np.isnan(pi_samples),axis=1)]).astype(int).sum() / len(rew))

#######################
# CoDe (block-wise)
#######################

block_sizes = [10, 20, 40, 80]
# block_sizes = [80]

for blocksize in block_sizes:

    results['rewards'].append(rew_p.mean())
    results['method'].append(f'CoDe (Ours) [block: {blocksize}]')
    results['fid'].append(0.0)
    results['kl'].append(0.0)
    results['mmd'].append(0.0)
    results['win_rate'].append(0.5)

    n_vals = np.arange(2, 12, 2).tolist()
    n_vals.extend(np.arange(20, 50, 10).tolist())

    for i, n in enumerate(n_vals):

        pi_samples = np.load(f'{OUTPUTS_DIR}/CD_b{blocksize}_{setting}/n_{n}/gensample.npy')
        pi_samples_filtered = pi_samples[~np.any(np.isnan(pi_samples),axis=1)]
        pi_mean = np.mean(pi_samples_filtered, axis=0)
        pi_cov = np.cov(pi_samples_filtered.T)

        rew = gmm.pdf(torch.from_numpy(pi_samples_filtered).to('cuda')).cpu().numpy()

        results['rewards'].append(rew.mean())
        results['method'].append(f'CoDe (Ours) [block: {blocksize}]')
        results['fid'].append(calculate_fid(pi_mean, pi_cov, p_mean, p_cov))
        results['kl'].append(kl_mvn(pi_mean, pi_cov, p_mean, p_cov))
        results['mmd'].append(mmd_rbf(pi_samples, p_samples))
        results['win_rate'].append((rew > rew_p[~np.any(np.isnan(pi_samples),axis=1)]).astype(int).sum() / len(rew))

#######################
# SVDD-PM
#######################

results['rewards'].append(rew_p.mean())
results['method'].append(f'SVDD-PM (Li, 2024)')
results['fid'].append(0.0)
results['kl'].append(0.0)
results['mmd'].append(0.0)
results['win_rate'].append(0.5)

n_vals = np.arange(2, 12, 2).tolist()
n_vals.extend(np.arange(20, 50, 10).tolist())

for i, n in enumerate(n_vals):

    pi_samples = np.load(f'{OUTPUTS_DIR}/CD_b1_{setting}/n_{n}/gensample.npy')
    pi_samples_filtered = pi_samples[~np.any(np.isnan(pi_samples),axis=1)]
    pi_mean = np.mean(pi_samples_filtered, axis=0)
    pi_cov = np.cov(pi_samples_filtered.T)

    rew = gmm.pdf(torch.from_numpy(pi_samples_filtered).to('cuda')).cpu().numpy()

    results['rewards'].append(rew.mean())
    results['method'].append(f'SVDD-PM (Li, 2024)')
    results['fid'].append(calculate_fid(pi_mean, pi_cov, p_mean, p_cov))
    results['kl'].append(kl_mvn(pi_mean, pi_cov, p_mean, p_cov))
    results['mmd'].append(mmd_rbf(pi_samples, p_samples))
    results['win_rate'].append((rew > rew_p[~np.any(np.isnan(pi_samples),axis=1)]).astype(int).sum() / len(rew))


#######################
# SDEdit
#######################

results['rewards'].append(rew_p.mean())
results['method'].append('SDEdit (Meng, 2021)')
results['fid'].append(0.0)
results['kl'].append(0.0)
results['mmd'].append(0.0)
results['win_rate'].append(0.5)

# Loop over strength
strengths = np.arange(0.4,1.0,0.1).tolist()

for i, strength in enumerate(reversed(strengths)):

    pi_samples = np.load(f'{OUTPUTS_DIR}/SDEdit_{setting}/strength_{round(strength * 10)}/gensample.npy')
    pi_samples_filtered = pi_samples[~np.any(np.isnan(pi_samples),axis=1)]
    pi_mean = np.mean(pi_samples_filtered, axis=0)
    pi_cov = np.cov(pi_samples_filtered.T)

    rew = gmm.pdf(torch.from_numpy(pi_samples_filtered).to('cuda')).cpu().numpy()

    results['rewards'].append(rew.mean())
    results['method'].append('SDEdit (Meng, 2021)')
    results['fid'].append(calculate_fid(pi_mean, pi_cov, p_mean, p_cov))
    results['kl'].append(kl_mvn(pi_mean, pi_cov, p_mean, p_cov))
    results['mmd'].append(mmd_rbf(pi_samples, p_samples))
    results['win_rate'].append((rew > rew_p[~np.any(np.isnan(pi_samples),axis=1)]).astype(int).sum() / len(rew))

#####################
# SDEdit + Best-of-K
#####################

results['rewards'].append(rew_p.mean())
results['method'].append('SDEdit + BoN')
results['fid'].append(0.0)
results['kl'].append(0.0)
results['mmd'].append(0.0)
results['win_rate'].append(0.5)

n_vals = np.arange(2, 12, 2).tolist()
n_vals.extend(np.arange(20, 50, 10).tolist())

for i, n in enumerate(n_vals):

    pi_samples = np.load(f'{OUTPUTS_DIR}/SDEdit_BoN_{setting}/strength_4_n_{n}/gensample.npy')
    pi_samples_filtered = pi_samples[~np.any(np.isnan(pi_samples),axis=1)]
    pi_mean = np.mean(pi_samples_filtered, axis=0)
    pi_cov = np.cov(pi_samples_filtered.T)

    # p_samples_bok = np.load(f'{OUTPUTS_DIR}/SDEdit_BoN_{setting}/strength_4_n_{n}/gen_uncond.npy')
    # p_mean_bok = np.mean(p_samples_bok, axis=0)
    # p_cov_bok = np.cov(p_samples_bok.T)

    rew = gmm.pdf(torch.from_numpy(pi_samples_filtered).to('cuda')).cpu().numpy()

    results['rewards'].append(rew.mean())
    results['method'].append('SDEdit + BoN')
    results['fid'].append(calculate_fid(pi_mean, pi_cov, p_mean, p_cov))
    results['kl'].append(kl_mvn(pi_mean, pi_cov, p_mean, p_cov))
    results['mmd'].append(mmd_rbf(pi_samples, p_samples))
    results['win_rate'].append((rew > rew_p[~np.any(np.isnan(pi_samples),axis=1)]).astype(int).sum() / len(rew))

#####################
# SDEdit + Gradient guidance
#####################

results['rewards'].append(rew_p.mean())
results['method'].append('SDEdit + DPS')
results['fid'].append(0.0)
results['kl'].append(0.0)
results['mmd'].append(0.0)
results['win_rate'].append(0.5)

# Loop over guidance scale
guidance_scales = np.arange(1,20,1).tolist()
guidance_scales.extend(np.arange(20, 55, 5).tolist())

for i, scale in enumerate(guidance_scales):

    pi_samples = np.load(f'{OUTPUTS_DIR}/SDEdit_gg_{setting}/strength_4_scale_{scale}/gensample.npy')
    pi_samples_filtered = pi_samples[~np.any(np.isnan(pi_samples),axis=1)]
    pi_mean = np.mean(pi_samples_filtered, axis=0)
    pi_cov = np.cov(pi_samples_filtered.T)

    rew = gmm.pdf(torch.from_numpy(pi_samples_filtered).to('cuda')).cpu().numpy()

    results['rewards'].append(rew.mean())
    results['method'].append('SDEdit + DPS')
    results['fid'].append(calculate_fid(pi_mean, pi_cov, p_mean, p_cov))
    results['kl'].append(kl_mvn(pi_mean, pi_cov, p_mean, p_cov))
    results['mmd'].append(mmd_rbf(pi_samples, p_samples))
    results['win_rate'].append((rew > rew_p[~np.any(np.isnan(pi_samples),axis=1)]).astype(int).sum() / len(rew))

#####################
# SDEdit + Universal guidance
#####################

results['rewards'].append(rew_p.mean())
results['method'].append('SDEdit + Universal Guidance')
results['fid'].append(0.0)
results['kl'].append(0.0)
results['mmd'].append(0.0)
results['win_rate'].append(0.5)

# Loop over guidance scale
guidance_scales = np.arange(1,20,1).tolist()
guidance_scales.extend(np.arange(20, 55, 5).tolist())

for i, scale in enumerate(guidance_scales):

    pi_samples = np.load(f'{OUTPUTS_DIR}/SDEdit_unvg_{setting}/strength_4_scale_{scale}/gensample.npy')
    pi_samples_filtered = pi_samples[~np.any(np.isnan(pi_samples),axis=1)]
    pi_mean = np.mean(pi_samples_filtered, axis=0)
    pi_cov = np.cov(pi_samples_filtered.T)

    rew = gmm.pdf(torch.from_numpy(pi_samples_filtered).to('cuda')).cpu().numpy()

    results['rewards'].append(rew.mean())
    results['method'].append('SDEdit + Universal Guidance')
    results['fid'].append(calculate_fid(pi_mean, pi_cov, p_mean, p_cov))
    results['kl'].append(kl_mvn(pi_mean, pi_cov, p_mean, p_cov))
    results['mmd'].append(mmd_rbf(pi_samples, p_samples))
    results['win_rate'].append((rew > rew_p[~np.any(np.isnan(pi_samples),axis=1)]).astype(int).sum() / len(rew))

#######################
# SDEdit + CoDe
#######################

block_sizes = [10, 20, 40, 80]
# block_sizes = [80]

for blocksize in block_sizes:

    results['rewards'].append(rew_p.mean())
    results['method'].append(f'SDEdit + CoDe (Ours) [block: {blocksize}]')
    results['fid'].append(0.0)
    results['kl'].append(0.0)
    results['mmd'].append(0.0)
    results['win_rate'].append(0.5)

    n_vals = np.arange(2, 12, 2).tolist()
    n_vals.extend(np.arange(20, 50, 10).tolist())

    for i, n in enumerate(n_vals):

        pi_samples = np.load(f'{OUTPUTS_DIR}/SDEdit_CD_b{blocksize}_{setting}/strength_4_n_{n}/gensample.npy')
        pi_samples_filtered = pi_samples[~np.any(np.isnan(pi_samples),axis=1)]
        pi_mean = np.mean(pi_samples_filtered, axis=0)
        pi_cov = np.cov(pi_samples_filtered.T)

        rew = gmm.pdf(torch.from_numpy(pi_samples_filtered).to('cuda')).cpu().numpy()

        results['rewards'].append(rew.mean())
        results['method'].append(f'SDEdit + CoDe (Ours) [block: {blocksize}]')
        results['fid'].append(calculate_fid(pi_mean, pi_cov, p_mean, p_cov))
        results['kl'].append(kl_mvn(pi_mean, pi_cov, p_mean, p_cov))
        results['mmd'].append(mmd_rbf(pi_samples, p_samples))
        results['win_rate'].append((rew > rew_p[~np.any(np.isnan(pi_samples),axis=1)]).astype(int).sum() / len(rew))

################################
# SDEdit + SVDD-PM
################################

results['rewards'].append(rew_p.mean())
results['method'].append(f'SDEdit + SVDD-PM')
results['fid'].append(0.0)
results['kl'].append(0.0)
results['mmd'].append(0.0)
results['win_rate'].append(0.5)

n_vals = np.arange(2, 12, 2).tolist()
n_vals.extend(np.arange(20, 50, 10).tolist())

for i, n in enumerate(n_vals):

    pi_samples = np.load(f'{OUTPUTS_DIR}/SDEdit_CD_b1_{setting}/strength_4_n_{n}/gensample.npy')
    pi_samples_filtered = pi_samples[~np.any(np.isnan(pi_samples),axis=1)]
    pi_mean = np.mean(pi_samples_filtered, axis=0)
    pi_cov = np.cov(pi_samples_filtered.T)

    rew = gmm.pdf(torch.from_numpy(pi_samples_filtered).to('cuda')).cpu().numpy()

    results['rewards'].append(rew.mean())
    results['method'].append(f'SDEdit + SVDD-PM')
    results['fid'].append(calculate_fid(pi_mean, pi_cov, p_mean, p_cov))
    results['kl'].append(kl_mvn(pi_mean, pi_cov, p_mean, p_cov))
    results['mmd'].append(mmd_rbf(pi_samples, p_samples))
    results['win_rate'].append((rew > rew_p[~np.any(np.isnan(pi_samples),axis=1)]).astype(int).sum() / len(rew))

results_df = pd.DataFrame(results)


In [62]:
plot_results = results_df.loc[results_df['method'].isin([
    'Best-of-N', 
    'DPS (Chung, 2023)',
    'UG (Bansal, 2024)', 
    # 'CoDe (Ours) [block: 10]',
    # 'CoDe (Ours) [block: 20]', 
    # 'CoDe (Ours) [block: 40]',
    'CoDe (Ours) [block: 80]',
    # 'CoDe Updated (Ours) [block: 10]',
    # 'CoDe Updated (Ours) [block: 20]',
    # 'CoDe Updated (Ours) [block: 40]',
    'CoDe (Ours) [block: 80]', 
    # 'CoDe Updated (Ours) [block: 1000]',
    'SVDD-PM (Li, 2024)',
    # 'SDEdit (Meng, 2021)',
    # 'SDEdit + BoN',
    # 'SDEdit + DPS',
    # 'SDEdit + Universal Guidance',
    # 'SDEdit + SVDD-MC',
    # 'SDEdit + CoDe (Ours) [block: 10]',
    # 'SDEdit + CoDe (Ours) [block: 20]',
    # 'SDEdit + CoDe Updated (Ours) [block: 10]',
    # 'SDEdit + CoDe Updated (Ours) [block: 20]',
    # 'SDEdit + CoDe Updated (Ours) [block: 40]',
    'SDEdit + CoDe (Ours) [block: 80]',
])]

method_labels = {
    'Best-of-N': 'BoN',
    'SDEdit \+ BoN': 'C-BoN', 
    r'CoDe.*': 'CoDe (Ours)',
    r'SDEdit \+ CoDe.*': 'C-CoDe (Ours)',
}

plot_results['method'] = plot_results['method'].replace(regex=method_labels)

normalize = True
if normalize:
    base_rew = rew_p.mean()
    plot_results['rewards'] = plot_results['rewards']/base_rew

In [11]:
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

In [63]:
markers = {
    'BoN': '^',
    # 'C-BoN': 'o',
    'CoDe (Ours)': 'o',
    'C-CoDe (Ours)': '^',
    # 'CoDe (Ours) [block 5]': 'o',
    # 'C-CoDe (Ours) [block 5 r 0.5]': 'D',
    # 'C-CoDe (Ours) [block 5 r 0.6]': '^',
    # 'C-CoDe (Ours) [block 5 r 0.7]': 'v',
    # 'C-CoDe (Ours) [block 5 r 0.8]': 'D',
    # 'C-CoDe (Ours) [block 50 r 0.6]': 'D',
    'SVDD-PM (Li, 2024)': 'o',
    'UG (Bansal, 2024)': 's',
    'DPS (Chung, 2023)': '^'
}

palette = {
    'BoN': sns.color_palette("Paired")[5],
    # 'C-BoN': sns.color_palette("Paired")[6],
    'CoDe (Ours)': sns.color_palette("Paired")[3],
    'C-CoDe (Ours)': sns.color_palette("Paired")[1],
    # 'CoDe (Ours) [block 5]': sns.color_palette("Paired")[3],
    # 'C-CoDe (Ours) [block 50 r 0.6]': sns.color_palette("Paired")[9],
    # 'C-CoDe (Ours) [block 5 r 0.6]': sns.color_palette("Paired")[1],
    # 'C-CoDe (Ours) [block 5 r 0.7]': sns.color_palette("Paired")[9],
    # 'C-CoDe (Ours) [block 5 r 0.8]': sns.color_palette("Paired")[11],
    'SVDD-PM (Li, 2024)': sns.color_palette("Paired")[9],
    'UG (Bansal, 2024)': sns.color_palette("Paired")[7],
    'DPS (Chung, 2023)': sns.color_palette("Paired")[11],
}

normalize = True
labels={
    "rewards": f"{'(Normalized)' if normalize else ''} Exp. Rewards",
    "kl": "KL Divergence",
    "fid": "Source FID",
    "cmmd": "Source CMMD",
    "ref_fid": "Reference FID",
    "ref_cmmd": "Reference CMMD",
    "win_rate": "Win Rate",
    "method": "Guidance Method",
    "clipscore": "Text Alignment",
    "clipwinrate": "Win Rate (Text Align)"
}

# linestyles = {
#     'BoN': '-',
#     # 'C-BoN': '-.',
#     'CoDe (Ours)': '--',
#     'C-CoDe (Ours)': ':',
#     'SVDD-MC (Li, 2024)': '--',
#     'Universal Guidance (Bansal, 2024)': '-.',
#     'DPS (Chung, 2023)': '-.'
# }

linestyles = {
    'BoN': '--',
    # 'C-BoN': sns.color_palette("Paired")[6],
    'CoDe (Ours)': '-',
    'C-CoDe (Ours)': '-',
    # 'CoDe (Ours) [block 5]': '--',
    # 'C-CoDe (Ours) [block 50 r 0.6]': '-',
    # 'C-CoDe (Ours) [block 5 r 0.6]': '-',
    # 'C-CoDe (Ours) [block 5 r 0.7]': '-',
    # 'C-CoDe (Ours) [block 5 r 0.8]': '-',
    'SVDD-PM (Li, 2024)': '-',
    'UG (Bansal, 2024)': '--',
    'DPS (Chung, 2023)': '-',
}

In [None]:
# for perf in ['rewards', 'win_rate']:
#     for div in ['fid', 'mmd', 'kl']:

from mpl_toolkits.axes_grid1.inset_locator import inset_axes

perf = ['rewards', 'win_rate']
div = ['kl', 'kl'] 

fig, axes = plt.subplots(1, 2, figsize=(10,5))

for i, ax in enumerate(axes.flatten()):

    g = sns.lineplot(plot_results, x=div[i], y=perf[i], hue='method', style='method',
                    linestyle='-', markers=markers, palette=palette, markersize=6, fillstyle='none', markeredgecolor=None, markeredgewidth=2,
                    style_order=list(palette.keys()), ax=ax, sort=False)

    ax.get_legend().remove()

    if i == 0:
      x1, x2, y1, y2 = 2, 6, 0.8, 1.0  # subregion of the original image
    else:
      x1, x2, y1, y2 = 1.5, 5.5, 0.9, 1.01  # subregion of the original image

    axins = ax.inset_axes(
        [0.55, 0.1, .4, .4],
        xlim=(x1, x2), ylim=(y1, y2)) #, xticklabels=[], yticklabels=[])
    ax.indicate_inset_zoom(axins, edgecolor="black")

    g = sns.lineplot(plot_results, x=div[i], y=perf[i], hue='method', style='method',
                    linestyle='-', markers=markers, palette=palette, markersize=7, fillstyle='none', markeredgecolor=None, markeredgewidth=2,
                    style_order=list(palette.keys()), ax=axins, sort=False, legend=False)
      
    axins.set_xlabel('')
    axins.set_ylabel('')

    ax.set_xlabel(labels[div[i]], fontsize=14)
    ax.set_ylabel(labels[perf[i]], fontsize=14)
    ax.tick_params(axis='both', labelsize=12)

handles, l_labels = ax.get_legend_handles_labels()
fig.legend(handles, l_labels, bbox_to_anchor=(0.94, 1.03), ncols=3, fontsize=14, frameon=False)
# # plt.show()
plt.savefig(export_path.joinpath(f'result_{setting}.png'), dpi=300, bbox_inches='tight')
# # plt.close()