<a href="https://colab.research.google.com/github/johnsunbuns/GansResearch/blob/master/FIDscore.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import inception_v3
import scipy

In [None]:
# inception v3 feature extractor
class PartialInceptionNetwork(torch.nn.Module):
    def __init__(self):
        # trigger parent constructor
        super(PartialInceptionNetwork, self).__init__()

        # initialize pretrained network
        self.inception_network = inception_v3(pretrained=True)

        # register a forward hook 
        self.inception_network.Mixed_7c.register_forward_hook(self.output_hook)

    # a hook allows you to give you the output of the internal block directly when you 
    # pass the inputs through the forward method. (Pytorch specific)
    def output_hook(self, module, input, output):
        # N x 2048 x 8 x 8
        self.mixed_7c_output = output 

    def forward(self, x):
        """
        Args:
            x: shape (N, 3, 299, 299) 
        Returns:
            inception activations: shape (N, 2048)
        """
        assert x.shape[1:] == (3, 299, 299)

        # Trigger output hook
        self.inception_network(x)

        # Output: N x 2048 x 8 x 8
        activations = self.mixed_7c_output

        # Output: N x 2048 x 1 x 1
        activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1, 1))

        # Rectify dimension
        activations = activations.view(x.shape[0], 2048)

        return activations

net = PartialInceptionNetwork()

In [None]:
def fid_score(real_images, gen_images, batch_size):

    # list of features
    real_activiations = []
    gen_activations = []

    # number of steps
    num_steps = int(math.ceil(float(len(real_images)) / float(batch_size)))

    # iterate over the images
    for i in range(num_steps):

        # mini-batch start and end index
        start = i * batch_size 
        end = (i + 1) * batch_size

        # mini-batch images
        mini_batch_real = real_images[s:e]
        mini_batch_gen = gen_images[s:e]

        # mini-batch as Torch tensor with gradients
        batch_real = Variable(mini_batch_real)
        batch_gen = Variable(mini_batch_gen)

        # apply a forward pass through inception network
        features_real = net(batch_real)
        features_gen = net(batch_gen)

        # accumulate features
        real_activations.append(features_real)
        gen_activations.append(features_gen)

    # stack tensor
    features_real = torch.cat(real_activations, 0)
    featrues_gen = torch.cat(gen_activiations, 0)

    # tensor to numpy 
    xr = features_real.detach().numpy()
    xg = features_gen.detach().numpy()

    # calculate mean
    u1 = np.mean(xr, axis=0)
    u2 = np.mean(xg, axis=0)

    # calculate variance
    s1 = np.cov(xr, rowvar=False)
    s2 = np.cov(xg, rowvar=False)

    # difference squared
    diff = u1 - u2
    diff_squared = diff.dot(diff)

    # trace covariance product
    prod = s1.dot(s2)
    sqrt_prod, _ = scipy.linalg.sqrtm(prod, disp=False)

    # avoid imaginary numbers
    if np.iscomplexobj(sqrt_prod):
        sqrt_prod = sqrt_prod.real

    prod_tr = np.trace(sqrt_prod)

    # calculate FID
    final_score = diff_squared + np.trace(s1) + np.trace(s2) - 2 * prod_tr

    # return FID score
    return final_score

In [None]:
# test FID score

# REAL images

# load mini batch from real dataset
images,_ = next(iter(train_loader))

# reshape to 1x28x28
images = images.view(-1, 1, 28, 28)

# repeat gray channel to RGB
images = images.repeat(1, 3, 1, 1)

# resize the images to 3x299x299
real_res_images = F.interpolate(images, size=(299,299))

# Generated Images, same process
images = generato(test_set)
images = images.view(-1, 1, 28, 28)
images = images.repeat(1, 3, 1, 1)
gen_res_images = F.interpolate(images, size=(299, 299))

# calculate inception score
score = fid_score(real_res_images, gen_res_images, batch_size)
score

NameError: ignored