In [1]:
import os
import torch
from PIL import Image
from torch.nn import CosineSimilarity
from transformers import CLIPModel, CLIPProcessor, logging

# Setup
clip_model_id = 'openai/clip-vit-large-patch14'
device = 'cuda'
logging.set_verbosity_error()
torch.manual_seed(1)

# Textual inversion settings
property_name = 'grooty'  # Name of learned property

# Background test
background_prompt = 'A photo of on the moon'
samples_path_background = f'generated_images/samples/{property_name}_background'

# Style test
style_prompt = 'An oil painting of'
samples_path_style = f'generated_images/samples/{property_name}_style'

# Composition test
composition_prompt = 'Elmo holding a'
samples_path_composition = f'generated_images/samples/{property_name}_composition'

# Hugging Face access token
token = ''
with open('hugging_face_token.txt', 'r') as secret:
    token = secret.readline().strip()

In [2]:
# Load CLIP components
model = CLIPModel.from_pretrained(clip_model_id)
processor = CLIPProcessor.from_pretrained(clip_model_id)
model.to(device)
model.eval()
print('Loaded CLIP model successfully!')

Loaded CLIP model successfully!


In [3]:
# Load embeddings by passing images through CLIP
def load_avg_embeddings(image_dir):
    file_paths = [os.path.join(image_dir, file_path) for file_path in os.listdir(image_dir)]
    images = [Image.open(path) for path in file_paths]
    with torch.no_grad():
        embeddings = model.get_image_features(**processor(images=images, return_tensors='pt').to(device))
    return torch.mean(embeddings, dim=0)

# Get embeddings for prompts by passing text through CLIP
def get_prompt_embeddings(prompts):
    with torch.no_grad():
        embeddings = model.get_text_features(**processor(text=prompts, return_tensors='pt', padding=True).to(device))
    return embeddings


In [4]:
# Load CLIP embeddings for each sample
avg_features = []
for path in (samples_path_background, samples_path_style, samples_path_composition):
    avg_features.append(load_avg_embeddings(path))

# Get text embeddings
prompt_features = get_prompt_embeddings([background_prompt, style_prompt, composition_prompt])

# Measure text similarity
similarity_acc = 0
for n in range(len(avg_features)):
    cosine_similarity = CosineSimilarity(dim=0)
    similarity = cosine_similarity(avg_features[n], prompt_features[n])
    similarity_acc += round(similarity.item(), 4)

# Get score
editability_score = round(similarity_acc / len(avg_features), 4)
print(f'Editability Score (Text Similarity): {editability_score}')

Editability Score (Text Similarity): 0.2639
