In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

from tqdm import trange
import argparse
from torchvision import datasets, transforms, models
import torch.optim as optim
from torchvision.utils import save_image
import torchvision
import argparse

KeyboardInterrupt: 

In [None]:
class Generator(nn.Module):
    def __init__(self, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(100, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)

        x = torch.tanh(self.fc4(x))
        #print("gen", x.shape)
        x = x.view(x.shape[0], 1, 28, 28)
        return x


class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 512)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    # forward method
    def forward(self, x):
        x = x.view(x.shape[0], -1).cuda()
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.fc4(x)
        #print("des ", x.shape)
        return x #torch.sigmoid(self.fc4(x))



In [None]:
def Descriminator_train(x, G, D, D_optimizer, clip_value=0.01):
    #=======================Train the discriminator=======================#

    D_optimizer.zero_grad()

    # train discriminator on real
    x_real = x
    x_real = x_real.cuda()


    # train discriminator on fake
    z = torch.randn(x.shape[0], 100).cuda()
    x_fake = G(z).detach()

    # gradient backprop & optimize ONLY D's parameters
    D_loss = -torch.mean(D(x_real)) + torch.mean(D(x_fake))
    D_loss.backward()
    D_optimizer.step()

    # Clip weights of discriminator
    for p in D.parameters():
        p.data.clamp_(-clip_value, clip_value)

    return  D_loss.data.item()


def Generator_train(x, G, D, G_optimizer):
    #=======================Train the generator=======================#
    G_optimizer.zero_grad()

    z = torch.randn(x.shape[0], 100).cuda()

    G_output = G(z)

    G_loss = -torch.mean(D(G_output))

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()

    return G_output

def save_models(G, D, folder):
    torch.save(G.state_dict(), os.path.join(folder,'G.pth'))
    torch.save(D.state_dict(), os.path.join(folder,'D.pth'))


def load_model(G, folder):
    ckpt = torch.load(os.path.join(folder,'G.pth'))
    G.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()})
    return G



In [None]:
lr = 0.0002
batch_size  = 64
epochs = 200

os.makedirs('checkpoints', exist_ok=True)
os.makedirs('data', exist_ok=True)

# Data Pipeline
print('Dataset loading...')
# MNIST Dataset
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root='data/MNIST/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='data/MNIST/', train=False, transform=transform, download=False)


train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, shuffle=False)
print('Dataset Loaded.')


print('Model Loading...')
mnist_dim = 784
G = Generator(g_output_dim = mnist_dim).cuda()
D = Discriminator(mnist_dim).cuda()

# model = DataParallel(model).cuda()
print('Model loaded.')

# define optimizers
G_optimizer = optim.RMSprop(G.parameters(), lr = lr)
D_optimizer = optim.RMSprop(D.parameters(), lr = lr)

print('Start Training :')
os.makedirs('images', exist_ok=True)

n_epoch = epochs

for epoch in trange(1, n_epoch+1, leave=True):
    for batch_idx, (x, _) in enumerate(train_loader):

        #Train discriminator
        Descriminator_train(x, G, D, D_optimizer)

        if batch_idx % 5 == 0 :
            #Train generator
            z = Generator_train(x, G, D, G_optimizer)

    # z = z.view(z.shape[0], 28, 28)
    # save_image(z.data[:25], "images/epoch_%d.png" % epoch, nrow=5, normalize=True)

    if epoch % 10 == 0:
        save_models(G, D, 'checkpoints')

print('Training done')

In [None]:

batch_size = 2048
print('Model Loading...')
# Model Pipeline
mnist_dim = 784

model = Generator(g_output_dim = mnist_dim).cuda()
model = load_model(model, 'checkpoints')
model.eval()

print('Model loaded.')

print('Start Generating')
os.makedirs('samples', exist_ok=True)

n_samples = 0
with torch.no_grad():
    while n_samples<10000:
        z = torch.randn(batch_size, 100).cuda()
        x = model(z)
        #x = x.view(batch_size,1, 28, 28)
        for k in range(x.shape[0]):
            if n_samples<10000:
                save_image(x[k], "samples/%d.png" % k, normalize=True)
                #torchvision.utils.save_image(x[k:k+1], os.path.join('samples', f'{n_samples}.png'), normalize=True)
                n_samples += 1

In [None]:
import os
import torch
import numpy as np
from scipy.linalg import sqrtm
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.io import read_image
from torch import nn

# Define a custom dataset class to load the generated images
class GeneratedImagesDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = sorted([f for f in os.listdir(root_dir) if f.endswith('.png')])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = read_image(img_path).float() / 255.0  # Normalize image to [0, 1]
        if self.transform:
            image = self.transform(image)
        return image

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5), std=(0.5))])

# Define the transform for tensor images
transform_generate = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Ensure 1 channel
    transforms.Resize((28, 28)),
    transforms.Normalize(mean=(0.5), std=(0.5))# Resize to match model's expected input size
])

batch_size = 64

mnist_dataset = datasets.MNIST(root='data/MNIST/', train=True, transform=transform, download=False)

mnist_loader = torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=batch_size, shuffle=False)

# Load the generated images
generated_dataset = GeneratedImagesDataset(root_dir="samples", transform=transform_generate)
generated_loader = DataLoader(generated_dataset, batch_size=batch_size, shuffle=False)


# Define your feature extractor model
class MNISTFeatureExtractor(nn.Module):
    def __init__(self):
        super(MNISTFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = self.fc1(x)
        return x  # 500-dimensional feature vector

feature_extractor = MNISTFeatureExtractor()
feature_extractor.eval()

# Function to get feature embeddings
def get_embeddings(loader, model):
    embeddings = []
    for (imgs, _) in loader:
        imgs = imgs.cuda()  # Move to device (CPU or GPU)
        with torch.no_grad():
            features = model(imgs)
            embeddings.append(features.cpu().numpy())  # Move to CPU for numpy operations
    return np.concatenate(embeddings)

# Function to get feature embeddings
def get_embeddings_generated(loader, model):
    embeddings = []
    for imgs in loader:
        imgs = imgs.cuda()  # Move to device (CPU or GPU)
        with torch.no_grad():
            features = model(imgs)
            embeddings.append(features.cpu().numpy())  # Move to CPU for numpy operations
    return np.concatenate(embeddings)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
feature_extractor = feature_extractor.to(device)

# Get embeddings for real and generated samples
real_embeddings = get_embeddings(mnist_loader, feature_extractor)
generated_embeddings = get_embeddings_generated(generated_loader, feature_extractor)

# Define precision, recall, and FID calculation functions (same as before)
def precision_recall(real_features, gen_features, k=3):
    nbrs_real = NearestNeighbors(n_neighbors=k).fit(real_features)
    distances_real, _ = nbrs_real.kneighbors(gen_features)

    nbrs_gen = NearestNeighbors(n_neighbors=k).fit(gen_features)
    distances_gen, _ = nbrs_gen.kneighbors(real_features)

    precision = np.mean(np.min(distances_real, axis=1) <= np.mean(distances_real))
    recall = np.mean(np.min(distances_gen, axis=1) <= np.mean(distances_gen))

    return precision, recall

def calculate_fid(real_features, gen_features):
    mu_real, sigma_real = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu_gen, sigma_gen = gen_features.mean(axis=0), np.cov(gen_features, rowvar=False)

    diff = mu_real - mu_gen
    covmean, _ = sqrtm(sigma_real.dot(sigma_gen), disp=False)
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff.dot(diff) + np.trace(sigma_real + sigma_gen - 2 * covmean)
    return fid

# Calculate metrics
precision, recall = precision_recall(real_embeddings, generated_embeddings, k=3)
fid = calculate_fid(real_embeddings, generated_embeddings)

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"FID: {fid:.4f}")
