In [18]:
import scipy
import numpy as np
import torch
import scipy.optimize

# This is the matrix square root function you will be using
def matrix_sqrt(x):
    '''
    Function that takes in a matrix and returns the square root of that matrix.
    For an input matrix A, the output matrix B would be such that B @ B is the matrix A.
    Parameters:
        x: a matrix
    '''
    y = x.cpu().detach().numpy()
    y = scipy.linalg.sqrtm(y)
    return torch.Tensor(y.real, device=x.device)

def frechet_distance(mu_x, mu_y, sigma_x, sigma_y):
    '''
    Function for returning the Fréchet distance between multivariate Gaussians,
    parameterized by their means and covariance matrices.
    Parameters:
        mu_x: the mean of the first Gaussian, (n_features)
        mu_y: the mean of the second Gaussian, (n_features) 
        sigma_x: the covariance matrix of the first Gaussian, (n_features, n_features)
        sigma_y: the covariance matrix of the second Gaussian, (n_features, n_features)
    '''
    return (mu_x - mu_y).dot(mu_x - mu_y) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2*torch.trace(matrix_sqrt(sigma_x @ sigma_y))

def preprocess(img):
    img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bilinear', align_corners=False)
    return img

def get_covariance(features):
    return torch.Tensor(np.cov(features.detach().numpy(), rowvar=False))

def get_FID(real_samples, fake_samples):
    inception_model = torch.hub.load('pytorch/vision:v0.9.0', 'inception_v3', pretrained=True)
    inception_model.eval()
    
    fake_features_list = []
    real_features_list = []
    for i in range(real_samples.shape[0]//4):
        #real feature maps
        real_samples = preprocess(real_samples[i:i+4,:,:,:])
        real_features = inception_model(real_samples).detach().to('cpu') # Move features to CPU
        real_features_list.append(real_features)
        
        #fake feature maps
        fake_samples = preprocess(fake_samples[i:i+4,:,:,:])
        fake_features = inception_model(fake_samples.detach().to('cpu'))
        fake_features_list.append(fake_features)
    
    fake_features_all = torch.cat(fake_features_list)
    real_features_all = torch.cat(real_features_list)
    
    mu_fake = fake_features_all.mean(0)
    mu_real = real_features_all.mean(0)
    sigma_fake = get_covariance(fake_features_all)
    sigma_real = get_covariance(real_features_all)
    
    with torch.no_grad():
        FID = frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item()
    
    return FID