# Fr√©chet Inception Distance (FID)

In [None]:
import torch
import torch.utils.data._utils.collate as collate_mod
from torchvision.transforms.functional import resize, to_tensor, to_pil_image

def custom_collate_fn(batch):
    target_size = (299, 299)
    new_batch = []
    for img in batch:
        if isinstance(img, torch.Tensor):
            if img.shape[1] != target_size[0] or img.shape[2] != target_size[1]:
                img = to_pil_image(img)
                img = resize(img, target_size)
                img = to_tensor(img)
        else:
            img = resize(img, target_size)
            img = to_tensor(img)
        new_batch.append(img)
    return torch.stack(new_batch)

collate_mod.default_collate = custom_collate_fn

import torch.utils.data as data
_original_DataLoader = data.DataLoader
def patched_DataLoader(*args, **kwargs):
    kwargs['num_workers'] = 0
    return _original_DataLoader(*args, **kwargs)
data.DataLoader = patched_DataLoader

from pytorch_fid import fid_score

real_images_path = # Real image folder
synthetic_images_path = # Synthetic image folder

fid_value = fid_score.calculate_fid_given_paths(
    [real_images_path, synthetic_images_path],
    batch_size=50,
    device='cuda',
    dims=2048
)

print("FID:", fid_value)


# Inception Score (IS)

In [None]:
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models import inception_v3
from torch.utils.data import DataLoader, Dataset
from PIL import Image

REAL_DIR =  # Real image folder
FAKE_DIR =  # Synthetic image folder

BATCH = 64
SPLITS = 10
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

class ImgFolder(Dataset):
    def __init__(self, root):
        exts = (".png", ".jpg", ".jpeg")
        self.paths = [os.path.join(root, f) for f in os.listdir(root)
                      if f.lower().endswith(exts)]
        if not self.paths:
            raise RuntimeError(f"No image files found in {root}")
        self.tf = transforms.Compose([
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        return self.tf(img)

def inception_score(img_dir, batch=BATCH, splits=SPLITS, device=DEVICE):
    loader = DataLoader(ImgFolder(img_dir), batch_size=batch,
                        shuffle=False, num_workers=4, pin_memory=True)
    net = inception_v3(pretrained=True, transform_input=False).to(device).eval()

    preds = []
    with torch.no_grad():
        for x in loader:
            x = x.to(device, non_blocking=True)
            preds.append(F.softmax(net(x), dim=1).cpu())
    preds = torch.cat(preds, dim=0).numpy()

    scores = []
    for part in np.array_split(preds, splits):
        p_y = part.mean(axis=0, keepdims=True)
        kl = part * (np.log(part + 1e-10) - np.log(p_y + 1e-10))
        scores.append(np.exp(kl.sum(axis=1).mean()))
    return float(np.mean(scores)), len(preds)

is_syn, n_syn = inception_score(FAKE_DIR)
print(f"IS (Synthetic) [{n_syn} imgs]: {is_syn:.4f}")

is_real, n_real = inception_score(REAL_DIR)
print(f"IS (Real)      [{n_real} imgs]: {is_real:.4f}")
