In [1]:
import torch
import numpy as np

from skimage.io import imread, imsave
from tqdm.auto import trange, tqdm
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Resize
from pytorch_fid import fid_score

from data_generator import DataGenerator
from default_mnist_config import create_default_mnist_config
from diffusion import DiffusionRunner
from models.classifier import ResNet, ResidualBlock, ConditionalResNet

from matplotlib import pyplot as plt

import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
%cd ..

/home/pasha/ml/mmp/neuralbayes/n4


In [3]:
def create_dir(path: str):
    if not os.path.exists(path):
        os.makedirs(path)

In [4]:
create_dir('./real_images_MNIST')

real_dataset = MNIST(root='./data', download=True, train=True, transform=Compose([Resize((32, 32))]))
for idx, (image_mnist, label) in enumerate(tqdm(real_dataset, total=len(real_dataset))):
    image = np.array(image_mnist)
    imsave("./real_images_MNIST/{}.png".format(idx), image)

  0%|          | 0/60000 [00:00<?, ?it/s]

#### Определим папку для синтетических картинок и сгенерируем 60к картинок

##### Безуслованя генерация

In [5]:
uncond_diff = DiffusionRunner(create_default_mnist_config(), eval=True)

In [8]:
create_dir('./uncond_mnist')

TOTAL_IMAGES_COUNT = 10_000
BATCH_SIZE = 200
NUM_ITERS = TOTAL_IMAGES_COUNT // BATCH_SIZE

global_idx = 0
for idx in trange(NUM_ITERS):
    images: torch.Tensor = uncond_diff.sample_images(batch_size=BATCH_SIZE).cpu()
    images = images.permute(0, 2, 3, 1).data.numpy().astype(np.uint8)

    for i in range(len(images)):
        imsave(os.path.join('./uncond_mnist', f'{global_idx}.png'), images[i])
        global_idx += 1

  0%|          | 0/50 [00:00<?, ?it/s]

  imsave(os.path.join('./uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('./uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('./uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('./uncond_mnist', f'{global_idx}.png'), images[i])


In [10]:
fid_value = fid_score.calculate_fid_given_paths(
    paths=['./real_images_MNIST', './uncond_mnist'],
    batch_size=200,
    device='cuda:0',
    dims=2048
)
fid_value

Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /home/pasha/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth


  0%|          | 0.00/91.2M [00:00<?, ?B/s]

100%|██████████| 300/300 [01:00<00:00,  4.96it/s]
100%|██████████| 50/50 [00:10<00:00,  4.82it/s]


118.81341380857594

##### Условная генерация

In [11]:
device = torch.device('cuda')
classifier_args = {
    "block": ResidualBlock,
    "layers": [2, 2, 2, 2]
}
noisy_classifier = ConditionalResNet(**classifier_args)
noisy_classifier.to(device)

noisy_classifier.load_state_dict(torch.load('./ddpm_checkpoints/classifier_1.pth'))

conditional_diffusion = DiffusionRunner(create_default_mnist_config(), eval=True)
conditional_diffusion.set_classifier(noisy_classifier, T=0.25)

In [12]:
create_dir('./cond_mnist')

TOTAL_IMAGES_COUNT = 10_000
BATCH_SIZE = 1_000
NUM_ITERS = TOTAL_IMAGES_COUNT // BATCH_SIZE

global_idx = 0
for idx in trange(NUM_ITERS):
    labels = idx * torch.ones(100, dtype=torch.long).to(device)

    images: torch.Tensor = uncond_diff.sample_images(batch_size=BATCH_SIZE, labels=labels).cpu()
    images = images.permute(0, 2, 3, 1).data.numpy().astype(np.uint8)

    for i in range(len(images)):
        imsave(os.path.join('./cond_mnist', f'{global_idx}.png'), images[i])
        global_idx += 1

  0%|          | 0/10 [00:00<?, ?it/s]

In [13]:
fid_value = fid_score.calculate_fid_given_paths(
    paths=['./real_images_MNIST', './cond_mnist'],
    batch_size=200,
    device=device,
    dims=2048
)
fid_value

100%|██████████| 300/300 [01:00<00:00,  4.93it/s]
100%|██████████| 5/5 [00:01<00:00,  4.41it/s]


125.28347629120378

> Какой фид получился? Сравните FID для безусловной генерации и для условной. Сгенерируйте для каждого класса по 6к картинок и посчитайте FID между реальными и условно сгенерированными картинками.

При безусловной генерации FID = 118, при условной 125. Качество сравнимое но при условной генерации все таки хуже. В целом FID для MNIST нормальный.