In [None]:
#Prompt(Chatgpt 4o): How to calculate FID and Clip score of trained emoji generator model

In [1]:
# Imports
import torch
from diffusers import StableDiffusionPipeline
from peft import PeftModel
from PIL import Image
import os
import pandas as pd
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel

In [2]:
# Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Prompts for generating stickers
PROMPTS = [
    "A happy cat holding a tiny heart",
    "Cool panda wearing sunglasses and headphones",
    "Cute avocado giving a thumbs up",
    "Sleepy moon with a pillow and stars",
    "Dancing banana with sparkles around",
    "Angry chili pepper blowing fire",
    "Smiling robot with a peace sign",
    "Laughing unicorn with rainbow hair",
    "Cute ghost saying boo with sparkles",
    "Excited frog jumping with joy",
    "Shy mushroom hiding behind leaves",
    "Cute sushi roll with rosy cheeks",
    "Cheerful sun waving hello",
    "Sad teardrop emoji with big eyes",
    "Boba tea with happy face and straw",
    "High-fiving stars with happy faces",
    "Crying cupcake with melting frosting",
    "Cool dog skateboarding",
    "Excited duck throwing confetti",
    "Cute alien giving a thumbs up",
    "Happy watermelon slice with smile",
    "Surprised bunny with wide eyes",
    "Tired coffee cup with droopy eyes",
    "Excited bee flying with sparkles",
    "Winking taco with salsa hat",
    "Angry broccoli flexing muscles",
    "Cute penguin with scarf and mittens",
    "Peaceful cloud meditating",
    "Jumping cookie saying 'Yay!'",
    "Nervous donut biting its lip"
]

# Output directory
save_dir = "outputs/Sticker_manual_prompt_clip_only"
os.makedirs(save_dir, exist_ok=True)

# Load Stable Diffusion with LoRA
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "../evaluation/sticker_diffusion_qlora/final_model")
pipe.enable_attention_slicing()
pipe.safety_checker = lambda images, clip_input, **kwargs: (images, [False] * len(images))
print(" LoRA model loaded and ready.")

# Generate images
results = []
for i, prompt in enumerate(tqdm(PROMPTS, desc="Generating images")):
    result = pipe(prompt, num_inference_steps=50, guidance_scale=7.5, height=256, width=256)
    image = result.images[0]
    image_path = os.path.join(save_dir, f"sticker_{i:02}.png")
    image.save(image_path)
    results.append({"prompt": prompt, "generated_image_path": image_path})

results_df = pd.DataFrame(results)

# Load CLIP model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval().to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Compute CLIP scores
clip_scores = []
for _, row in tqdm(results_df.iterrows(), total=len(results_df), desc="Calculating CLIP scores"):
    image = Image.open(row["generated_image_path"]).convert("RGB")
    inputs = clip_processor(text=row["prompt"], images=image, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        # Run processor
        inputs = clip_processor(text=row["prompt"], images=image, return_tensors="pt", padding=True).to(device)

        # Extract pixel and text features separately
        image_inputs = {"pixel_values": inputs["pixel_values"]}
        text_inputs = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"]
        }

        # Get features
        image_embeds = clip_model.get_image_features(**image_inputs)
        text_embeds = clip_model.get_text_features(**text_inputs)

        # Normalize and compute cosine similarity
        image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
        text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

        score = torch.nn.functional.cosine_similarity(image_embeds, text_embeds).item()
        clip_scores.append(score)

results_df["clip_score"] = clip_scores
results_df.to_parquet(os.path.join(save_dir, "clip_scores.parquet"), index=False)

# Print summary
print(f"\n All done! Results saved to: {save_dir}")
print(f" Average CLIP Score: {results_df['clip_score'].mean():.4f}")

Using device: cuda


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

 LoRA model loaded and ready.


Generating images:   0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:   3%|▎         | 1/30 [00:06<03:22,  6.98s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:   7%|▋         | 2/30 [00:12<02:59,  6.40s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  10%|█         | 3/30 [00:18<02:46,  6.17s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  13%|█▎        | 4/30 [00:24<02:35,  5.98s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  17%|█▋        | 5/30 [00:30<02:27,  5.91s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  20%|██        | 6/30 [00:36<02:22,  5.93s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  23%|██▎       | 7/30 [00:42<02:16,  5.93s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  27%|██▋       | 8/30 [00:48<02:10,  5.93s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  30%|███       | 9/30 [00:54<02:04,  5.95s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  33%|███▎      | 10/30 [01:00<01:59,  5.95s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  37%|███▋      | 11/30 [01:05<01:52,  5.90s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  40%|████      | 12/30 [01:11<01:45,  5.84s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  43%|████▎     | 13/30 [01:17<01:40,  5.89s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  47%|████▋     | 14/30 [01:23<01:34,  5.90s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  50%|█████     | 15/30 [01:29<01:29,  5.99s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  53%|█████▎    | 16/30 [01:35<01:24,  6.05s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  57%|█████▋    | 17/30 [01:42<01:19,  6.15s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  60%|██████    | 18/30 [01:48<01:14,  6.24s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  63%|██████▎   | 19/30 [01:55<01:08,  6.26s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  67%|██████▋   | 20/30 [02:01<01:03,  6.35s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  70%|███████   | 21/30 [02:07<00:55,  6.21s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  73%|███████▎  | 22/30 [02:13<00:48,  6.07s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  77%|███████▋  | 23/30 [02:18<00:41,  5.94s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  80%|████████  | 24/30 [02:24<00:35,  5.89s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  83%|████████▎ | 25/30 [02:30<00:29,  5.91s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  87%|████████▋ | 26/30 [02:36<00:23,  6.00s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  90%|█████████ | 27/30 [02:43<00:18,  6.10s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  93%|█████████▎| 28/30 [02:49<00:12,  6.17s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images:  97%|█████████▋| 29/30 [02:55<00:06,  6.27s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Generating images: 100%|██████████| 30/30 [03:02<00:00,  6.08s/it]
Calculating CLIP scores: 100%|██████████| 30/30 [00:00<00:00, 30.86it/s]


 All done! Results saved to: outputs/Sticker_manual_prompt_clip_only
 Average CLIP Score: 0.3070





In [5]:
import torch
import os
import pandas as pd
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets.folder import default_loader
from tqdm import tqdm
from diffusers import StableDiffusionPipeline
from peft import PeftModel
from PIL import Image
from pytorch_fid.fid_score import calculate_fid_given_paths

In [6]:
# Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Load validation dataset
df = pd.read_parquet("../data/processed_sticker_dataset.parquet")

class EmojiDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_tensor = torch.load(self.df.iloc[idx]['image_path']).float()
        prompt = self.df.iloc[idx]['prompt']
        return {"pixel_values": image_tensor, "prompt": prompt}

dataset = EmojiDataset(df)
train_size = int(0.99 * len(dataset))
_, val_set = random_split(dataset, [train_size, len(dataset) - train_size])

# Generate images from validation prompts
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "../evaluation/sticker_diffusion_qlora/final_model")
pipe.enable_attention_slicing()
pipe.safety_checker = lambda images, clip_input, **kwargs: (images, [False] * len(images))
print(" LoRA model loaded.")

# Output directories for FID
generated_dir = "outputs/sticker_fid/generated"
real_dir = "outputs/sticker_fid/real"
os.makedirs(generated_dir, exist_ok=True)
os.makedirs(real_dir, exist_ok=True)

print("🖼 Generating validation images...")
for i, sample in enumerate(tqdm(val_set)):
    prompt = sample["prompt"]
    real_image_tensor = sample["pixel_values"]
    
    # Save real image
    real_image = transforms.ToPILImage()(real_image_tensor)
    real_image.save(os.path.join(real_dir, f"real_{i:04}.png"))

    # Generate and save fake image
    result = pipe(prompt, num_inference_steps=50, guidance_scale=7.5, height=256, width=256)
    gen_image = result.images[0]
    gen_image.save(os.path.join(generated_dir, f"gen_{i:04}.png"))

# Compute FID using pytorch_fid
print("📊 Calculating FID...")
fid_score = calculate_fid_given_paths([real_dir, generated_dir], batch_size=32, device=device, dims=2048)
print(f" FID Score: {fid_score:.4f}")


Using device: cuda


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

 LoRA model loaded.
🖼 Generating validation images...


  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  2%|▏         | 1/44 [00:06<04:26,  6.19s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

  5%|▍         | 2/44 [00:12<04:13,  6.03s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

  7%|▋         | 3/44 [00:18<04:08,  6.06s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

  9%|▉         | 4/44 [00:24<04:10,  6.26s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 11%|█▏        | 5/44 [00:31<04:07,  6.35s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 14%|█▎        | 6/44 [00:38<04:07,  6.53s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 16%|█▌        | 7/44 [00:44<04:02,  6.55s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 18%|█▊        | 8/44 [00:51<03:54,  6.52s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 20%|██        | 9/44 [00:57<03:41,  6.33s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 23%|██▎       | 10/44 [01:03<03:32,  6.24s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 25%|██▌       | 11/44 [01:09<03:23,  6.17s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 27%|██▋       | 12/44 [01:14<03:13,  6.05s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 30%|██▉       | 13/44 [01:21<03:07,  6.06s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 32%|███▏      | 14/44 [01:27<03:04,  6.16s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 34%|███▍      | 15/44 [01:33<02:58,  6.16s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 36%|███▋      | 16/44 [01:39<02:54,  6.23s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 39%|███▊      | 17/44 [01:46<02:49,  6.29s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 41%|████      | 18/44 [01:52<02:42,  6.26s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 43%|████▎     | 19/44 [01:58<02:35,  6.20s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 45%|████▌     | 20/44 [02:04<02:29,  6.21s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 48%|████▊     | 21/44 [02:10<02:21,  6.13s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 50%|█████     | 22/44 [02:17<02:15,  6.15s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 52%|█████▏    | 23/44 [02:22<02:07,  6.09s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 55%|█████▍    | 24/44 [02:29<02:02,  6.13s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 57%|█████▋    | 25/44 [02:35<01:56,  6.12s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 59%|█████▉    | 26/44 [02:41<01:50,  6.13s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 61%|██████▏   | 27/44 [02:47<01:44,  6.13s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 64%|██████▎   | 28/44 [02:53<01:37,  6.09s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 66%|██████▌   | 29/44 [02:59<01:30,  6.01s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 68%|██████▊   | 30/44 [03:05<01:23,  5.96s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 70%|███████   | 31/44 [03:11<01:16,  5.90s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 73%|███████▎  | 32/44 [03:16<01:09,  5.82s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 75%|███████▌  | 33/44 [03:22<01:05,  5.97s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 77%|███████▋  | 34/44 [03:29<01:01,  6.16s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 80%|███████▉  | 35/44 [03:36<00:56,  6.28s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 82%|████████▏ | 36/44 [03:42<00:51,  6.40s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 84%|████████▍ | 37/44 [03:49<00:45,  6.44s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 86%|████████▋ | 38/44 [03:55<00:37,  6.29s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 89%|████████▊ | 39/44 [04:01<00:30,  6.18s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 91%|█████████ | 40/44 [04:06<00:24,  6.05s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 93%|█████████▎| 41/44 [04:12<00:18,  6.05s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 95%|█████████▌| 42/44 [04:18<00:11,  5.99s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 98%|█████████▊| 43/44 [04:25<00:06,  6.13s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 44/44 [04:31<00:00,  6.18s/it]


📊 Calculating FID...


100%|██████████| 2/2 [00:03<00:00,  1.95s/it]
100%|██████████| 2/2 [00:03<00:00,  1.92s/it]


 FID Score: 240.1287
