In [7]:
import torch
from dc_gan import Generator
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = 'TRUE'
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

#hyper-parameters
learning_rate = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64

transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

# If you train on MNIST, remember to set channels_img to 1
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, download=True)

cuda


In [9]:
# Loading MNIST weights -- DCGAN

gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
gen.load_state_dict(torch.load('mnist_dcgan.pth', map_location=device))
gen.eval()

Generator(
  (generator): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1))
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding

In [10]:
from torch.utils.data import DataLoader
#dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = datasets.MNIST(root="dataset/", train=False, transform=transforms, download=True)
test_dataloader = DataLoader(test_dataset, batch_size=512, shuffle=True)


In [11]:
import torch.nn.functional as F

In [12]:
from evaluation_metrics import fid_score, PartialInceptionNetwork

#define FID Model
model = PartialInceptionNetwork().to(device)

#this is only for one particular batch --- Calculation of FID
FID_Score = 0
i=0
for imgs, labels in test_dataloader:
    images = imgs.repeat(1, 3, 1, 1)
    real_images = F.interpolate(images, size = (299, 299))

    with torch.no_grad():
        noise = torch.randn(128, NOISE_DIM, 1, 1).to(device)
        generated_img = gen(noise)
    gen_images = F.interpolate(generated_img, size = (299, 299))
    gen_images = gen_images.repeat(1, 3, 1, 1)

    fid_score_batch = fid_score(real_images, gen_images, imgs.shape[0], model)
    FID_Score += fid_score_batch
    print(f"Batch {i}: FID score for batch: {fid_score_batch}")
    i+=1
print("FID Score is equal to: ", FID_Score/len(test_dataloader))



#If you want to Calculate FID of a single batch in dataloader, you can use this code
"""images, labels = next(iter(test_dataloader))
images = images.repeat(1, 3, 1, 1)
real_images = F.interpolate(images, size = (299, 299))
print(real_images.shape)

with torch.no_grad():
    noise = torch.randn(128, NOISE_DIM, 1, 1).to(device)
    generated_img = gen(noise)

gen_images = F.interpolate(generated_img, size = (299, 299))
gen_images = gen_images.repeat(1, 3, 1, 1)
model = PartialInceptionNetwork()
model = model.to(device)
fid_score = fid_score(real_images, gen_images, BATCH_SIZE, model)
print("FID Score:", fid_score)"""



Batch 0: FID score for batch: 63.66679062630564
Batch 1: FID score for batch: 64.23622867654387
Batch 2: FID score for batch: 65.53893269267701
Batch 3: FID score for batch: 62.50555562637081
Batch 4: FID score for batch: 56.95146817049405
Batch 5: FID score for batch: 68.48163775140586
Batch 6: FID score for batch: 60.94788280286966
Batch 7: FID score for batch: 61.515903849379896
Batch 8: FID score for batch: 58.91108151030264
Batch 9: FID score for batch: 65.78818469208659
Batch 10: FID score for batch: 67.58763837174195
Batch 11: FID score for batch: 62.09312978069812
Batch 12: FID score for batch: 57.64450324836312
Batch 13: FID score for batch: 70.21582512530071
Batch 14: FID score for batch: 67.48408916404006
Batch 15: FID score for batch: 63.45694182877597
Batch 16: FID score for batch: 60.568962460628654
Batch 17: FID score for batch: 68.44977347413376
Batch 18: FID score for batch: 65.45670434916946
Batch 19: FID score for batch: 61.77498845267752
FID Score is equal to:  63.6

'images, labels = next(iter(test_dataloader))\nimages = images.repeat(1, 3, 1, 1)\nreal_images = F.interpolate(images, size = (299, 299))\nprint(real_images.shape)\n\nwith torch.no_grad():\n    noise = torch.randn(128, NOISE_DIM, 1, 1).to(device)\n    generated_img = gen(noise)\n\ngen_images = F.interpolate(generated_img, size = (299, 299))\ngen_images = gen_images.repeat(1, 3, 1, 1)\nmodel = PartialInceptionNetwork()\nmodel = model.to(device)\nfid_score = fid_score(real_images, gen_images, BATCH_SIZE, model)\nprint("FID Score:", fid_score)'

In [13]:
# Loading CIFAR weights -- DCGAN

CHANNELS_IMG = 3
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
gen.load_state_dict(torch.load('cifar_dcgan.pth', map_location=device))
gen.eval()

Generator(
  (generator): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1))
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding

In [14]:
########################
##### CIFAR DATASET ####
########################

import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms

transformations = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)
test_dataset = torchvision.datasets.CIFAR10(root='dataset/', train=False, transform=transformations, download=True)
test_dataloader = DataLoader(test_dataset, batch_size=512, shuffle=True)
#define FID Model
model = PartialInceptionNetwork().to(device)

#this is only for one particular batch --- Calculation of FID
FID_Score = 0
i=0
for imgs, labels in test_dataloader:
    #images = imgs.repeat(1, 3, 1, 1)   ----> No need to increase to 3 as CIFAR already has 3 channels unlike MNIST which has 1
    real_images = F.interpolate(imgs, size = (299, 299))

    with torch.no_grad():
        noise = torch.randn(128, NOISE_DIM, 1, 1).to(device)
        generated_img = gen(noise)
    gen_images = F.interpolate(generated_img, size = (299, 299))
    #gen_images = gen_images.repeat(1, 3, 1, 1)

    fid_score_batch = fid_score(real_images, gen_images, imgs.shape[0], model)
    FID_Score += fid_score_batch
    print(f"Batch {i}: FID score for batch: {fid_score_batch}")
    i+=1
print("FID Score (CIFAR) is equal to: ", FID_Score/len(test_dataloader))

Files already downloaded and verified




Batch 0: FID score for batch: 195.27387877171245
Batch 1: FID score for batch: 189.0330097595008
Batch 2: FID score for batch: 191.7647556234041
Batch 3: FID score for batch: 192.6418623579033
Batch 4: FID score for batch: 188.18977271790635
Batch 5: FID score for batch: 191.2678238637945
Batch 6: FID score for batch: 192.50876336952322
Batch 7: FID score for batch: 194.23309450843254
Batch 8: FID score for batch: 187.41242913009876
Batch 9: FID score for batch: 189.3634768091946
Batch 10: FID score for batch: 189.34463243265714
Batch 11: FID score for batch: 191.74598388976386
Batch 12: FID score for batch: 193.85380603334295
Batch 13: FID score for batch: 193.30794027861083
Batch 14: FID score for batch: 193.52679807454547
Batch 15: FID score for batch: 200.27921753823068
Batch 16: FID score for batch: 189.07756808401243
Batch 17: FID score for batch: 189.1383179256813
Batch 18: FID score for batch: 196.63813349354632
Batch 19: FID score for batch: 197.0750921586427
FID Score (CIFAR)