In [1]:
import torch
from dc_gan import Generator
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = 'TRUE'
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

#hyper-parameters
learning_rate = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64

transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

# If you train on MNIST, remember to set channels_img to 1
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, download=True)

cuda


In [3]:
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
gen.load_state_dict(torch.load('generator_weights_dc_gan.pth', map_location=device))
gen.eval()

Generator(
  (generator): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1))
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding

In [4]:
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [5]:
import torch.nn.functional as F

In [6]:
from evaluation_metrics import fid_score, PartialInceptionNetwork


#this is only for one particular batch --- Calculation of FID
images, labels = next(iter(dataloader))
images = images.repeat(1, 3, 1, 1)
real_images = F.interpolate(images, size = (299, 299))

with torch.no_grad():
    noise = torch.randn(128, NOISE_DIM, 1, 1).to(device)
    generated_img = gen(noise)

gen_images = F.interpolate(generated_img, size = (299, 299))
gen_images = gen_images.repeat(1, 3, 1, 1)
model = PartialInceptionNetwork()
model = model.to(device)
fid_score = fid_score(real_images, gen_images, BATCH_SIZE, model)
print("FID Score:", fid_score)



FID Score: 75.51969494775118
