In [1]:
import os
import torch
import torchvision
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize
import tqdm

inception = InceptionScore()
fid = FrechetInceptionDistance()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
fid.cuda()
inception.cuda()

transform = Compose([
    CenterCrop((299, 299)),  # Adjust target size as needed
    ToTensor(),
    lambda x: x*255
])

dataset = ImageFolder(root="coco", transform=transform)

batch_size = 32
data_loader = DataLoader(dataset, batch_size=batch_size)

for images, labels in tqdm.tqdm(data_loader):
    images = images.type(torch.uint8).cuda()
    fid.update(images, real=True)
    inception.update(images)

100%|█████████████████████████████████████████| 157/157 [01:20<00:00,  1.95it/s]


In [3]:
imgs = []
im_folder = 'controlnet_images'
for filename in os.listdir(im_folder):
    if filename == '.ipynb_checkpoints':
        continue
    f = os.path.join(im_folder, filename)
    img = torchvision.io.read_image(f)[:3]
    #plt.imshow(img.permute(1, 2, 0))
    #plt.show()
    imgs.append(img)
imgs = torch.stack(imgs)

In [4]:
print(imgs.shape)

torch.Size([30, 3, 512, 512])


In [5]:
fid.update(imgs.cuda(), real=False)
FID_mean = fid.compute()
print(FID_mean)

tensor(279.2801, device='cuda:0')


In [6]:
inception.update(imgs.cuda())
IS_mean, IS_std = inception.compute()
print(IS_mean)

tensor(32.0988, device='cuda:0')
