# FID

In [None]:
import argparse
import torch
from torch import nn
from torchvision.models import inception_v3
import cv2
import multiprocessing
import numpy as np
import glob
import os
from scipy import linalg

In [None]:
def to_cuda(elements):
    """
    Transfers elements to cuda if GPU is available
    Args:
        elements: torch.tensor or torch.nn.module
        --
    Returns:
        elements: same as input on GPU memory, if available
    """
    if torch.cuda.is_available():
        return elements.cuda()
    return elements

class PartialInceptionNetwork(nn.Module):

    def __init__(self, transform_input=True):
        super().__init__()
        self.inception_network = inception_v3(pretrained=True)
        self.inception_network.Mixed_7c.register_forward_hook(self.output_hook)
        self.transform_input = transform_input

    def output_hook(self, module, input, output):
        # N x 2048 x 8 x 8
        self.mixed_7c_output = output

    def forward(self, x):
        """
        Args:
            x: shape (N, 3, 299, 299) dtype: torch.float32 in range 0-1
        Returns:
            inception activations: torch.tensor, shape: (N, 2048), dtype: torch.float32
        """
        assert x.shape[1:] == (3, 299, 299), "Expected input shape to be: (N,3,299,299)" +\
                                             ", but got {}".format(x.shape)
        x = x * 2 -1 # Normalize to [-1, 1]

        # Trigger output hook
        self.inception_network(x)

        # Output: N x 2048 x 1 x 1 
        activations = self.mixed_7c_output
        activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1,1))
        activations = activations.view(x.shape[0], 2048)
        return activations


def get_activations(images, batch_size):
    """
    Calculates activations for last pool layer for all iamges
    --
        Images: torch.array shape: (N, 3, 299, 299), dtype: torch.float32
        batch size: batch size used for inception network
    --
    Returns: np array shape: (N, 2048), dtype: np.float32
    """
    assert images.shape[1:] == (3, 299, 299), "Expected input shape to be: (N,3,299,299)" +\
                                              ", but got {}".format(images.shape)

    num_images = images.shape[0]
    inception_network = PartialInceptionNetwork()
    inception_network = to_cuda(inception_network)
    inception_network.eval()
    n_batches = int(np.ceil(num_images  / batch_size))
    inception_activations = np.zeros((num_images, 2048), dtype=np.float32)
    for batch_idx in range(n_batches):
        start_idx = batch_size * batch_idx
        end_idx = batch_size * (batch_idx + 1)

        ims = images[start_idx:end_idx]
        ims = to_cuda(ims)
        activations = inception_network(ims)
        activations = activations.detach().cpu().numpy()
        assert activations.shape == (ims.shape[0], 2048), "Expexted output shape to be: {}, but was: {}".format((ims.shape[0], 2048), activations.shape)
        inception_activations[start_idx:end_idx, :] = activations
    return inception_activations

def calculate_activation_statistics(images, batch_size):
    """Calculates the statistics used by FID
    Args:
        images: torch.tensor, shape: (N, 3, H, W), dtype: torch.float32 in range 0 - 1
        batch_size: batch size to use to calculate inception scores
    Returns:
        mu:     mean over all activations from the last pool layer of the inception model
        sigma:  covariance matrix over all activations from the last pool layer 
                of the inception model.
    """
    act = get_activations(images, batch_size)
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
            
    Stable version by Dougal J. Sutherland.
    Params:
    -- mu1 : Numpy array containing the activations of the pool_3 layer of the
             inception net ( like returned by the function 'get_predictions')
             for generated samples.
    -- mu2   : The sample mean over activations of the pool_3 layer, precalcualted
               on an representive data set.
    -- sigma1: The covariance matrix over activations of the pool_3 layer for
               generated samples.
    -- sigma2: The covariance matrix over activations of the pool_3 layer,
               precalcualted on an representive data set.
    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
    assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

    diff = mu1 - mu2
    # product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
        warnings.warn(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

def preprocess_image(im):
    """Resizes and shifts the dynamic range of image to 0-1
    Args:
        im: np.array, shape: (H, W, 3), dtype: float32 between 0-1 or np.uint8
    Return:
        im: torch.tensor, shape: (3, 299, 299), dtype: torch.float32 between 0-1
    """
    assert im.shape[2] == 3
    assert len(im.shape) == 3
    if im.dtype == np.uint8:
        im = im.astype(np.float32) / 255
    im = cv2.resize(im, (299, 299))
    im = np.rollaxis(im, axis=2)
    im = torch.from_numpy(im)
    assert im.max() <= 1.0
    assert im.min() >= 0.0
    assert im.dtype == torch.float32
    assert im.shape == (3, 299, 299)

    return im

def preprocess_images(images, use_multiprocessing):
    """Resizes and shifts the dynamic range of image to 0-1
    Args:
        images: np.array, shape: (N, H, W, 3), dtype: float32 between 0-1 or np.uint8
        use_multiprocessing: If multiprocessing should be used to pre-process the images
    Return:
        final_images: torch.tensor, shape: (N, 3, 299, 299), dtype: torch.float32 between 0-1
    """
    if use_multiprocessing:
        with multiprocessing.Pool(multiprocessing.cpu_count()) as pool:
            jobs = []
            for im in images:
                job = pool.apply_async(preprocess_image, (im,))
                jobs.append(job)
            final_images = torch.zeros(images.shape[0], 3, 299, 299)
            for idx, job in enumerate(jobs):
                im = job.get()
                final_images[idx] = im#job.get()
    else:
        final_images = torch.stack([preprocess_image(im) for im in images], dim=0)
    assert final_images.shape == (images.shape[0], 3, 299, 299)
    assert final_images.max() <= 1.0
    assert final_images.min() >= 0.0
    assert final_images.dtype == torch.float32
    return final_images

def calculate_fid(images1, images2, use_multiprocessing, batch_size):
    """ Calculate FID between images1 and images2
    Args:
        images1: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8
        images2: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8
        use_multiprocessing: If multiprocessing should be used to pre-process the images
        batch size: batch size used for inception network
    Returns:
        FID (scalar)
    """
    images1 = preprocess_images(images1, use_multiprocessing)
    images2 = preprocess_images(images2, use_multiprocessing)
    mu1, sigma1 = calculate_activation_statistics(images1, batch_size)
    mu2, sigma2 = calculate_activation_statistics(images2, batch_size)
    fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
    return fid

def load_images(path):
    """ Loads all .png or .jpg images from a given path
    Warnings: Expects all images to be of same dtype and shape.
    Args:
        path: relative path to directory
    Returns:
        final_images: np.array of image dtype and shape.
    """
    image_paths = []
    image_extensions = ["png", "jpg"]
    for ext in image_extensions:
        print("Looking for images in", os.path.join(path, "*.{}".format(ext)))
        for impath in glob.glob(os.path.join(path, "*.{}".format(ext))):
            image_paths.append(impath)
    first_image = cv2.imread(image_paths[0])
    
    W, H = first_image.shape[:2]
    image_paths.sort()
    image_paths = image_paths
    final_images = np.zeros((len(image_paths), H, W, 3), dtype=first_image.dtype)
    final_images1 = np.zeros((len(image_paths), H, W, 3), dtype=first_image.dtype)
    
    for idx, impath in enumerate(image_paths):
        trn = impath.find('val_GT')
        endn = len(impath)
        impath1 = './comparison_results/SRGAN_re/'+impath[trn+7:endn-4]+'_pred.png'
        im = cv2.imread(impath)
        im1 = cv2.imread(impath1)
        im = im[:, :, ::-1] # Convert from BGR to RGB
        im1 = im1[:, :, ::-1] 
        assert im.dtype == final_images.dtype
        final_images[idx] = im
        assert im1.dtype == final_images1.dtype
        final_images1[idx] = im1
        
        
    return final_images, final_images1

In [None]:
parser = argparse.ArgumentParser(description='FID')

parser.add_argument('--path1', type=str, default='./datasets/face_SRnDeblur/val_GT')
parser.add_argument('--path2', type=str, default='./comparison_results/wnoise/')

parser.add_argument('--multiprocessing', action='store_true', default=False)
parser.add_argument('--batch_size',type=int,default=8)

global args
args = parser.parse_args(['--path1','./datasets/face_SRnDeblur/val_GT',
                          '--path2', './comparison_results/SRResNet_re/',
                          '--batch_size','8'])

images1, images2 = load_images(args.path1)

fid_value = calculate_fid(images1, images2, args.multiprocessing, args.batch_size)
print(fid_value)

In [None]:
import argparse
import os
import random
import torch
import torchvision
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.backends import cudnn
from torch.autograd import Variable
from torch.utils import data
from torchvision import transforms
from PIL import Image

In [None]:
##### Helper Functions for Data Loading & Pre-processingclass ImageFolder(data.Dataset):
class ImageFolder(data.Dataset):
    def __init__(self, opt):
        # os.listdir function gives all lists of directory
        self.root = opt.dataroot
        self.no_resize_or_crop = opt.no_resize_or_crop
        self.no_flip = opt.no_flip
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])
        self.transformM = transforms.Compose([transforms.ToTensor()])
        #=====================================================================================#
        self.dir_A = os.path.join(opt.dataroot,'valx8')
        self.Aimg_paths = list(map(lambda x:os.path.join(self.dir_A,x),os.listdir(self.dir_A)))
        #=====================================================================================#
        
    def __getitem__(self, index):
        #=====================================================================================#
        # A : 32x32 (blur + LR)
        # B : 256x256 (LR)
        # C : 256x256 (GT)
        # D : 256x256 (fmask)
        A_path = self.Aimg_paths[index]
        trn = A_path.find('valx8')
        endn = len(A_path)
        C_path = A_path[:trn]+'val_GT'+A_path[trn+5:endn-4]+'.jpg'
        B_path = './comparison_results/detail1/'+A_path[trn+6:endn-4]+'_pred.png'
        A = Image.open(A_path).convert('RGB')
        C = Image.open(C_path).convert('RGB')
        B = Image.open(B_path)
        A = self.transform(A)
        B = self.transform(B)
        C = self.transform(C)

        return {'A':A,'B':B, 'C':C,'fname':A_path[trn+6:endn-4]}
        
    def __len__(self):
        return len(self.Aimg_paths)

##### Helper Function for GPU Training
def to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

##### Helper Function for Math
def denorm(x):
    out = (x+1)/2
    return out.clamp(0,1)

##### Helper Functions for GAN Loss (4D Loss Comparison)
def GAN_Loss(input, target, criterion):
    if target == True:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(1.0)
        labels = Variable(tmp_tensor, requires_grad=False)
    else:
        tmp_tensor = torch.FloatTensor(input.size()).fill_(0.0)
        labels = Variable(tmp_tensor, requires_grad=False)
        
    if torch.cuda.is_available():
        labels = labels.cuda()
        
    return criterion(input, labels)
##### Helper Function for Math
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def to_numpy(x):
    x = x.cpu()
    x = ((x.detach().numpy()+1)/2)
    x = np.transpose(x,(1,2,0))
    return x

def mse(x, y):
    return np.linalg.norm(x - y)


In [None]:
parser = argparse.ArgumentParser(description='Implementation of Pix2Pix')

# Task
parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders train, val, etc)')
parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')

# Options
parser.add_argument('--no_resize_or_crop', action='store_true', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--batchSize', type=int, default=1, help='test Batch size')

# misc
parser.add_argument('--model_path', type=str, default='./models')
parser.add_argument('--sample_path', type=str, default='./test_results')
parser.add_argument('--results_txt', type=str, default='./test_MSE_PSNR_SSIM.txt')

##### Helper Functions for Data Loading & Pre-processing
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

args = parser.parse_args(['--dataroot','./datasets/face_SRnDeblur','--which_direction','AtoB',
                          '--num_epochs','451','--batchSize','16','--no_resize_or_crop',
                          '--model_path','./Model3_1015_b16e451/models',
                          '--sample_path','./Model3_1015_b16e451/results_e451',
                         '--results_txt','./Model3_1015_b16e451/PSNRSSIM_e451.txt'])
# 741 751 761 771 781
print(args)

dataset = ImageFolder(args)
data_loader = data.DataLoader(dataset=dataset,
                              batch_size=args.batchSize,
                              shuffle=True,
                              num_workers=2)

In [None]:
from PIL import Image
from skimage import data, img_as_float
from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_psnr as psnr

In [None]:
total_step = len(data_loader) # For Print Log
mse_in_all, mse_out_all, psnr_in_all, psnr_out_all, ssim_in_all, ssim_out_all = 0,0,0,0,0,0
for i, sample in enumerate(data_loader):

    input_A = sample['A']
    comp = sample['B']
    GTHR = sample['C']
    testfileN = sample['fname']

    in_blurLR = to_variable(input_A)
    v_comp = to_variable(comp)
    v_GTHR = to_variable(GTHR)
    
    # print the log info
    print('Validation[%d/%d]' % (i + 1, total_step))
    # save the sampled images
    
    real_Br = v_GTHR[:,0:3,:,:]
    comp_Br = v_comp[:,0:3,:,:]
    
    for k in range(16):
        
        fake_Br_ = img_as_float(to_numpy(comp_Br[k,:,:,:]))
        real_Br_ = img_as_float(to_numpy(real_Br[k,:,:,:]))

        mse_out = mse(real_Br_,fake_Br_)
        psnr_out = psnr(real_Br_,fake_Br_,data_range=fake_Br_.max()-fake_Br_.min())
        ssim_out = ssim(real_Br_,fake_Br_,data_range=fake_Br_.max()-fake_Br_.min(),multichannel=True)

        mse_out_all += mse_out
        psnr_out_all += psnr_out
        ssim_out_all += ssim_out
        
#         torchvision.utils.save_image(denorm(comp_Br[k,:,:,:].data), os.path.join('./comparison_results/SRResNet_re', '%s.png' % testfileN[k]))

print('Average of MSE PSNR SSIM \n')
print('mse_out : %f, psnr_out : %f, ssim_out : %f \n' % (mse_out_all/1200, psnr_out_all/1200, ssim_out_all/1200))


In [None]:
print(testfileN)