In [None]:
import torch
import numpy as np

from skimage.io import imread, imsave
from tqdm.auto import trange, tqdm
from torchvision.datasets import MNIST
from pytorch_fid import fid_score
import torchvision.transforms as T

from data_generator import DataGenerator
from default_mnist_config import create_default_mnist_config
from diffusion import DiffusionRunner
from models.classifier import ResNet, ResidualBlock, ConditionalResNet

from matplotlib import pyplot as plt

import os
import shutil

# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device('cuda')

#### Определим папку с настоящими картинками

In [None]:
def create_dir(path: str):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
create_dir('../real_images_MNIST')

real_dataset = MNIST(root='../data', download=True, train=True, transform=T.Compose([T.Resize((32, 32))]))
for idx, (image_mnist, label) in enumerate(tqdm(real_dataset, total=len(real_dataset))):
    image = np.array(image_mnist)
    imsave("../real_images_MNIST/{}.png".format(idx), image)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 265488159.36it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 34659712.11it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 119204567.57it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 13359417.09it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw






  0%|          | 0/60000 [00:00<?, ?it/s]

#### Определим папку для синтетических картинок и сгенерируем 60к картинок

In [None]:
uncond_diff = DiffusionRunner(create_default_mnist_config(), eval=True)

In [None]:
create_dir('../uncond_mnist')

TOTAL_IMAGES_COUNT = 60_000
BATCH_SIZE = 200
NUM_ITERS = TOTAL_IMAGES_COUNT // BATCH_SIZE

global_idx = 0
for idx in trange(NUM_ITERS):
    images: torch.Tensor = uncond_diff.sample_images(batch_size=BATCH_SIZE).cpu()
    images = images.permute(0, 2, 3, 1).data.numpy().astype(np.uint8)

    for i in range(len(images)):
        imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
        global_idx += 1

  0%|          | 0/300 [00:00<?, ?it/s]

  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])


In [None]:
fid_value = fid_score.calculate_fid_given_paths(
        paths=['../real_images_MNIST', '../uncond_mnist'],
        batch_size=200,
        device=device,
        dims=2048
    )
fid_value

Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:02<00:00, 44.8MB/s]
100%|██████████| 300/300 [00:29<00:00, 10.13it/s]
100%|██████████| 300/300 [00:29<00:00, 10.15it/s]


18.758689161152915

> Какой фид получился? Сравните FID для безусловной генерации и для условной. Сгенерируйте для каждого класса по 6к картинок и посчитайте FID между реальными и условно сгенерированными картинками.

18.8

#### Условная генерация

Разобьем MNIST по папкам классов:

In [None]:
for i in range(10):
    create_dir(f'../cond_mnist/{i}')
    create_dir(f'../real_images_MNIST/{i}')

In [None]:
for idx, (image_mnist, label) in enumerate(tqdm(real_dataset, total=len(real_dataset))):
    image = np.array(image_mnist)
    imsave(f"../real_images_MNIST/{label}/{idx}.png", image)

  0%|          | 0/60000 [00:00<?, ?it/s]

Сгенерируем картинки условной генерацией:

In [None]:
classifier_args = {
    "block": ResidualBlock,
    "layers": [2, 2, 2, 2]
}
noisy_classifier = ConditionalResNet(**classifier_args)
noisy_classifier.to(device)

noisy_classifier.load_state_dict(torch.load('./ddpm_checkpoints/classifier.pth'))

T = 1.0 

cond_diff = DiffusionRunner(create_default_mnist_config(), eval=True)
cond_diff.set_classifier(noisy_classifier.eval(), T=T)
create_dir('../cond_mnist')

In [None]:
def sample_class(class_num: int):
    TOTAL_IMAGES_COUNT = 6_000
    BATCH_SIZE = 200
    NUM_ITERS = TOTAL_IMAGES_COUNT // BATCH_SIZE
    dir_name = f'../cond_mnist/{class_num}'

    global_idx = 0
    for idx in range(NUM_ITERS):
        y = class_num * torch.ones(BATCH_SIZE, dtype=torch.long)
        images: torch.Tensor = cond_diff.sample_images(batch_size=BATCH_SIZE, labels=y).cpu()
        images = images.permute(0, 2, 3, 1).data.numpy().astype(np.uint8)

        for i in range(len(images)):
            imsave(os.path.join(dir_name, f'{global_idx}.png'), images[i])
            global_idx += 1

In [None]:
fid_values = []
for i in trange(10):
    sample_class(i)
    fid_value = fid_score.calculate_fid_given_paths(
        paths=[f'../real_images_MNIST/{i}', f'../cond_mnist/{i}'],
        batch_size=200,
        device=device,
        dims=2048
    )
    fid_values.append(fid_value)

In [None]:
for i in range(10):
    print(f"Class {i} FID = {fid_values[i]}")

Class 0 FID = 18.257824641622506
Class 1 FID = 26.605954668569055
Class 2 FID = 21.13187946369723
Class 3 FID = 21.40748077353132
Class 4 FID = 22.125964517526683
Class 5 FID = 22.299480751186366
Class 6 FID = 25.707515630812395
Class 7 FID = 20.083946369352816
Class 8 FID = 24.73293477124298
Class 9 FID = 24.6966394821155


Посчитаем общий FID:

In [None]:
create_dir('../cond_mnist/all')
for i in range(10):
    files = os.listdir(f'../cond_mnist/{i}')
    for file_name in files:
        shutil.move(f'../cond_mnist/{i}/{file_name}', f'../cond_mnist/all/{file_name}')

In [None]:
fid_value = fid_score.calculate_fid_given_paths(
    paths=['../real_images_MNIST', '../cond_mnist/all'],
    batch_size=200,
    device=device,
    dims=2048
)
fid_value

100%|██████████| 300/300 [00:29<00:00, 10.17it/s]
100%|██████████| 30/30 [00:03<00:00,  9.65it/s]


43.110569191291546

К сожалению получилось больше, чем при безусловной оптимизации. Стоит учесть, что мы не занимались подбором гиперпараметра температуры (веса классификатора в скоре). При слишком большой температуре картинки не будут похожи на соответствующие классы (т.е. результат почти как при безусловной генерации), но при недостаточно большой температуре, картинки будут слишком однообразными что уменьшит FID.