In [1]:
import torch
import open_clip
from PIL import Image
import os
import torch.nn.functional as nnf
from tqdm import tqdm
import csv
sdxl = False
gpu_num = 0
device = torch.device(f"cuda:{gpu_num}") if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K', device=device)
tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K')


# Define heads perturbed
if not sdxl:
    heads_perturbed = [0, 1, 11, 21, 31, 41, 51, 61, 71, 81, 91, 101, 111, 121, 128]
else:
    heads_perturbed = [0, 11, 111, 211, 311, 411, 511, 611, 711, 811, 911, 1011, 1111, 1211, 1300]

# Define the classes
Color = "red, blue, green, yellow, black, white, purple, gray, pink, brown".split(", ")
Animals = "cat, dog, rabbit, frog, bird, squirrel, deer, lion, penguin, horse".split(", ")
Fruits_and_Vegetables = "lemons, bananas, apples, oranges, blueberries, carrots, broccoli, tomatoes, potatoes, grapes".split(", ")
Image_Style = "cubist, pop art, steampunk, impressionist, black-and-white, watercolor, cartoon, minimalist, sepia, sketch".split(", ")
Material = "glass, copper, marble, jade, gold, basalt, silver, clay, paper, leather".split(", ")
Nature_Scenes = "forest, desert, beach, waterfall, mountain, canyon, glacier, coral reef, jungle, lake".split(", ")
Weather_Conditions = "snowy, rainy, foggy, stormy".split(", ")
Geometric_Patterns = "polka-dot, leopard, stripe, greek-key, plaid".split(", ")
Furniture = "bed, table, chair, sofa, recliner, bookshelf, dresser, wardrobe, coffee table, TV stand".split(", ")
Electronics = "smartphone, laptop, tablet, smart TV, digital camera, drone, desktop computer, microwave, refrigerator, smartwatch".split(", ")
Objects_A = "car, bench, bowl, ballon, ball".split(", ")
Objects_B = "bowl, cup, table, ball, teapot".split(", ")
Objects_C = "T-shirt, pillow, wallpaper, umbrella, blanket".split(", ")
Animals_A = "cat, dog, rabbit, frog, bird".split(", ")
Others = "castle, mountain, cityscape, farmland, forest".split(", ")
Animals_ood = "rabbit, frog, sheep, pig, chicken, dolphin, goat, duck, deer, fox".split(", ")
Color_ood = "coral, beige, violet, cyan, magenta, indigo, orange, turquoise, teal, khaki".split(", ")
Material_ood = "copper, marble, jade, gold, basalt, silver, clay, steel, tin, bronze".split(", ")
Fruits_and_Vegetables_ood = "lemons, blueberries, onions, raspberries, pineapples, cherries, cucumbers, bell peppers, cauliflowers, mangoes".split(", ")
Nature_Scenes_ood = "glacier, coral reef, swamp, pond, fjord, rainforest, grassland, marsh, creek, island".split(", ")
Tableware = "salad bowl, serving platter, bread basket, fondue pot, spoon, fork, nut dish, coffee pot, tureen, chafing dish".split(",")

# Define seeds
seeds = [10, 20, 30]

In [None]:
"""Calculate CLIP image-text similarity and save the results as a csv file"""
# --- Change only the following variables --- #
exp_nums = [5]
main_items_list = [Image_Style]
categories = "Image Style".split(", ")
# ------------------------------------------- #

name = "wo_category"


for idx, exp_num in enumerate(exp_nums):
    main_items = main_items_list[idx]
    category = categories[idx]

    directory = f"./hp_outputs/exp_{exp_num}"
    subdirectories = [name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name))]
    scores = dict()

    for subdirectory in tqdm(subdirectories):
        # Prepare the text
        for item in main_items:
            if item in subdirectory:
                prompt = item
                break
        if name == "wo_category":
            prompt = [prompt]
        elif name == "w_category":
            prompt = [f"{category}: {prompt}"]
        text = tokenizer(prompt * len(seeds)).to(device)

        # Prepare the images
        for order in ["top", "bottom"]:
            subdirectory_order = os.path.join(directory, subdirectory, order)
            suborders = [name for name in os.listdir(subdirectory_order) if os.path.isdir(os.path.join(subdirectory_order, name))]
            for suborder in suborders: 
                images = []
                for seed in seeds:
                    image = preprocess(Image.open(os.path.join(subdirectory_order, suborder, f"{seed}.png"))).unsqueeze(0).to(device)
                    images.append(image)
                images = torch.cat(images, dim=0)
                with torch.no_grad():
                    image_features = model.encode_image(images)
                    text_features = model.encode_text(text)
                    score = nnf.cosine_similarity(image_features, text_features).mean(dim=0).item()
                    if suborder not in scores:
                        scores[suborder] = score
                    else:
                        scores[suborder] += score
        
    for key, _ in scores.items():
        scores[key] /= len(subdirectories)

    # Save as csv file
    os.makedirs('./hp_results', exist_ok=True)
    csv_file = f'./hp_results/hp_scores_{name}_exp_{exp_num}.csv'
    keys = [f"top_{num_heads}" for num_heads in heads_perturbed]
    keys += [f"bottom_{num_heads}" for num_heads in heads_perturbed] 

    with open(csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Key', 'Value'])  # Write the header
        for key in keys:
            writer.writerow([key, f"{scores[key]:.4f}"])

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [03:21<00:00,  4.04s/it]
