In [1]:
import torch
from torch import nn
from torchvision import transforms
import numpy as np
from PIL import Image

Get pretrained Inception v3 model

https://pytorch.org/vision/stable/models.html

In [2]:
from torchvision import models
inception = models.inception_v3(pretrained=True)
inception.fc = nn.Identity()  # Remove last layer
inception.eval()

Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

In [3]:
def fid(model, img_1, img_2):
    '''
    Calculate the Frechet Inception Distance between two images.
    Takes in two images of the same size as Pytorch tensors.
    Both images need to be 299x299 with 3 channels (3xHxW).
    '''
    # normalize images to specifications of pretrained model
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
    normal_img_1 = normalize(img_1)
    normal_img_2 = normalize(img_2)

    # Pass the images through the model
    pred_img_1 = model(normal_img_1)
    pred_img_2 = model(normal_img_2)


    # Calculate score
    mu_1 = torch.mean(pred_img_1, axis=0)
    mu_2 = torch.mean(pred_img_2, axis=0)

    sigma_1 = torch.Tensor(np.cov(pred_img_1.detach().numpy(), rowvar=False))
    sigma_1 = torch.atleast_2d(sigma_1)
    sigma_2 = torch.Tensor(np.cov(pred_img_2.detach().numpy(), rowvar=False))
    sigma_2 = torch.atleast_2d(sigma_2)

    square_diff_mu = torch.sum((mu_1 - mu_2) ** 2)
    tr_sigma = torch.trace(sigma_1 + sigma_2 - 2 * torch.sqrt(torch.matmul(sigma_1, sigma_2)))
    return float(square_diff_mu + tr_sigma)

In [4]:
transform = transforms.ToTensor()

### Test FID Score

In [7]:
dog_1 = Image.open("/content/dog_1.jpeg").resize((299, 299))
dog_2 = Image.open("/content/dog_2.jpeg").resize((299, 299))
dolphin_1 = Image.open("/content/dolphin_1.jpeg").resize((299, 299))
dolphin_2 = Image.open("/content/dolphin_2.jpeg").resize((299, 299))

img_1_tensor = torch.unsqueeze(transform(dog_1), dim=0)
img_2_tensor = torch.unsqueeze(transform(dog_2), dim=0)
img_3_tensor = torch.unsqueeze(transform(dolphin_1), dim=0)
img_4_tensor = torch.unsqueeze(transform(dolphin_2), dim=0)

# Same animal
print(fid(inception, img_1_tensor, img_2_tensor))
print(fid(inception, img_3_tensor, img_4_tensor))

# Different animal (should be greater)
print(fid(inception, img_1_tensor, img_3_tensor))
print(fid(inception, img_2_tensor, img_4_tensor))

169.03912353515625
116.71086883544922
274.4468078613281
301.8088684082031
