# Frechet Inception Distance (FID) Calculation
Reference: https://www.kaggle.com/code/ibtesama/gan-in-pytorch-with-fid#Fretchet-Inception-Distance

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.autograd import Function
import torchvision.transforms as transforms
import glob
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import math
import copy
from scipy import linalg

In [None]:
# %cd /content/drive/My\ Drive/CV2022

/content/drive/My Drive/CV2022


## Samples Generation

In [None]:
class EqualLR:
    def __init__(self, name):
        self.name = name

    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        fan_in = weight.data.size(1) * weight.data[0][0].numel()

        return weight * math.sqrt(2 / fan_in)

    @staticmethod
    def apply(module, name):
        fn = EqualLR(name)

        weight = getattr(module, name)
        del module._parameters[name]
        module.register_parameter(name + '_orig', nn.Parameter(weight.data))
        module.register_forward_pre_hook(fn)

        return fn

    def __call__(self, module, input):
        weight = self.compute_weight(module)
        setattr(module, self.name, weight)


def equal_lr(module, name='weight'):
    EqualLR.apply(module, name)

    return module

In [None]:
class EqualLinear(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()

        linear = nn.Linear(in_dim, out_dim)
        linear.weight.data.normal_()
        linear.bias.data.zero_()

        self.linear = equal_lr(linear)

    def forward(self, x):
        return self.linear(x)

In [None]:
class EqualConv2d(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

        conv = nn.Conv2d(*args, **kwargs)
        conv.weight.data.normal_()
        conv.bias.data.zero_()
        self.conv = equal_lr(conv)

    def forward(self, input):
        return self.conv(input)

In [None]:
class ConstantInput(nn.Module):
    def __init__(self, channel, size=4):
        super().__init__()

        self.input = nn.Parameter(torch.randn(1, channel, size, size))

    def forward(self, x):
        batch = x.shape[0]
        out = self.input.repeat(batch, 1, 1, 1)
        return out

In [None]:
class NoiseInjection(nn.Module):
    def __init__(self, channel):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))

    def forward(self, image, noise):
        return image + self.weight * noise

In [None]:
class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, in_channel, style_dim):
        super().__init__()

        self.norm = nn.InstanceNorm2d(in_channel)
        self.style = nn.Linear(style_dim, in_channel * 2)
        self.style.bias.data[:in_channel] = 1
        self.style.bias.data[in_channel:] = 0

    def forward(self, input, style):
        style = self.style(style).unsqueeze(2).unsqueeze(3)
        gamma, beta = style.chunk(2, 1)

        out = self.norm(input)
        out = gamma * out + beta

        return out

In [None]:
class BlurFunctionBackward(Function):
    @staticmethod
    def forward(ctx, grad_output, kernel, kernel_flip):
        ctx.save_for_backward(kernel, kernel_flip)

        grad_input = F.conv2d(
            grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]
        )

        return grad_input

    @staticmethod
    def backward(ctx, gradgrad_output):
        kernel, kernel_flip = ctx.saved_tensors

        grad_input = F.conv2d(
            gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]
        )

        return grad_input, None, None


class BlurFunction(Function):
    @staticmethod
    def forward(ctx, input, kernel, kernel_flip):
        ctx.save_for_backward(kernel, kernel_flip)

        output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])

        return output

    @staticmethod
    def backward(ctx, grad_output):
        kernel, kernel_flip = ctx.saved_tensors

        grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)

        return grad_input, None, None


blur = BlurFunction.apply


class Blur(nn.Module):
    def __init__(self, channel):
        super().__init__()

        weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
        weight = weight.view(1, 1, 3, 3)
        weight = weight / weight.sum()
        weight_flip = torch.flip(weight, [2, 3])

        self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))
        self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))

    def forward(self, input):
        return blur(input, self.weight, self.weight_flip)

In [None]:
class StyledConvBlock(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size=3,
        padding=1,
        style_dim=512,
        initial=False,
        upsample=False
    ):
        super().__init__()

        if initial:
            self.conv1 = ConstantInput(in_channel)

        else:
            if upsample:
                self.conv1 = nn.Sequential(
                    nn.Upsample(scale_factor=2, mode='nearest'),
                    EqualConv2d(
                        in_channel, out_channel, kernel_size, padding=padding
                    ),
                    Blur(out_channel),
                )
            else:
                self.conv1 = EqualConv2d(
                    in_channel, out_channel, kernel_size, padding=padding
                )

        self.noise1 = NoiseInjection(out_channel)
        self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim)
        self.lrelu1 = nn.LeakyReLU(0.2)

        self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
        self.noise2 = NoiseInjection(out_channel)
        self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim)
        self.lrelu2 = nn.LeakyReLU(0.2)

    def forward(self, x, style, noise):
        out = self.conv1(x)
        out = self.noise1(out, noise)
        out = self.lrelu1(out)
        out = self.adain1(out, style)

        out = self.conv2(out)
        out = self.noise2(out, noise)
        out = self.lrelu2(out)
        out = self.adain2(out, style)

        return out

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=512, n_linear=5):
        super(Generator, self).__init__()
        layers = []
        for i in range(n_linear):
            layers.append(EqualLinear(z_dim, z_dim))
            layers.append(nn.LeakyReLU(0.2))
        self.style = nn.Sequential(*layers)
        self.progression = nn.ModuleList(
            [
              StyledConvBlock(512, 512, 3, 1, initial=True),
              StyledConvBlock(512, 512, 3, 1, upsample=True),
              StyledConvBlock(512, 256, 3, 1, upsample=True),
              StyledConvBlock(256, 128, 3, 1, upsample=True),
              StyledConvBlock(128, 64, 3, 1, upsample=True),
            ]
        )
        self.to_rgb = EqualConv2d(64, 3, 1)

    def forward(self, x, noise=None, step=0):
        batch = x.size(0)
        if noise is None:
            noise = []
            for i in range(step + 1):
                size = 4 * 2 ** i
                noise.append(torch.randn(batch, 1, size, size, device=x[0].device))
        x = x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)
        styles = self.style(x)
        out = noise[0]
        for i, conv in enumerate(self.progression):
            out = self.progression[i](out, styles, noise[i])
        return self.to_rgb(out)

In [None]:
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
G = torch.load('Generator_v2_150.pth', map_location=device)

In [None]:
def generate_and_save_random_image(G, index):
    img_size = 64
    step = int(math.log(img_size, 2)) - 2
    z = torch.randn((1, 512))
    with torch.no_grad():
        img = G(z, step=step)[0]
    imgpath = f'generated_images/random_image_{index}.png'
    imgdata = torch.clip(img, 0, 1).permute([1, 2, 0]).detach().cpu().numpy()
    plt.imsave(imgpath, imgdata)
    return imgpath, z

In [None]:
for i in range(10000):
    generate_and_save_random_image(G, i)

## Datasets Construction

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        super().__init__()
        self.files = glob.glob(root_dir+"/*")
        self.transform = transform
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        image = Image.open(self.files[idx]).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)
        return image

In [None]:
dataroot = "animefacedataset/images"

img_size = 64
batch_size = 1
dataset = CustomDataset(root_dir=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize((img_size, img_size)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
gen_dataroot = "generated_images"

img_size = 64
batch_size = 1
gen_dataset = CustomDataset(root_dir=gen_dataroot,
                           transform=transforms.Compose([
                               transforms.Resize((img_size, img_size)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
gen_dataloader = torch.utils.data.DataLoader(gen_dataset, batch_size=batch_size, shuffle=True)

In [None]:
init_images_array = np.zeros((10000, 3, 64, 64))
gen_images_array = np.zeros((10000, 3, 64, 64))

iter_dataloader = iter(dataloader)
iter_gen_dataloader = iter(gen_dataloader)

for i in range(10000):
    if i % 10 == 0:
        print(i)
    init_images_array[i] = next(iter_dataloader)[0].numpy()
    gen_images_array[i] = next(iter_gen_dataloader)[0].numpy()

In [None]:
init_images_array.shape

(10000, 3, 64, 64)

In [None]:
print(init_images_array[9999].shape)

(3, 64, 64)


In [None]:
# np.save('init_array', init_images_array)
# np.save('gen_array', gen_images_array)

## FID Calculation

In [None]:
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)

In [None]:
init_images_array = np.load('init_array.npy')
gen_images_array = np.load('gen_array.npy')

In [None]:
init_images_array.shape

(10000, 3, 64, 64)

In [None]:
fids = []
colors = ['RED', 'GREEN', 'BLUE']
sizes = [[[0, 32], [0, 32]], [[32, 64], [0, 32]], [[0, 32], [32, 64]], [[32, 64], [32, 64]]]

for i in range(3):
    print('channel: ', colors[i])
    for j in range(4):
        current_quarter = sizes[j]
        x1, x2 = current_quarter[0]
        y1, y2 = current_quarter[1]
        print('quarter: ', current_quarter)

        init_part = copy.deepcopy(init_images_array[:10000, i:i + 1, x1:x2, y1:y2])
        gen_part = copy.deepcopy(gen_images_array[:10000, i:i + 1, x1:x2, y1:y2])
        # print('initial shape: ', init_part.shape)

        init_2D = init_part.reshape(init_part.shape[0], -1)
        gen_2D = gen_part.reshape(gen_part.shape[0], -1)
        # print('2D shape: ', init_2D.shape)

        init_mu = np.mean(init_2D, axis=0)
        init_sigma = np.cov(init_2D, rowvar=False)

        gen_mu = np.mean(gen_2D, axis=0)
        gen_sigma = np.cov(gen_2D, rowvar=False)

        current_fid = calculate_frechet_distance(init_mu, init_sigma, gen_mu, gen_sigma)
        fids.append(current_fid)
        print('FID: ', current_fid)

channel:  RED
quarter:  [[0, 32], [0, 32]]
FID:  209.2895935186839
quarter:  [[32, 64], [0, 32]]
FID:  203.3839375458764
quarter:  [[0, 32], [32, 64]]
FID:  291.6656503531822
quarter:  [[32, 64], [32, 64]]
FID:  210.04824968332048
channel:  GREEN
quarter:  [[0, 32], [0, 32]]
FID:  344.65171807566
quarter:  [[32, 64], [0, 32]]
FID:  269.7404155800184
quarter:  [[0, 32], [32, 64]]
FID:  432.3298171981106
quarter:  [[32, 64], [32, 64]]
FID:  298.1838433448992
channel:  BLUE
quarter:  [[0, 32], [0, 32]]
FID:  420.66032142787185
quarter:  [[32, 64], [0, 32]]
FID:  295.78259560661604
quarter:  [[0, 32], [32, 64]]
FID:  515.0935615107355
quarter:  [[32, 64], [32, 64]]
FID:  334.66293931552804


In [None]:
overall_mean = 0
for i in range(0, 12, 4):
    print('channel: ', colors[i % 3])
    current_mean = np.mean(fids[i:i + 4])
    overall_mean += current_mean
    print('mean: ', current_mean)

overall_mean /= 3
print('OVERALL FID: ', overall_mean)

channel:  RED
mean:  228.59685777526573
channel:  GREEN
mean:  336.2264485496721
channel:  BLUE
mean:  391.5498544651879
OVERALL FID:  318.7910535967086
