# Evaluate text to image generation

Build a test captions list.

In [7]:
from pycocotools.coco import COCO
import torch
import clip
from model import CVAE

# Load the test dataset
test_annotations_file = "autodl-tmp/annotations/captions_val2017.json"
test_coco = COCO(test_annotations_file)

# Hyperparameters
latent_dim = 128
condition_dim = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
subset_fraction = 0.5  # Fraction of the dataset to evaluate

# Extract test captions
test_captions = []
all_image_ids = list(test_coco.imgs.keys())
subset_size = int(len(all_image_ids) * subset_fraction)
image_ids = all_image_ids[:subset_size]

for img_id in image_ids:
    ann_ids = test_coco.getAnnIds(imgIds=img_id)
    annotations = test_coco.loadAnns(ann_ids)
    caption = annotations[0]['caption']
    test_captions.append(caption)

print(f"Number of captions: {len(test_captions)}")
print(f"Example caption: {test_captions[0]}")

loading annotations into memory...
Done (t=0.04s)
creating index...
index created!
Number of captions: 2500
Example caption: A man is in a kitchen making pizzas.


Load trained model

In [54]:
from model import CVAE
import torch

latent_dim = 128
condition_dim = 512  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CVAE(condition_dim=condition_dim, latent_dim=latent_dim).to(device)
model.load_state_dict(torch.load("model_250.pth"))
model.eval()  
print("Model loaded successfully!")

Model loaded successfully!


In [55]:
# Load CLIP model
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)

print("CLIP model loaded successfully!")

# CLIP Similarity Evaluation Function
def clip_similarity(model, captions, device, num_samples=10):
    """
    Evaluates the CLIP similarity between generated images and input captions.

    Args:
        model: CVAE model.
        captions: List of text captions for evaluation.
        device: Device to perform computations.
        num_samples: Number of samples to evaluate (default: 10).

    Returns:
        Average CLIP similarity score.
    """
    total_similarity = 0.0
    num_evaluated = min(len(captions), num_samples)  # Limit the number of samples

    for i in range(num_evaluated):
        # Get the text caption
        caption = captions[i]
        text_token = clip.tokenize([caption]).to(device)

        with torch.no_grad():
            # Encode text into CLIP feature space
            text_feature = model_clip.encode_text(text_token)

            # Generate latent vector and decode an image
            z = torch.randn(1, latent_dim).to(device)  # Random latent vector
            generated_image = model.decoder(z, text_feature).view(1, 3, 224, 224).to(device)

            # Encode the generated image into CLIP feature space
            image_feature = model_clip.encode_image(generated_image)

            # Compute cosine similarity
            similarity = torch.nn.functional.cosine_similarity(text_feature, image_feature, dim=-1).item()
            total_similarity += similarity

    # Return the average similarity score
    return total_similarity / num_evaluated

# Evaluate the model
average_similarity = clip_similarity(model, test_captions, device, num_samples=2000)
print(f"Average CLIP Similarity: {average_similarity:.4f}")

CLIP model loaded successfully!
Average CLIP Similarity: 0.2002


-----


In [61]:
# from pytorch_fid import calculate_fid_given_paths
import os
from torchvision.utils import save_image
from tqdm import tqdm
from pycocotools.coco import COCO
import torch
from model import CVAE
import clip
from PIL import Image

# Directories for real and generated images
real_images_dir = "real_images"
generated_images_dir = "generated_images"
os.makedirs(real_images_dir, exist_ok=True)
os.makedirs(generated_images_dir, exist_ok=True)

# Dataset and Model Setup
test_annotations_file = "autodl-tmp/annotations/captions_val2017.json"
test_coco = COCO(test_annotations_file)

latent_dim = 128
condition_dim = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
subset_fraction = 0.1  # Smaller subset for FID calculation

# Load CVAE Model
model = CVAE(condition_dim=condition_dim, latent_dim=latent_dim).to(device)
model.load_state_dict(torch.load("model_250.pth"))
model.eval()

# Load CLIP Model
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)

# Transform for Real Images
from torchvision.transforms import Compose, Resize, ToTensor

transform = Compose([
    Resize((224, 224)),
    ToTensor(),
])

# Extract a Subset of Real Images and Save Them
image_ids = list(test_coco.imgs.keys())
subset_size = int(len(image_ids) * subset_fraction)
image_ids = image_ids[:subset_size]

print("Saving real images...")
for i, img_id in enumerate(tqdm(image_ids)):
    image_info = test_coco.loadImgs(img_id)[0]
    img_path = f"autodl-tmp/val2017/{image_info['file_name']}"  # Update to your dataset path
    img = Image.open(img_path).convert("RGB")
    img = transform(img)
    save_image(img, os.path.join(real_images_dir, f"{i}.png"))

# Generate Synthetic Images and Save Them
test_captions = []
for img_id in image_ids:
    ann_ids = test_coco.getAnnIds(imgIds=img_id)
    annotations = test_coco.loadAnns(ann_ids)
    caption = annotations[0]['caption']
    test_captions.append(caption)

print("Generating and saving synthetic images...")
for i, caption in enumerate(tqdm(test_captions)):
    with torch.no_grad():
        text_token = clip.tokenize([caption]).to(device)
        text_feature = model_clip.encode_text(text_token)
        z = torch.randn(1, latent_dim).to(device)  # Sample a latent vector
        generated_image = model.decoder(z, text_feature).squeeze(0).cpu()
        save_image(generated_image, os.path.join(generated_images_dir, f"{i}.png"))

# # Calculate FID
# print("Calculating FID...")
# fid_score = calculate_fid_given_paths(
#     [real_images_dir, generated_images_dir],
#     batch_size=50,
#     device=device,
#     dims=2048,  # Default dimension used in InceptionV3
# )
# print(f"FID Score: {fid_score}")

loading annotations into memory...
Done (t=0.07s)
creating index...
index created!
Saving real images...


100%|██████████| 500/500 [00:09<00:00, 50.48it/s]


Generating and saving synthetic images...


100%|██████████| 500/500 [00:12<00:00, 38.87it/s]


In [25]:
import pytorch_fid
print(dir(pytorch_fid))

['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__']
