 # Part I: Understanding KL with Mixtures

 The goal of this section is to understand practically the difference between the following two versions of the variational inference problem:

 $$ \text{argmin}_{p \in \mathcal{F}} \operatorname{KL}(p || p_{\mathcal{D}})\;,$$

 and

 $$\text{argmin}_{p \in \mathcal{F}} \operatorname{KL}(p_{\mathcal{D}} || p)\;.$$

 In this section, we consider one of the simplest possible scenarios

 $$ p_{\mathcal{D}} = \frac{1}{2}\mathcal{N}(-1, 10^{-2}) + \frac{1}{2}\mathcal{N}(+1, 10^{-2}) \;,$$

 and

 $$ \mathcal{F} = \{\mathcal{N}(\mu, \exp(2\log\sigma)) | \mu \in \mathbb{R}; \log \sigma \in \mathbb{R} \} \;. $$

 ## Forward KL:

 We consider first the case of the forward KL:

 $$\text{argmin}_{p \in \mathcal{F}} \text{KL}(p_{\mathcal{D}} || p) = \int \log\frac{p_{\mathcal{D}}(x)}{p(x)}p_{\mathcal{D}}(dx)\;. $$

In [1]:
import numpy as np

**Unizipping the dataset**

In [None]:
!unzip chemin/vers/fichier.zip -d dossier_de_destination

unzip:  cannot find or open chemin/vers/fichier.zip, chemin/vers/fichier.zip.zip or chemin/vers/fichier.zip.ZIP.


In [None]:
import torch.distributions as dist
from torch import optim
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import math

p_data = dist.MixtureSameFamily(
    dist.Categorical(torch.ones(2) / 2, validate_args=False),
    dist.Normal(torch.tensor([-1., 1.]), 0.1*torch.ones(2), validate_args=False),
    validate_args=False
    )
param = torch.tensor([0.1, math.log(0.1)], requires_grad=True)

def get_dist_from_params(params: torch.Tensor) -> dist.Distribution:
    return dist.Normal(loc=params[0], scale=params[1].exp(), validate_args=False)

def kl_monte_carlo_estimator(p_0: dist.Distribution, p_1: dist.Distribution, n_mc=1_000) -> torch.Tensor:
    mc_nodes = p_0.sample((n_mc,))
    return (p_0.log_prob(mc_nodes)-p_1.log_prob(mc_nodes)).mean()


def plot_densities(p_0: dist.Distribution, p_1:dist.Distribution, ax, xr=(-3, 3)):
    x = torch.linspace(*xr, 100)
    log_p_0 = p_0.log_prob(x)
    log_p_1 = p_1.log_prob(x)
    ln1, = ax.plot(x, log_p_0.exp(), label=r"$p_0$")
    ln2, = ax.plot(x, log_p_1.exp(), label=r"$p_1$")
    ax.set_ylabel("Density")
    ax.set_xlabel("x")
    ax.legend()
    return (ln1, ln2)

fig, ax = plt.subplots(1, 1)
lines = plot_densities(p_data, get_dist_from_params(param.detach()), ax)
losses = []
params = []

#Optimization loop
n_epochs = 1000
opt = optim.Adam([param], lr=5e-2)


def init():
    return lines

def update(frame):
    opt.zero_grad()
    loss = kl_monte_carlo_estimator(p_0 = p_data, p_1 = get_dist_from_params(param))
    loss.backward()
    losses.append(loss.item())
    params.append(param.detach().clone())
    opt.step()
    lines[1].set_ydata(get_dist_from_params(param.detach()).log_prob(torch.linspace(-3, 3, 100)).exp())
    return lines

ani = FuncAnimation(fig, update, frames=torch.arange(1, 100),
                init_func=init, blit=True) 
plt.close()
HTML(ani.to_jshtml())

# Backward KL and more

 We now try to solve:

 $$
 \text{argmin}_{p \in \mathcal{F}} \text{KL}(p || p_{\mathcal{D}})\;.
 $$

In [None]:
param = torch.tensor([.1, math.log(.1)], requires_grad=True)
fig, ax = plt.subplots(1, 1)
lines = plot_densities(p_data, get_dist_from_params(param.detach()), ax)
losses = []
params = []


def better_kl_monte_carlo_estimator(param, p_1: dist.Distribution, n_mc=1_000) -> torch.Tensor:
    mc_nodes = torch.randn((n_mc,)) * param[1].exp() + param[0]
    entropy_p_0 = -param[1]
    return -p_1.log_prob(mc_nodes).mean() + entropy_p_0 

#Optimization loop
n_epochs = 100
plot_every = 1
opt = optim.AdamW([param], betas=(0.7, 0.999), lr=5e-2, weight_decay=0)


def init():
    return lines

def update(frame):
    for i in range(plot_every):
        opt.zero_grad()
        loss = better_kl_monte_carlo_estimator(p_1 = p_data, param=param, n_mc=10_000)
        loss.backward()
        losses.append(loss.item())
        params.append(param.detach().clone())
        opt.step()
    lines[1].set_ydata(get_dist_from_params(param.detach()).log_prob(torch.linspace(-3, 3, 100)).exp())
    return lines

ani = FuncAnimation(fig, update, frames=torch.arange(1, n_epochs // plot_every),
                init_func=init, blit=True) 
plt.close()
HTML(ani.to_jshtml())

 # Part II: Evaluating generative models from samples.

 The goal of this notebook is to evaluate the generative capabilities of three different generated datasets, namely "generated_data_1", "generated_data_2" and "generated_data_3". Each dataset task is to generate

 data that emulates the data generated by distribution underlying the samples of "real_image".



In [None]:
# This cell implement imports and define some functions that are useful.
from datasets import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
import platform
from ot.sliced import max_sliced_wasserstein_distance, sliced_wasserstein_distance
from tqdm import tqdm
import numpy as np
from torchvision.models import inception_v3
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
import time
import pandas as pd
from scipy import linalg


def get_device():
    # Check if CUDA is available
    if torch.cuda.is_available():
        device = torch.device("cuda:1")
        print(f"Using {device} for computation.")
        return device
    elif (
        "Apple" in platform.system()
    ):  # Check for Apple Silicon devices (MacOS on M1, etc.)
        device = torch.device("mps")  # Use Apple Metal Performance Shaders
        print(f"Using {device} for computation.")
        return device
    else:
        device = torch.device("cpu")
        print(f"Using CPU for computation.")
        return device


# Get the device
device = get_device()


def plot_image_from_tensor(ax, img):
    ax.imshow((img.permute(1, 2, 0) + 1) / 2)
    ax.set_axis_off()
    return ax

In [None]:
# Datasets should be a dictionary with the name of each dataset and a path variable to the folder where eacg datasets is (after unpacking).
datasets = {
    "real": {"path": "real_images", "dataloader": None},
    "gen1": {"path": "generated_data_1", "dataloader": None},
    "gen2": {"path": "generated_data_2", "dataloader": None},
    "gen3": {"path": "generated_data_3", "dataloader": None},
}
# This transformation resizes images to size 64 for the real images.
resize = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((64, 64)),
        transforms.Normalize(
            [
                0.5,
            ]
            * 3,
            [
                0.5,
            ]
            * 3,
        ),
    ]
)

In [None]:
# This cell is effectively where we load the data. We create a dataloader class which will randomly read files from the disk and return batches of BATCH_SIZE, which is a variable you should choose depending on the configuration you
# are using.
def f(x):
    return {"image": [resize(image.convert("RGB")) for image in x["image"]]}
BATCH_SIZE = 8
for name, infos in datasets.items():
    ds = Dataset.load_from_disk(infos["path"]).with_format("torch")

    ds.set_transform(
        f
    )
    # This loop creates a dataloader function that helps iterating through the dataset
    infos["dataloader"] = DataLoader(
        ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=8
    )

In [None]:
# We simply display some images. Each time you run this cell a different set will be drawn
fig, axes = plt.subplots(1, 4)
for ax, (label, info) in zip(axes, datasets.items()):
    for batch in info["dataloader"]:
        img = batch["image"][0]
        break
    ax = plot_image_from_tensor(ax, img)
    ax.set_title(label)
fig.tight_layout()
fig.show()

 ## Sliced Wasserstein(s) in ambient space.

 We are going to use both max sliced wasserstein and mean sliced wasserstein to evaluate the generative models.

 For a finite number of samples, the empirical estimator is biased. The first thing one can do is to evaluate the bias and variance for different sample sizes and slices for a given dataset.

 ### Implementation of a function that evaluates the bias and variance of the empirical estimator for different slices and number of samples and find a configuration you are ok with.

In [None]:
NREP = 5
nlog2range = range(6, 14)
sliceslog10range = range(1, 4)
wdist = {}
for nlog2 in nlog2range:
    nsamples = 2**nlog2
    wdist[nsamples] = {}
    for sliceslog10 in sliceslog10range:
        nslices = 10**sliceslog10
        wdist[nsamples][nslices] = {"mean": [], "max": []}
        for i in tqdm(range(NREP), desc=f"N samples: {nsamples}, N slices: {nslices}"):
            sampled_datas = {
                j: torch.concatenate(
                    [
                        batch["image"]
                        for _, batch in zip(
                            range(nsamples // BATCH_SIZE + nsamples % BATCH_SIZE),
                            datasets["real"]["dataloader"],
                        )
                    ]
                )
                for j in range(2)
            }
            start = time.time()
            mean_sw = sliced_wasserstein_distance(
                    sampled_datas[0].flatten(1, -1).to(device),
                    sampled_datas[1].flatten(1, -1).to(device),
                    n_projections=nslices,
                ).item()
            mean_sw_time = time.time() - start
            wdist[nsamples][nslices]["mean"].append(
                {"value": mean_sw, "run_time": mean_sw_time}
            )
            start = time.time()
            max_sw = max_sliced_wasserstein_distance(
                    sampled_datas[0].flatten(1, -1).to(device),
                    sampled_datas[1].flatten(1, -1).to(device),
                    n_projections=nslices,
                ).item()
            max_sw_time = time.time() - start
            wdist[nsamples][nslices]["max"].append(
                {"value": max_sw, "run_time": max_sw_time}
            )

In [None]:
fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(5, 5))
identifiers = {(nsamples, nslices) for nsamples, wdist_per_samples in wdist.items() for nslices in wdist_per_samples.keys()}
for t, fmt in zip(["max", "mean"], ["*", "o"]):
    for n_slices in set([10**s for s in sliceslog10range]):
        mean_values = [np.mean([it["value"] for it in wdist[i[0]][n_slices][t]]) for i in identifiers]
        std_values = [np.std([it["value"] for it in wdist[i[0]][n_slices][t]]) for i in identifiers]
        ax.errorbar([i[0] for i in identifiers], mean_values, yerr=std_values, fmt=fmt, label=f"{t} {n_slices}")
ax.set_yscale("log")
ax.set_xlabel("N slices")
ax.set_ylabel("Value")
ax.legend()
#ax.set_xscale("log")
fig.show()

In [None]:
fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(5, 5))
identifiers = {(nsamples, nslices) for nsamples, wdist_per_samples in wdist.items() for nslices in wdist_per_samples.keys() }
for t in ["max", "mean"]:
    mean_values = [np.mean([it["value"] for it in wdist[i[0]][i[1]][t]]) for i in identifiers]
    std_values = [np.std([it["value"] for it in wdist[i[0]][i[1]][t]]) for i in identifiers]
    mean_rt = [np.mean([it["run_time"] for it in wdist[i[0]][i[1]][t]]) for i in identifiers]
    std_rt = [np.std([it["run_time"] for it in wdist[i[0]][i[1]][t]]) for i in identifiers]
    ax.errorbar(mean_rt, mean_values, xerr=std_rt, yerr=std_values, fmt="o", label=t)
    for i, x, y in zip(identifiers, mean_rt, mean_values):
        ax.annotate(f"{i[0]}\n{i[1]}", (x, y))
ax.set_xlim([0.01, 0.2])
ax.set_yscale("log")
ax.set_xlabel("Run time")
ax.set_ylabel("Value")
ax.legend()
ax.set_xscale("log")
fig.show()

 ### Evaluation for the chosen configuration the Wasserstein between the different generative models.

In [None]:
N_SAMPLES = 4096
N_SLICES = 1000
NREP = 5
wdist2real = {k: {"mean": [], "max": []} for k in datasets.keys()}
for label in wdist2real.keys():
    for i in tqdm(range(NREP), desc=label):
        sampled_datas = {
            l: torch.concatenate(
                [
                    batch["image"]
                    for _, batch in zip(
                        range(N_SAMPLES // BATCH_SIZE + N_SAMPLES % BATCH_SIZE),
                        datasets[l.replace("_", "")]["dataloader"],
                    )
                ]
            )
            for l in ["real", label+"_"]
        }
        wdist2real[label]["mean"].append(
            sliced_wasserstein_distance(
                sampled_datas["real"].flatten(1, -1).to(device),
                sampled_datas[label+"_"].flatten(1, -1).to(device),
                n_projections=N_SLICES,
            ).item()
        )
        wdist2real[label]["max"].append(
            max_sliced_wasserstein_distance(
                sampled_datas["real"].flatten(1, -1).to(device),
                sampled_datas[label+"_"].flatten(1, -1).to(device),
                n_projections=N_SLICES,
            ).item()
        )

In [None]:
pd.DataFrame.from_records({k: {label: f"{np.mean(value):.4f} {np.std(value):.4f}" for label, value in v.items()} for k, v in wdist2real.items()})

 ## Evaluating Wasserstein using embedding

 It is often the case where one does not care exactly for the distribution in the ambient space (pixel space in our case),

 But rather the similarity between distributions in some embedding space. For images, this embedding space is often the last layers

 of some classification network used for a relevant task such as classifying the objects in the photo.

 We are going to do just that, we are going to use the last layer of the inceptionv3 classifier as our embedding space and

 we are going to evaluate if in this embedding space our generative models ressemble the real data.

In [None]:
# This define some helpful functions and load the network
full_transform = transforms.Compose(
    [
        transforms.Resize(342),
        transforms.CenterCrop(299),
        transforms.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

inception_net = inception_v3(pretrained=True, transform_input=False)
inception_net.requires_grad_(False)
inception_net.eval()
inception_net = inception_net.to(device)
train_nodes, eval_nodes = get_graph_node_names(inception_net)

# remove the last layer
return_nodes = eval_nodes[-2:-1]

# create a feature extractor for each intermediary layer
feat_inception = create_feature_extractor(inception_net, return_nodes=return_nodes)

feat_inception = torch.compile(feat_inception).to(device)

def get_feature(x: torch.Tensor) -> torch.Tensor:
    # This function takes as input a batch of images and return their embedding.
    return feat_inception(full_transform(x.to(device)))["flatten"].flatten(1, -1)

In [None]:
N_SAMPLES = 4096
N_SLICES = 1000
NREP = 5
wdist2real = {k: {"mean": [], "max": []} for k in datasets.keys()}
for label in wdist2real.keys():
    for i in tqdm(range(NREP), desc=label):
        sampled_datas = {
            l: torch.concatenate(
                [
                    get_feature(batch["image"])
                    for _, batch in zip(
                        range(N_SAMPLES // BATCH_SIZE + N_SAMPLES % BATCH_SIZE),
                        datasets[l.replace("_", "")]["dataloader"],
                    )
                ]
            )
            for l in ["real", label+"_"]
        }
        wdist2real[label]["mean"].append(
            sliced_wasserstein_distance(
                sampled_datas["real"].flatten(1, -1).to(device),
                sampled_datas[label+"_"].flatten(1, -1).to(device),
                n_projections=N_SLICES,
            ).item()
        )
        wdist2real[label]["max"].append(
            max_sliced_wasserstein_distance(
                sampled_datas["real"].flatten(1, -1).to(device),
                sampled_datas[label+"_"].flatten(1, -1).to(device),
                n_projections=N_SLICES,
            ).item()
        )

In [None]:
pd.DataFrame.from_records({k: {label: f"{np.mean(value):.4f} {np.std(value):.4f}" for label, value in v.items()} for k, v in wdist2real.items()})

 One of the most famous metrics is the FID (Frechet Inception distance). It is obtained by first calculating the mean and covariance matrices

 for the embedding for each set of samples and then calculating the Wasserstein between the two, as it is available in closed formula.