In [1]:
# example of calculating the frechet inception distance
import numpy
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy.random import random
from scipy.linalg import sqrtm
import cv2 as cv
from os import listdir 

In [2]:
# calculate frechet inception distance
def calculate_fid(act1, act2):
    # calculate mean and covariance statistics
    mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = numpy.sum((mu1 - mu2)**2.0)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
    fid=abs(fid)
    return fid

In [3]:
def getFID(paint, foto, n):
    avg_fid=0
    for img in paint:
        img=cv.imread(img)
        img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
        img=img/255
        min_fid=None
        for imr in foto:
            imr=cv.imread(imr)
            imr=cv.cvtColor(imr, cv.COLOR_BGR2GRAY)
            imr=imr/255
            fid=calculate_fid(img, imr)
            if(min_fid==None or fid<min_fid):
                min_fid=fid
        avg_fid+=min_fid
    avg_fid/=n
    return avg_fid

In [4]:
def getImg(path):
    image=[]
    folder=listdir(path)
    for img in folder:
        if(img!='.DS_Store'):
            item=path+"/"+img
            image.append(item)
    #print(image)
    return image

In [5]:
paint=getImg('finish')
foto=getImg('result_resized')
print('net:%.3f'%getFID(paint, foto, 60))

net:10.442


In [6]:
paint=getImg('finish_1')
foto=getImg('finish_2')
print('baseline:%.3f'%getFID(paint, foto, 30))

baseline:6.914


In [12]:
def getFID_2(paint, foto):
    a=0
    for img in paint:
        a=a+1
        img=cv.imread(img)
        img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
        img=img/255
        b=0
        for imr in foto:
            b=b+1
            imr=cv.imread(imr)
            imr=cv.cvtColor(imr, cv.COLOR_BGR2GRAY)
            imr=imr/255
            fid=calculate_fid(img, imr)
            if(fid>9 and fid<11):
                print(a, b)

In [13]:
paint=getImg('finish')
foto=getImg('result_resized')
print('net:%.3f'%getFID_2(paint, foto))

9 4
11 2
11 12
11 27
13 2
14 17
15 10
17 2
17 10
19 20
19 23
20 25
21 12
22 28
24 2


KeyboardInterrupt: 

In [14]:
a=cv.imread("finish/35.jpg")
a=cv.cvtColor(a, cv.COLOR_BGR2GRAY)
a=a/255

b=cv.imread("result_resized/image (5).png")
b=cv.cvtColor(b, cv.COLOR_BGR2GRAY)
b=b/255

fid=calculate_fid(a, b)
print('FID:%.3f'%fid)

FID:9.682
