In [6]:
!pip install torcheval
!pip install pytorch-ignite
!pip install -r requirements.txt 

Collecting torcheval
  Downloading torcheval-0.0.7-py3-none-any.whl.metadata (8.6 kB)
Downloading torcheval-0.0.7-py3-none-any.whl (179 kB)
Installing collected packages: torcheval
Successfully installed torcheval-0.0.7
Collecting pytorch-ignite
  Downloading pytorch_ignite-0.5.1-py3-none-any.whl.metadata (27 kB)
Downloading pytorch_ignite-0.5.1-py3-none-any.whl (312 kB)
Installing collected packages: pytorch-ignite
Successfully installed pytorch-ignite-0.5.1
Collecting pytorch_lightning (from -r requirements.txt (line 3))
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting omegaconf (from -r requirements.txt (line 4))
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting pycocotools (from -r requirements.txt (line 6))
  Downloading pycocotools-2.0.8-cp312-cp312-macosx_10_9_universal2.whl.metadata (1.1 kB)
Collecting torchmetrics>=0.7.0 (from pytorch_lightning->-r requirements.txt (line 3))
  Downloading torchmetrics-1.6.0-py3-none-a

In [8]:
# Perform Imports
import sys
sys.path.append(".")
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torcheval.metrics as metrics
import torchvision.transforms as T
import torchvision.datasets as datasets
import torchvision.transforms.functional as TF
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from t2i_model import T2IEncoderInput


# Setup

In [None]:
# Some preliminary setup to load the dataset
batch_size  = 256
d_proj      = 512
d_hidden    = 32
config_path = "logs/vqgan_gumbel_f8/configs/model.yaml"
chkpt_path = "logs/vqgan_gumbel_f8/checkpoints/last.ckpt"

transform = T.Compose([
    T.PILToTensor(), # Convert all PIL objects to tensors
    T.Resize((d_hidden, d_hidden)),  # Resize all images to d_h x d_h
    T.ConvertImageDtype(torch.float),
])

# Define a custom collate_fn
def custom_collate_fn(batch):
    images, captions = zip(*batch)

    # Find the max height and width in the batch
    max_height = max(img.shape[1] for img in images)
    max_width  = max(img.shape[2] for img in images)

    # Pad images to the max dimensions
    padded_images = [
        F.pad(img, (0, max_width - img.shape[2], 0, max_height - img.shape[1]))
        for img in images
    ]

    # Stack images and return with captions
    return torch.stack(padded_images), captions

# Define helper functions
def normalize_output(x):
  x = torch.clamp(x, -1., 1.)
  x = (x + 1.)/2.
  return x

# load the dataset
dataset = datasets.CocoCaptions(root = 'TEST_ROOT_PATH',
                                annFile = 'TEST_ANN_FILE',
                                transform=transform)
# load the dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn = custom_collate_fn)

# load the model checkpoint

'''
TODO: CHANGE THIS FOR THE DIFFRENT MODELS
'''
model = T2IEncoderInput(d_proj, d_hidden, config_path, chkpt_path, True)
model.load_state_dict(torch.load(PATH, weights_only=True))
model.eval() # set it to evaluation mode

# Getting FID for model checkpoint

In [None]:
# create the FID object
'''
TODO: specify the device as part of the FID object
TODO: might want to update the number of features in the FID object
'''
FID = metrics.FrechetInceptionDistance()

# loop through the dataloader
# Convert caption tuple options into a list
for (img, txt) in tqdm(dataloader):
   txt = [t[0] for t in txt]

   # Pass it through our model and get the logits
   txt_to_img = model(x = txt)
   txt_to_img = torch.stack([normalize_output(t) for t in txt_to_img], dim = 0)

   # add original and generated images to the FID object
   FID.update(images=img, is_real=True)
   FID.update(images=txt_to_img, is_real=False)

# calculate the FID score
fid_score = FID.compute()

print("FID Score: ", fid_score)

# Get Inception Score

In [9]:
from collections import OrderedDict

import torch
from torch import nn, optim

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.metrics.regression import *
from ignite.utils import *

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [None]:
IS = InceptionScore()

for (img, txt) in tqdm(dataloader):
    txt = [t[0] for t in txt]

   # Pass it through our model and get the logits
    txt_to_img = model(x = txt)
    txt_to_img = torch.stack([normalize_output(t) for t in txt_to_img], dim = 0)

    # add original and generated images to the FID object
    IS.update(output=txt_to_img)

# calculate the FID score
is_score = IS.compute()