In [9]:
import torch
import numpy as np
import torch.nn as nn
from scipy import linalg
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.transforms.functional import to_pil_image
from PIL import Image



class FIDScore:
    def __init__(self, path_a, path_b, image_size, batch_size, device='cuda'):
        self.device = device
        self.image_size = image_size
        self.path_a = path_a
        self.path_b = path_b
        self.batch_size = batch_size
        self.inception = self.load_patched_inception_v3().eval().to(device)
        self.transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

    @torch.no_grad()
    def extract_features(self, loader):
        for batch_idx, (data, target) in enumerate(loader):
            print(f"Batch {batch_idx} - Data shape: {data.shape}, Target shape: {target.shape}")
        pbar = tqdm(loader)
        feature_list = []
        for img, _ in loader:
            img = img.to(self.device)
            feature = self.inception(img)[0].view(img.shape[0], -1)
            feature_list.append(feature.to('cpu'))
        features = torch.cat(feature_list, 0)
        return features

    def calc_fid(self, sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
        cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
        if not np.isfinite(cov_sqrt).all():
            print('product of cov matrices is singular')
            offset = np.eye(sample_cov.shape[0]) * eps
            cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
        if np.iscomplexobj(cov_sqrt):
            if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
                m = np.max(np.abs(cov_sqrt.imag))
                raise ValueError(f'Imaginary component {m}')
            cov_sqrt = cov_sqrt.real
        mean_diff = sample_mean - real_mean
        mean_norm = mean_diff @ mean_diff
        trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
        fid = mean_norm + trace
        return fid

    def __call__(self, num_workers=4):
        dset_a = ImageFolder(self.path_a, self.transform)
       
        print("Number of real images:", len(dset_a))
        loader_a = DataLoader(dset_a, batch_size=self.batch_size, num_workers=num_workers)
        features_a = self.extract_features(loader_a).numpy()
        real_mean = np.mean(features_a, 0)
        real_cov = np.cov(features_a, rowvar=False)

        dset_b = ImageFolder(self.path_b, self.transform)
        print("Number of synthetic images:", len(dset_b))
        loader_b = DataLoader(dset_b, batch_size=self.batch_size, num_workers=num_workers, drop_last=True)
        features_b = self.extract_features(loader_b).numpy()
        sample_mean = np.mean(features_b, 0)
        sample_cov = np.cov(features_b, rowvar=False)

        fid = self.calc_fid(sample_mean, sample_cov, real_mean, real_cov)

        return fid

    @staticmethod
    def load_patched_inception_v3():
        inception = torch.hub.load('pytorch/vision:v0.9.0', 'inception_v3', pretrained=True)
        inception.fc = nn.Identity()
        return inception

In [12]:
import os
path_a = os.path.join(os.getcwd(), 'images')
print (path_a)
path_b = os.path.join('/home', 'jovyan', 'FastGAN', 'eval_40000')
print (path_b)
fid = FIDScore(path_a, path_b, 256, 32)
fid_score = fid()
print ('fid score:' , fid_score)

Using cache found in /home/jovyan/.cache/torch/hub/pytorch_vision_v0.9.0
  0%|          | 0/62 [00:14<?, ?it/s]

/home/jovyan/FastGAN/benchmarking/images
/home/jovyan/FastGAN/eval_40000





Number of real images: 160


  0%|          | 0/5 [00:00<?, ?it/s]

Batch 0 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 1 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 2 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 3 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 4 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Number of synthetic images: 2006


  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/5 [00:00<?, ?it/s]


Batch 0 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 1 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 2 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 3 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 4 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 5 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 6 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 7 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 8 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 9 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 10 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 11 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.S

  0%|          | 0/5 [00:02<?, ?it/s]


Batch 20 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 21 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 22 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 23 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 24 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 25 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 26 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 27 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 28 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 29 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 30 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 31 - Data shape: torch.Size([32, 3, 256, 256]), Target shap

  0%|          | 0/62 [00:00<?, ?it/s]

Batch 60 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
Batch 61 - Data shape: torch.Size([32, 3, 256, 256]), Target shape: torch.Size([32])
fid score: 1.6408695886198288


In [1]:
import torchvision
print (torchvision.__version__)

0.5.0
