# Synthetic data

In the final part of this lab, we are going to generate synthetic image data. You can think of synthetic data as an extension of augmentation, with the exception that we want to generate completely new samples.

Why is synthetic data desirable? There are at least two reasons:

1. More high quality data helps us train better models.
2. Using synthetic data in lieu of "natural data" helps us avoid privacy issues.

Synthetic data is a _generative task_. We thus need a model that learns an approximation to the "true" distribution that we wish to sample (e.g. all cat images). There are many possible model architectures for this, e.g

- Generative Adversarial Networks
- (Variational) Autoencoders
- Normalizing flows
- Diffusion models
- ... and probably many more

Note that one can also generate data in other ways. For instance, a popular approach is to use 3D rendering software like Blender.

Here, we're using a diffusion model. You have already seen this type of model in a previous lab, so you should be able to complete this lab in a breeze. ;) Synthetic data from diffusion models have also [been found to improve classification performance](https://arxiv.org/abs/2304.08466).

## Generating synthetic samples

For this demonstration, we will generate synthetic snacks from the [`snacks` dataset](https://huggingface.co/datasets/Matthijs/snacks).
This is a dataset of 20 different types of snack foods that accompanies the book [Machine Learning by Tutorials](https://www.raywenderlich.com/books/machine-learning-by-tutorials/v2.0). Let us begin by loading the ground truth data. We do this using the [`datasets`](https://huggingface.co/docs/datasets/index) library, which is like the `transformers` library we saw in a previous law - but for datasets.

In [None]:
from typing import Union

from datasets import (
  Dataset,
  DatasetDict,
  IterableDataset,
  IterableDatasetDict,
  load_dataset,
)

# You can ignore this, it's just to make the type checker happy.
DatasetType = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]

dataset: DatasetType  = load_dataset("Matthijs/snacks")

Datasets behave like you know it from e.g. `torch` or `pandas`. We can access the `train` split by indexing it like a dictionary:

In [None]:
dataset["train"]

To access a sample, just use the sample index:

In [None]:
dataset["train"][42]

This is just a dictionary!

In [None]:
dataset["train"][42]["image"]

For your convenience, we provide a dictionary that maps the numeric class ids to the class names.

In [None]:
LABEL_MAP = {
    0: "apple",
    1: "banana",
    2: "cake",
    3: "candy",
    4: "carrot",
    5: "cookie",
    6: "doughnut",
    7: "grape",
    8: "hot dog",
    9: "ice cream",
    10: "juice",
    11: "muffin",
    12: "orange",
    13: "pineapple",
    14: "popcorn",
    15: "pretzel",
    16: "salad",
    17: "strawberry",
    18: "waffle",
    19: "watermelon",
}

Next, we'll use a diffusion model from HuggingFace hub. Much like the `transformers` library implements various transformer architectures, the `diffusers` library implements many state-of-the-art diffusion models. We will use `Stable Diffusion XL Turbo`, which is a distilled version of stable diffusion that can produce high quality images in as little as one step. If you are blessed with beefy hardware, you can of course also use a different model.

In [None]:
import torch
from diffusers import AutoPipelineForText2Image

pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

# Use this if you are on a Mac:
# pipe.to("mps")

# Use this if you have an NVIDIA GPU:
pipe = pipe.to("cuda")

# Else, use this:
# pipe = pipe.to("cpu")

# Compilation accelerates inference. The first inference run will be very slow though!
# Comment this line out if you are using a Mac.
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

Now, it is very easy to generate samples. Simply prompt the model as shown below. In Colab, this will take around three and a half minutes. Subsequent runs take around a second.

In [None]:
prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
image = pipe(prompt=prompt, guidance_scale=0.0, num_inference_steps=1).images[0]
image

### Generating some snacks

How do we go from cute racoons to images of snacks? Well, it's just a matter of the prompt!

In [None]:
prompt = "A basket of snacks."

image = pipe(prompt=prompt, guidance_scale=0.0, num_inference_steps=1).images[0]
image

Let's generate a bit more data. Simply repeat the prompt a few times!

In [None]:
synth_samples = []
num_samples = 10  # Increase as you like. More samples will also make later computations more accurate.
for i in range(num_samples):
  synth_samples.append(pipe(prompt=prompt, guidance_scale=0.0, num_inference_steps=1).images[0])


Now that we have a few samples, let us compare them against the "true" data.

## Evaluating the Quality of Synthetic Datasets

It is vital to know the strengths and weaknesses of your synthetic dataset. In essence, we are interested in four aspects of the generated data:

- Realism: Is the synthetic data indistinguishable from real data?
- Representation: How well represented is the real data among the synthetic samples?
- Variety: How much variety is there among the synthetic samples?
- Novelty: How novel are the synthetic samples?

You can, of course, get an idea by manually looking through the generated samples one by one, and comparing them to ground truth data. This is laborious and not particularly systematic, unless you employ a group of highly trained individuals. This is known as _qualitative evaluation_. Qualitative evaluation is common in assessments of new generative models, but has its limits when it comes to identifying the nuances in synthetic datasets.

On the other end of the evaluation spectrum, there's _quantitative evaluation_. There exist various metrics we can compute to compare two statistical distributions. However, it can be quite challenging to apply these methods to image data, given their high dimensionality. This makes using methods like the [Kolmogorov-Smirnov test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test) impossible as they are cursed by their dimensionality.

To combat this, we use two tricks.

1. We use image embeddings.
2. We use the maximum mean discrepancy (MMD), a method that is better suited for high dimensions.

### Excursion: MMD

The Maximum Mean Discrepancy (MMD) is a measure of the discrepancy between two probability distributions. It quantifies the difference between the means of two distributions in a reproducing kernel Hilbert space (RKHS). MMD is often used as a non-parametric test statistic for assessing whether two sets of samples come from the same distribution.

Given two sets of samples $X = \{x_1, x_2, ..., x_n\}$ and $Y = \{y_1, y_2, ..., y_m\}$, and a reproducing kernel $k(\cdot, \cdot)$, the MMD between the distributions $P_X$ and $P_Y$ they represent is defined as:

$ MMD^2(X, Y) = \frac{1}{n(n-1)} \sum_{i=1}^{n} \sum_{j=1, j \neq i}^{n} k(x_i, x_j) - \frac{2}{nm} \sum_{i=1}^{n} \sum_{j=1}^{m} k(x_i, y_j) + \frac{1}{m(m-1)} \sum_{i=1}^{m} \sum_{j=1, j \neq i}^{m} k(y_i, y_j) $

where:
- $k(\cdot, \cdot)$ is a positive definite kernel function.
- $n$ and $m$ are the number of samples in sets $X$ and $Y$, respectively.

The MMD measures the difference between the empirical kernel mean embeddings of the two distributions in the RKHS induced by the kernel function $k(\cdot, \cdot)$. A smaller MMD indicates that the two distributions are more similar, while a larger MMD suggests greater dissimilarity between them.

### Comparing distributions

In the cells below, we have implemented MMD for you. Take some time to read through the implementation and understand how it relates to the formula presented above.

In [None]:
import torch
import torch.nn.functional as F


def compute_mmd(X, Y, sigma=1.0):
    """
    Compute the Maximum Mean Discrepancy (MMD) between two sets of samples X and Y.

    Parameters:
        X (torch.Tensor): Samples from distribution P.
        Y (torch.Tensor): Samples from distribution Q.
        sigma (float): Standard deviation for the Gaussian kernel. Only used if kernel='rbf'.

    Returns:
        torch.Tensor: MMD between distributions P and Q.
    """
    # Calculate the kernel matrices
    xx, yy, xy = compute_kernel_matrix(X, Y, sigma=sigma)

    # Compute the MMD^2 using the kernel trick
    mmd2 = torch.mean(xx) - 2 * torch.mean(xy) + torch.mean(yy)

    return mmd2


def compute_kernel_matrix(X, Y, sigma=1.0):
    """
    Compute the kernel matrix between two sets of samples.

    Parameters:
        X (torch.Tensor): Samples from distribution P.
        Y (torch.Tensor): Samples from distribution Q.
        sigma (float): Standard deviation for the Gaussian kernel. Only used if kernel='rbf'.

    Returns:
        torch.Tensor: Kernel matrix.
    """
    # Gaussian (RBF) kernel
    xx = torch.exp(
        -torch.pow(torch.norm(X.unsqueeze(1) - X.unsqueeze(0), dim=2), 2)
        / (2 * sigma**2)
    )
    yy = torch.exp(
        -torch.pow(torch.norm(Y.unsqueeze(1) - Y.unsqueeze(0), dim=2), 2)
        / (2 * sigma**2)
    )
    xy = torch.exp(
        -torch.pow(torch.norm(X.unsqueeze(1) - Y.unsqueeze(0), dim=2), 2)
        / (2 * sigma**2)
    )

    return xx, yy, xy

Next, we need a model that can produce image embeddings for use. For this, we use a [tiny vision transformer from HuggingFace hub](https://huggingface.co/timm/tiny_vit_5m_224.dist_in22k).

In [None]:
import timm

model = timm.create_model(
    "tiny_vit_5m_224.dist_in22k",
    pretrained=True,
    num_classes=0,  # remove classifier nn.Linear
)
model = model.eval()

data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)


def embed(image):
    output = model(transforms(image).unsqueeze(0))
    # output is (batch_size, num_features) shaped tensor
    return output

All that is left to do is to take a few true samples and compare them with the generated ones.

In [None]:
import random

sample_idx = random.sample(range(dataset["train"].num_rows), num_samples)
sample_idx

In [None]:
true_samples = dataset["train"][sample_idx]["image"]

In [None]:
synth_samples_embeddings = torch.stack([embed(s) for s in synth_samples])
true_samples_embeddings = torch.stack([embed(s) for s in true_samples])

Finally, let's compute the MMD to see how far the synthetic samples are from the true samples.

In [None]:
compute_mmd(synth_samples_embeddings, true_samples_embeddings)

Now, what does this number tell us? On its own, it is not particularly informative. However, you can use it to guide you data generation process. Above, we generated samples using a promt that does not use any knowledge about the dataset. But you can change this!

For the remainder of this lab, experiment with prompts to make the MMD as small as possible.
For instance:

- Use your knowledge of the classes to generate images that contain these objects.
- Add context to your prompt. Instead of writing "An image of an apple", try things like "An apple on a green table with a cat in the background and silver cutlery next to it."

You can automate the generation of prompts by creating a few templates and then combining them at random! Be creative, maybe a large scale synthetic image generation pipeline could be a candidate for you MLOps project. :)