In [1]:
import torch
import numpy as np
from torchvision import datasets, transforms
from PIL import Image
import os
import torch

In [None]:
from dataclasses import dataclass


@dataclass
class TrainingConfig:
    image_size = 32  # the generated image resolution
    eval_batch_size = 5  # how many images to sample during evaluation
    test_images_path = ""
    mixed_precision = "fp16"
    learning_rate = 1e-4
    lr_warmup_steps = 500
    gradient_accumulation_steps = 1

    seed = 24
config = TrainingConfig()

# Preprocess images

In [3]:
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.Resize(config.image_size),
    transforms.CenterCrop(config.image_size),
    transforms.ToTensor(),
])

In [4]:
from PIL import Image
import torch

def preprocess_image(image):
    image = torch.tensor(image).unsqueeze(0)
    image = image.permute(0, 3, 1, 2) / 255.0
    return image

def read_images(dir_path):
  image_paths = sorted([os.path.join(dir_path, x) for x in os.listdir(dir_path)])
  real_images = [np.array(Image.open(path).convert("RGB")) for path in image_paths]
  real_images = torch.cat([preprocess_image(image) for image in real_images])
  return real_images

# Compare images

In [5]:
from torchmetrics.image.fid import FrechetInceptionDistance
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
from tqdm import tqdm

In [None]:
fid_results = []

#Calculate fid for every saved model epoch (every 50th epoch in my case)
for i in range(50, 1951, 100):
    # Load the images
    generated_images = read_images(f"CIFAR10AT/{i}epoch") 
    val_images = read_images(config.test_images_path)
    
    # Create data loaders
    generated_loader = DataLoader(generated_images, batch_size=config.eval_batch_size, shuffle=False)
    val_loader = DataLoader(val_images, batch_size=config.eval_batch_size, shuffle=False)

    # Initialize FID calculation
    fid = FrechetInceptionDistance(normalize=True)
    
    # Determine the number of batches
    num_batches = min(len(generated_loader), len(val_loader))

    # Calculate FID
    with tqdm(total=num_batches, desc="Calculating FID", unit="batch") as pbar:
        for batch1, batch2 in zip(generated_loader, val_loader):
            fid.update(batch2, real=True)  # Real images
            fid.update(batch1, real=False)  # Generated images
            pbar.update(1)  # Update progress bar

    # Compute FID value
    temp = float(fid.compute())
    print(f"FID for i={i}: {temp}")

    # Append the result to the list
    fid_results.append({'i': i, 'FID': temp})

# Save results to a CSV file
results_df = pd.DataFrame(fid_results)
results_df.to_csv('result.csv', index=False)

In [None]:
generated_images=read_images(config.generated_images_path)
val_images=read_images(config.test_images_path)

generated_loader = DataLoader(generated_images, batch_size=config.eval_batch_size, shuffle=False)
val_loader = DataLoader(val_images, batch_size=config.eval_batch_size, shuffle=False)

fid = FrechetInceptionDistance(normalize=True)

num_batches = min(len(generated_loader), len(val_loader))

with tqdm(total=num_batches, desc="Calculating FID", unit="batch") as pbar:
    for batch1, batch2 in zip(generated_loader, val_loader):
        fid.update(batch2, real=True)  # Real images
        fid.update(batch1, real=False)  # Generated images
        pbar.update(1)  # Update progress bar

temp=float(fid.compute())
print(f"FID: {temp}")