In [7]:
import torch
import numpy as np

from skimage.io import imread, imsave
from tqdm.auto import trange, tqdm
from torchvision.datasets import MNIST
from pytorch_fid import fid_score

from data_generator import DataGenerator
from default_mnist_config import create_default_mnist_config
from diffusion import DiffusionRunner
from models.classifier import ResNet, ResidualBlock, ConditionalResNet

from matplotlib import pyplot as plt

import os
import shuitl

# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device('cuda')

#### Определим папку с настоящими картинками

In [3]:
def create_dir(path: str):
    if not os.path.exists(path):
        os.makedirs(path)

In [6]:
create_dir('../real_images_MNIST')

real_dataset = MNIST(root='../data', download=True, train=True, transform=Compose([Resize((32, 32))]))
for idx, (image_mnist, label) in enumerate(tqdm(real_dataset, total=len(real_dataset))):
    image = np.array(image_mnist)
    imsave("../real_images_MNIST/{}.png".format(idx), image)

  0%|          | 0/60000 [00:00<?, ?it/s]

#### Определим папку для синтетических картинок и сгенерируем 60к картинок

In [8]:
uncond_diff = DiffusionRunner(create_default_mnist_config(), eval=True)

In [None]:
create_dir('../uncond_mnist')

TOTAL_IMAGES_COUNT = 60_000
BATCH_SIZE = 200
NUM_ITERS = TOTAL_IMAGES_COUNT // BATCH_SIZE

global_idx = 0
for idx in trange(NUM_ITERS):
    images: torch.Tensor = uncond_diff.sample_images(batch_size=BATCH_SIZE).cpu()
    images = images.permute(0, 2, 3, 1).data.numpy().astype(np.uint8)

    for i in range(len(images)):
        imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
        global_idx += 1

  0%|          | 0/300 [00:00<?, ?it/s]

In [None]:
fid_value = fid_score.calculate_fid_given_paths(
    paths=['real_images_MNIST', 'uncond_mnist'],
    batch_size=200,
    device=device,
    dims=2048
)
fid_value

> Какой фид получился? Сравните FID для безусловной генерации и для условной. Сгенерируйте для каждого класса по 6к картинок и посчитайте FID между реальными и условно сгенерированными картинками.

Разобьем MNIST по папкам классов:

In [None]:
real_dataset = MNIST(root='../data', download=True, train=True, transform=Compose([Resize((32, 32))]))
for idx, (image_mnist, label) in enumerate(tqdm(real_dataset, total=len(real_dataset))):
    image = np.array(image_mnist)
    imsave("../real_images_MNIST/{}.png".format(idx), image)

Сгенерируем картинки условной генерацией:

In [None]:
classifier_args = {
    "block": ResidualBlock,
    "layers": [2, 2, 2, 2]
}
noisy_classifier = ConditionalResNet(**classifier_args)
noisy_classifier.to(device)

noisy_classifier.load_state_dict(torch.load('./ddpm_checkpoints/classifier.pth'))
noisy_classifier.eval()

cond_diff = DiffusionRunner(create_default_mnist_config(), eval=True)
cond_diff.set_classifier(noisy_classifier, T=1.)
create_dir('../cond_mnist')

In [None]:
def sample_class(class_num: int):
    TOTAL_IMAGES_COUNT = 6_000
    BATCH_SIZE = 200
    NUM_ITERS = TOTAL_IMAGES_COUNT // BATCH_SIZE
    dir_name = f'../cond_mnist/{class_num}'
    create_dir(dir_name)

    global_idx = 0
    for idx in trange(NUM_ITERS):
        images: torch.Tensor = cond_diff.sample_images(batch_size=BATCH_SIZE).cpu()
        images = images.permute(0, 2, 3, 1).data.numpy().astype(np.uint8)

        for i in range(len(images)):
            imsave(os.path.join('dir_name', f'{global_idx}.png'), images[i])
            global_idx += 1

In [None]:
for i in range(10):
    sample_class(i)
    fid_value = fid_score.calculate_fid_given_paths(
        paths=[f'real_images_MNIST/{i}', 'cond_mnist/{i}'],
        batch_size=200,
        device=device,
        dims=2048
    )
    print(f"Class {i} FID = {fid_value}")

Посчитаем общий FID:

In [None]:
create_dir('cond_mnist/all')
for i in range(10):
    files = os.listdir('cond_mnist/{i}')
    for f in files:
        shutil.move(f, 'cond_mnist/all')

In [None]:
fid_value = fid_score.calculate_fid_given_paths(
    paths=['real_images_MNIST', 'cond_mnist/all'],
    batch_size=200,
    device=device,
    dims=2048
)
fid_value