# Baseline FID and IS inception score using VQ-Difussion

In [None]:
from inference_VQ_Diffusion import VQ_Diffusion
# Perform Imports
import sys
sys.path.append(".")
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import torch
import torch.nn.functional as F
import torcheval.metrics as metrics
import torchvision.transforms as T
import torchvision.datasets as datasets
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

from collections import OrderedDict

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.metrics.regression import *
from ignite.utils import *

In [None]:
# Some preliminary setup to load the dataset
batch_size  = 256
d_proj      = 512
d_hidden    = 32
config_path = "logs/vqgan_gumbel_f8/configs/model.yaml"
chkpt_path = "logs/vqgan_gumbel_f8/checkpoints/last.ckpt"

transform = T.Compose([
    T.PILToTensor(), # Convert all PIL objects to tensors
    T.Resize((d_hidden, d_hidden)),  # Resize all images to d_h x d_h
    T.ConvertImageDtype(torch.float),
])

# Define a custom collate_fn
def custom_collate_fn(batch):
    images, captions = zip(*batch)

    # Find the max height and width in the batch
    max_height = max(img.shape[1] for img in images)
    max_width  = max(img.shape[2] for img in images)

    # Pad images to the max dimensions
    padded_images = [
        F.pad(img, (0, max_width - img.shape[2], 0, max_height - img.shape[1]))
        for img in images
    ]

    # Stack images and return with captions
    return torch.stack(padded_images), captions

# Define helper functions
def normalize_output(x):
  x = torch.clamp(x, -1., 1.)
  x = (x + 1.)/2.
  return x

dataset = datasets.CocoCaptions(root = 'TEST_ROOT_PATH',
                                annFile = 'TEST_ANN_FILE',
                                transform=transform)
# load the dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn = custom_collate_fn)

# Load the model - coco pretrained
VQ_Diffusion_model = VQ_Diffusion(config='OUTPUT/pretrained_model/config_text.yaml', path='OUTPUT/pretrained_model/coco_learnable.pth')

In [None]:
truncation_rate = 0.86

FID = metrics.FrechetInceptionDistance()
IS = InceptionScore()


for (img, txt) in tqdm(dataloader):
    batch_size = img.shape[0]

    data_i = {}
    data_i['label'] = [txt]
    data_i['image'] = None

    condition = txt

    with torch.no_grad():
            model_out = VQ_Diffusion_model.model.generate_content(
                batch=data_i,
                filter_ratio=0,
                replicate=batch_size,
                content_ratio=1,
                return_att_weight=False,
                sample_type="top"+str(truncation_rate)+'r',
            ) # B x C x H x W

    content = model_out['content']

    FID.update(images=img, is_real=True)
    FID.update(images=content, is_real=False)

    IS.update(output=content)

fid_score = FID.compute()
print("FID Score: ", fid_score)

# calculate the FID score
is_score = IS.compute()
