In [2]:
import torch
import torchvision
from torchvision import transforms
import numpy as np


image_path = './'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))
])
mnist_dataset = torchvision.datasets.MNIST(root=image_path,
                                           train=True,
                                           transform=transform,
                                           download=True)

batch_size = 64

torch.manual_seed(1)
np.random.seed(1)

## Set up the dataset
from torch.utils.data import DataLoader
mnist_dl = DataLoader(mnist_dataset, batch_size=batch_size,
                      shuffle=True, drop_last=True)


In [5]:
import os
import numpy as np
from torchvision.datasets import MNIST
from PIL import Image

# Replace with the path to your MNIST data loader
mnist_data_loader = mnist_dl

# Define the folder where you want to save the images
output_folder = 'mnist_images'

# Create the output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

# Iterate through the DataLoader and save the images
for batch_idx, (images, labels) in enumerate(mnist_data_loader):
    for i in range(len(images)):
        image = images[i].numpy()
        image = (image * 255).astype(np.uint8)
        image = Image.fromarray(image[0], mode='L')  # Convert to grayscale PIL image
        image.save(os.path.join(output_folder, f'image_{batch_idx * len(images) + i:04d}.png'))

print(f"Saved {len(mnist_data_loader) * len(images)} images to {output_folder}.")


Saved 59968 images to mnist_images.


In [None]:
from pytorch_fid import fid_score
from pytorch_fid.inception import InceptionV3

# Paths to the folders containing real and generated images
real_images_path = 'path_to_real_images_folder'
generated_images_path = 'path_to_generated_images_folder'

# Create instances of the InceptionV3 model for both datasets
real_dataset = InceptionV3([real_images_path], device='cuda', batch_size=50)
generated_dataset = InceptionV3([generated_images_path], device='cuda', batch_size=50)

# Calculate the FID score
fid = fid_score.calculate_fid(real_dataset, generated_dataset)

print(f'FID Score: {fid:.2f}')
