<pre style="text-align: right; width: 100%; font-size: 0.75em; line-height: 0.75em;">
+ ------------------------- + <br>
| 03/05/2025                | <br>
| Héctor Tablero Díaz       | <br>
| Álvaro Martínez Gamo      | <br>
+ ------------------------- + 
</pre>

# **Evaluation (Metrics)**

In [None]:
import sys
sys.path.append('./..')

import os

import torch
from torch import Tensor
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms
import numpy as np

from image_gen import GenerativeModel
from image_gen.samplers import ExponentialIntegrator
from image_gen.diffusion import VariancePreserving
from image_gen.noise import LinearNoiseSchedule
from image_gen.metrics import BaseMetric, BitsPerDimension, FrechetInceptionDistance, InceptionScore

from typing import Dict, List, Optional, Union

import matplotlib.pyplot as plt
from image_gen.visualization import display_images

In [None]:
epoch_values = [1, 5, 25, 100]
train_percent = 0.8
digit = 3
class_id = 1
seed = 1234

In [None]:
data_mnist = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
indices_digit = torch.where(data_mnist.targets == digit)[0]
data_mnist = Subset(data_mnist, indices_digit)
train_size_mnist = int(train_percent * len(data_mnist))
test_size_mnist = len(data_mnist) - train_size_mnist
data_mnist_train, data_mnist_test = torch.utils.data.random_split(data_mnist, [train_size_mnist, test_size_mnist])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
data_cifar = datasets.CIFAR10(
    root='data',
    train=True,
    download=True,
    transform=transform
)
indices_class_id = torch.where(torch.tensor(data_cifar.targets) == class_id)[0]
data_cifar = Subset(data_cifar, indices_class_id)
train_size_cifar = int(train_percent * len(data_cifar))
test_size_cifar = len(data_cifar) - train_size_cifar
data_cifar_train, data_cifar_test = torch.utils.data.random_split(data_cifar, [train_size_cifar, test_size_cifar])

In [None]:
trained_models_mnist = []
trained_models_cifar = []

for epochs in epoch_values:
    model = GenerativeModel(
        diffusion=VariancePreserving,
        noise_schedule=LinearNoiseSchedule,
        sampler=ExponentialIntegrator
    )
    # The "_train" part is added to avoid loading models that have been trained with the full dataset (including test data)
    filename = f'saved_models/mnist_{digit}_train_vp-lin_{epochs}e.pth'
    if os.path.isfile(filename):
        model.load(filename)
    else:
        model.train(data_mnist_train, epochs=epochs)
        model.save(filename)
    trained_models_mnist.append(model)

    model = GenerativeModel(
        diffusion=VariancePreserving,
        noise_schedule=LinearNoiseSchedule,
        sampler=ExponentialIntegrator
    )
    # The "_train" part is added to avoid loading models that have been trained with the full dataset (including test data)
    filename = f'saved_models/cifar10_{class_id}_train_vp-lin_{epochs}e.pth'
    if os.path.isfile(filename):
        model.load(filename)
    else:
        model.train(data_cifar_train, epochs=epochs)
        model.save(filename)
    trained_models_cifar.append(model)

In [None]:
def plot_scores(scores: Dict[str, List[float]], metrics: List[BaseMetric], title: Optional[str] = None):
    fig, axes = plt.subplots(1, len(metrics), figsize=(6 * len(metrics), 5))
    if len(metrics) == 1:
        axes = [axes]

    for ax, metric in zip(axes, metrics):
        name = metric.name
        values = scores[name]
        
        ax.plot(epoch_values, values, 'o-', label=name, color='blue')
        ax.set_title(f"{name} ({'lower' if metric.is_lower_better else 'higher'} is better)")
        ax.set_xlabel('Training Epochs')
        ax.set_ylabel("Value")
        ax.set_xscale("log")
        ax.grid(True)
        
        # Find best point
        best_idx = int(np.argmin(values) if metric.is_lower_better else np.argmax(values))
        best_epoch = epoch_values[best_idx]
        best_value = values[best_idx]
        ax.plot(best_epoch, best_value, 'ro')
        ax.annotate(f'Best: {best_value:.3f}', 
                    xy=(best_epoch, best_value), 
                    xytext=(best_epoch, best_value + 0.05 * (max(values) - min(values))),
                    arrowprops=dict(arrowstyle='->', color='black'))
        
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

## **Overview**

A dict with scores can be obtained by using the following code:

> ```python
> # Get scores for generated samples
> scores = model.score(
>     real=real_data, 
>     generated=fake_samples,
>     scores=["bpd", "fid", "is"]
> )
> ```

**Implemented Metrics:**
| Metric | Full Name | Range | Ideal | Best For |
|--------|-----------|-------|-------|----------|
| [BPD](#bpd) | Bits Per Dimension | $[0, \infty)$ | Lower | Density Estimation |
| [FID](#fid) | Fréchet Inception Distance | $[0, \infty)$ | Lower | Image Quality |
| [IS](#is) | Inception Score | $[1, \infty)$ | Higher | Diversity |

### <span id="bpd">**Bits Per Dimension (BPD)**</span>

**Measurement:** Negative log-likelihood in bits/dimension  
**Interpretation:**
- Lower = Better density modeling
- Sensitive to training stability
- Values < 3.0 generally acceptable

### <span id="fid">**Fréchet Inception Distance (FID)**</span>

**Measurement:** Distance between real/fake feature distributions  
**Interpretation:**
- Lower = Better visual quality
- < 50 = Excellent
- 50-100 = Good
- \> 100 = Needs improvement

### <span id="is">**Inception Score (IS)**</span>

**Measurement:** KL divergence between conditional and marginal class distributions  
**Interpretation:**
- Higher = Better diversity/quality
- \> 10 = Excellent
- 5-10 = Good
- < 5 = Poor diversity

### **Metric Comparison by Epochs**

#### CIFAR10

In [None]:
scores_cifar = {}

for model in trained_models_cifar:
    metrics = [BitsPerDimension(model), FrechetInceptionDistance(model), InceptionScore(model)]
    samples = model.generate(16, seed=seed)
    scores = model.score(data_cifar_test, samples, metrics=metrics)
    for score in scores:
        if score not in scores_cifar:
            scores_cifar[score] = []
        scores_cifar[score].append(scores[score])

# Show the last generated images
display_images(samples)

In [None]:
plot_scores(scores_cifar, metrics, "CIFAR10 - Scores vs Training Epochs")

#### MNIST

In [None]:
scores_mnist = {}

for model in trained_models_mnist:
    metrics = [BitsPerDimension(model), FrechetInceptionDistance(model), InceptionScore(model)]
    samples = model.generate(16, seed=seed)
    scores = model.score(data_mnist_test, samples, metrics=metrics)
    for score in scores:
        if score not in scores_mnist:
            scores_mnist[score] = []
        scores_mnist[score].append(scores[score])

# Show the last generated images
display_images(samples)

In [None]:
plot_scores(scores_mnist, metrics, "MNIST - Scores vs Training Epochs")

Contrary to the expectations from the `CIFAR10` test, where scores got better the longer the model had trained for, in `MNIST` we find that `BitsPerDimension` and `InceptionScore` don't show the best results with the most epochs.

This highlights some of the problems to take into account:

**BPD:**
1. BPD measures likelihood, which doesn't always correlate with sample quality in diffusion models
2. As the model focuses on generating clear digit patterns, it may actually assign lower likelihood to some variations in the real data distribution
3. Diffusion models can exhibit the "likelihood training paradox" where improved sample quality comes at the cost of worse likelihood scores

**IS:**
1. IS measures both quality and diversity simultaneously
2. As training progresses, your model may be generating more accurate digits but with less stylistic variation
3. The Inception network wasn't designed for MNIST-type images, making IS less reliable for digit evaluation

This pattern of inconsistent metrics despite visual improvement is well-documented. Here's what should be taken into consideration:

1. FID is the most reliable metric for grayscale diffusion and shows clear improvement
2. Visual inspection remains crucial for evaluating digit generation quality
3. BPD in particular can be misleading for evaluating sample quality in diffusion models

Each problem can be solved through various alternatives. For example, to test digit quality, a custom metric that relies on OCR (Optical Character Recognition) to try to read the numbers would produce results more similar to human evaluation.

### **Creating Custom Metrics**

Custom metrics can be created by inheriting from `BaseMetric`. The methods that must be implemented are `__call__`, `name` and `is_lower_better`.

#### Implementation Example (OCR Metric for MNIST)

In [None]:
from PIL import Image
import easyocr

class DigitOCRMetric(BaseMetric):
    def __init__(self,
                 model: GenerativeModel,
                 digit: Union[int, str] = 3,
                 display: bool = False # Added as part of the demonstration in this notebook (a real implementation would not include this nor any of the code to display the images)
                 ):
        super().__init__(model)
        self.digit = str(digit)
        self.display = display
        self.reader = easyocr.Reader(['en'])

    def config(self) -> dict:
        return {
            "digit": self.digit
        }

    @property
    def name(self) -> str:
        return f"Digit {self.digit} OCR Accuracy"

    @property
    def is_lower_better(self) -> bool:
        return False

    def _batch_to_pil(self, batch: Tensor) -> list[Image.Image]:
        batch = batch.clamp(0, 1).mul(255).to(torch.uint8)
        imgs = []
        for img in batch:
            if img.shape[0] == 1:
                arr = img[0].cpu().numpy()
                pil = Image.fromarray(arr, mode="L")
            else:
                arr = img.permute(1, 2, 0).cpu().numpy()
                pil = Image.fromarray(arr)
            imgs.append(pil.convert("L"))
        return imgs

    def __call__(self,
                 _,
                 generated: Union[Tensor, torch.utils.data.Dataset],
                 *args,
                 **kwargs) -> float:

        if not isinstance(generated, Tensor):
            dl = DataLoader(generated, batch_size=64, shuffle=False)
        else:
            dl = [(generated, )]

        total = 0
        correct = 0
        outlines = []

        for batch_tuple in dl:
            batch = batch_tuple[0] if isinstance(batch_tuple, (list, tuple)) else batch_tuple
            batch = batch.to("cpu")
            pil_images = self._batch_to_pil(batch)
            batch_correct = []

            for img in pil_images:
                img_np = np.array(img)

                result = self.reader.readtext(img_np)
                text = ''.join([word_info[1] for word_info in result])
                is_correct = self.digit in text
                batch_correct.append(is_correct)
                if is_correct:
                    correct += 1
            total += len(pil_images)

            if self.display:
                outlines.append((batch, batch_correct))

        acc = correct / total if total > 0 else float("nan")

        if self.display and outlines:
            outlined_batches = []
            for batch, corr in outlines:
                b = batch
                if b.shape[1] == 1:
                    b = b.repeat(1, 3, 1, 1)
                for i, right in enumerate(corr):
                    color = torch.tensor([0, 255, 0], dtype=torch.uint8) if right else torch.tensor([255, 0, 0], dtype=torch.uint8)
                    b[i, :, 0, :] = color[:, None]
                    b[i, :, -1, :] = color[:, None]
                    b[i, :, :, 0] = color[:, None]
                    b[i, :, :, -1] = color[:, None]
                outlined_batches.append(b)
            generated_with_outline = torch.cat(outlined_batches, dim=0)
            display_images(generated_with_outline)

        return acc

In [None]:
scores_mnist_ocr = {}

for model in trained_models_mnist:
    metrics = [
        DigitOCRMetric(model, digit=digit, display=True)
    ]
    samples = model.generate(16, seed=seed)
    scores = model.score(data_mnist_test, samples, metrics=metrics)
    for score in scores:
        if score not in scores_mnist_ocr:
            scores_mnist_ocr[score] = []
        scores_mnist_ocr[score].append(scores[score])

plot_scores(scores_mnist_ocr, metrics, "MNIST - Scores vs Training Epochs")