In [1]:
# Import necessary libraries

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import resnet18
import torchvision.models as models
import numpy as np
from scipy.stats import entropy
from scipy.linalg import sqrtm
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

In [2]:
# Set device

# If GPU available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Metal Performance Shaders Apple's M1/M2/M3 Chips
# device = "mps" if torch.backends.mps.is_available() else "cpu"

print("Device in use:", device)

# Free up memory
torch.cuda.empty_cache()

Device in use: cuda


In [3]:
# Hyperparameters
latent_size = 100

In [4]:
# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_size, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

In [5]:
# Load generator weights
generator = Generator().to(device)
generator.load_state_dict(torch.load("generator_final.pth"))
generator.eval()

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): Tanh()
  )
)

In [6]:
# Load ResNet18 model
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Identity()
resnet = resnet.to(device)
resnet.eval()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
# Extract features from images using ResNet18
def extract_features(model, images):
    features = []
    with torch.no_grad():
        features = model(images)
    return features.cpu().numpy()

In [8]:
# Calculate FID score

def calculate_fid_score(real_features, generated_features):
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = generated_features.mean(axis=0), np.cov(generated_features, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2)
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid_score = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid_score

In [9]:
# Calculate IS score

def calculate_is_score(generated_features):
    preds = torch.nn.functional.softmax(torch.tensor(generated_features), dim=1).numpy()
    kl_divs = []
    for pred in preds:
        kl_div = entropy(pred.mean(axis=0), base=2, axis=None)
        kl_divs.append(kl_div)
    return np.exp(np.mean(kl_divs))

In [10]:
# Generate images using the generator
def generate_images(generator, num_images):
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_images, latent_size, 1, 1, device=device)
        fake_images = generator(noise)
    return fake_images

In [11]:
# Load CIFAR10 dataset for real images
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
cifar10_dataset = CIFAR10(root='./data', download=True, train=False, transform=transform)
real_images_loader = DataLoader(cifar10_dataset, batch_size=100, shuffle=True, num_workers=2)

Files already downloaded and verified


In [12]:
# Generate images
num_images = 10000
generated_images = generate_images(generator, num_images)

In [13]:
# Extract features
generated_features = extract_features(resnet, generated_images)

In [14]:
# Extract features from real images
real_features = []
for batch in real_images_loader:
    batch = batch[0].to(device)
    features = extract_features(resnet, batch)
    real_features.append(features)
real_features = np.concatenate(real_features, axis=0)

  self.pid = os.fork()


In [15]:
# Calculate & view FID score
fid_score = calculate_fid_score(real_features, generated_features)
print(f"FID score: {fid_score}")

FID score: 12.056572421708601


In [16]:
# Calculate & view IS score
is_score = calculate_is_score(generated_features)
print(f"IS score: {is_score}")

IS score: 1.0


# Evaluation Results
## Fréchet Inception Distance
Low distance between the feature distributions of real and generated images, suggests that the generated images are relatively close to the real images in terms of their visual features.
## Inception Score
Generated images are of good quality and exhibit a high degree of diversity in terms of their content and appearance.
## Conclusion
Both scores indicate that the generator is capable of producing good quality and diverse images that resemble relatively close to the real images.