## THIS CODE DOWNLOADS MNIST AS PNG IMAGES (10000 SAMPLES)

In [9]:
import os
from torchvision.datasets import MNIST
from PIL import Image


# Download MNIST dataset
dataset = MNIST(root='./data', train=True, download=True)

os.makedirs('mnist_png', exist_ok=True)
# Create a single directory to save all images
output_dir = 'mnist_png_small'
os.makedirs(output_dir, exist_ok=True)

# Save each image as PNG in the same directory
for index, (image, label) in enumerate(dataset):
    img_path = f'{output_dir}/{label}_{index}.png'  # Include label in the filename to avoid overwrites
    image.save(img_path, 'PNG')  # Save the PIL image directly

    if index >= 10000 :  # Progress update every 1000 images
        print(f'Saved {index} images...')
        break
print("All images have been saved in the 'mnist_png/train' directory.")

Saved 10000 images...
All images have been saved in the 'mnist_png/train' directory.


## HERE YOU NEED TO PUT YOUR ARCHITECTURE AND THEN GENERATE THE SAMPLES (YOU NEED TO LOAD TTHE CHECKPOINTS ALSO)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

from tqdm import trange
import argparse
from torchvision import datasets, transforms, models
import torch.optim as optim
from torchvision.utils import save_image
import torchvision
import argparse

class Generator(nn.Module):
    def __init__(self, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(100, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)

        x = torch.tanh(self.fc4(x))
        #print("gen", x.shape)
        x = x.view(x.shape[0], 1, 28, 28)
        return x


class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 512)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    # forward method
    def forward(self, x):
        x = x.view(x.shape[0], -1).cuda()
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.fc4(x)
        #print("des ", x.shape)
        return x #torch.sigmoid(self.fc4(x))

def Descriminator_train(x, G, D, D_optimizer, clip_value=0.01):
    #=======================Train the discriminator=======================#

    D_optimizer.zero_grad()

    # train discriminator on real
    x_real = x
    x_real = x_real.cuda()


    # train discriminator on fake
    z = torch.randn(x.shape[0], 100).cuda()
    x_fake = G(z).detach()

    # gradient backprop & optimize ONLY D's parameters
    D_loss = -torch.mean(D(x_real)) + torch.mean(D(x_fake))
    D_loss.backward()
    D_optimizer.step()

    # Clip weights of discriminator
    for p in D.parameters():
        p.data.clamp_(-clip_value, clip_value)

    return  D_loss.data.item()


def Generator_train(x, G, D, G_optimizer):
    #=======================Train the generator=======================#
    G_optimizer.zero_grad()

    z = torch.randn(x.shape[0], 100).cuda()

    G_output = G(z)

    G_loss = -torch.mean(D(G_output))

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()

    return G_output

def save_models(G, D, folder):
    torch.save(G.state_dict(), os.path.join(folder,'G.pth'))
    torch.save(D.state_dict(), os.path.join(folder,'D.pth'))


def load_model(G, folder):
    ckpt = torch.load(os.path.join(folder,'G.pth'))
    G.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()})
    return G


batch_size = 2048
print('Model Loading...')
# Model Pipeline
mnist_dim = 784

model = Generator(g_output_dim = mnist_dim).cuda()
model = load_model(model, 'checkpoints')
model.eval()

print('Model loaded.')

print('Start Generating')
os.makedirs('samples', exist_ok=True)

n_samples = 0
with torch.no_grad():
    while n_samples<10000:
        z = torch.randn(batch_size, 100).cuda()
        x = model(z)
        #x = x.view(batch_size,1, 28, 28)
        for k in range(x.shape[0]):
            if n_samples<10000:
                save_image(x[k], "samples/%d.png" % k, normalize=True)
                #torchvision.utils.save_image(x[k:k+1], os.path.join('samples', f'{n_samples}.png'), normalize=True)
                n_samples += 1

Model Loading...
Model loaded.
Start Generating


  ckpt = torch.load(os.path.join(folder,'G.pth'))


## COMPUTE FID

In [5]:
!pip install pytorch-fid

Collecting pytorch-fid
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Downloading pytorch_fid-0.3.0-py3-none-any.whl (15 kB)
Installing collected packages: pytorch-fid
Successfully installed pytorch-fid-0.3.0


In [10]:
!python -m pytorch_fid mnist_png_small samples

100% 201/201 [00:40<00:00,  4.99it/s]
100% 41/41 [00:08<00:00,  5.02it/s]
FID:  100.02741126259218


## COMPUTE PRECISION AND RECALL

In [15]:
#https://github.com/youngjung/improved-precision-and-recall-metric-pytorch/blob/master/improved_precision_recall.py

#You need first to charge this file

! python improved_precision_recall.py mnist_png_small samples

done
extracting features of 5000 images: 100% 100/100 [00:27<00:00,  3.63it/s]
extracting features of 2048 images: 100% 41/41 [00:11<00:00,  3.61it/s]
computing precision...: 100% 2048/2048 [00:00<00:00, 14714.31it/s]
computing recall...: 100% 5000/5000 [00:00<00:00, 103675.70it/s]
precision: 0.14208984375
recall: 0.1234
found 1 images in mnist_png_small
Traceback (most recent call last):
  File "/content/improved_precision_recall.py", line 398, in <module>
    first_image = iter(dataloader).next()
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute 'next'
