In [None]:
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, CLIPModel
from diffusers import StableDiffusionPipeline

In [None]:
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")

concept = 'sweetpepper'
folder = f'./{concept}'
clip_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
# an example for torch sort
x = torch.randn(3, 4)

sorted, indices = torch.sort(x, dim=1, descending=True)
print(sorted)
print(indices)

In [None]:
# In case the concept is a compound word, replace the underscore with a space
concept_nu = concept.replace('_', ' ')

# Load the model, detach to make sure no grad
orig_embeddings = pipe.text_encoder.text_model.embeddings.token_embedding.weight.clone().detach()

# Get the average norm of the embeddings
norms = [i.norm().item() for i in orig_embeddings]
avg_norm = np.mean(norms)

# alpha is the weight of each hidden concept achived from CLIP, for example: sweetpaper = 0.8 * fingers + 0.5 * paper
alphas_dict = torch.load(f'{folder}/output/best_alphas.pt').detach_().requires_grad_(False)

# Get the dictionary of the tokenizer
dictionary = torch.load(f'{folder}/output/dictionary.pt')

# the dictionary would be for example: {'finers': 0.8, 'paper': 0.1, 'sweet': 0.05,...}
sorted_alphas, sorted_indices = torch.sort(alphas_dict, descending=True)
alpha_ids = []
num_alphas = 50

# Get the top 50 words with the highest weights, sorted_indices is the index of the words in the dictionary
# for example the finger was the 8th word in the dictionary, pepper was the 10th word in the dictionary, then the sorted_indices would be [8, 10, ...]
# then we decode the top50 hidden concepts by using this code chunk
for i, idx in enumerate(sorted_indices[:num_alphas]):
    alpha_ids.append((i, pipe.tokenizer.decode([dictionary[idx]])))
alphas = torch.zeros(orig_embeddings.shape[0]).cuda()
top_word_idx = [dictionary[i] for i in sorted_indices[:num_alphas]]
for i, index in enumerate(top_word_idx):
    alphas[index] = alphas_dict[sorted_indices[i]]
# after this code chunk, we get the top 50 hidden concepts and their weights, the rest of the hidden concepts are set to 0
# the alphas would be for example: [0.8, 0.5, 0.1, 0.05, 0.01, 0.01, 0.01, ..., 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]

# the clip concept input is the concept we are using now, for exmaple sweetpepper
clip_concept_inputs = clip_tokenizer([concept_nu], padding=True, return_tensors="pt").to('cuda')
clip_concept_features = clip_model.get_text_features(**clip_concept_inputs)

# the clip text input is the top 50 hidden concepts we get from the previous code chunk
clip_text_inputs = clip_tokenizer([pipe.tokenizer.decode([x]) for x in top_word_idx], padding=True, return_tensors="pt").to('cuda')
clip_text_features = clip_model.get_text_features(**clip_text_inputs)
clip_words_similarity = (torch.matmul(clip_text_features, clip_text_features.transpose(1, 0)) /
                         torch.matmul(clip_text_features.norm(dim=1).unsqueeze(1),
                                      clip_text_features.norm(dim=1).unsqueeze(0)))
# the clip_words_similarity is the similarity between the top 50 hidden concepts, for example, the similarity between fingers and paper is 0.5, this 
# is calculated for the future loop, for example if one hidden concept got removed from the list, then the very similar hidden concept would also be removed

# concept similarity is the similarity between the concept and the top 50 hidden concepts, I am not so sure if this is proper since we are only considering the 
# the hidden concepts instead of considering about their weights, a.k.a. the alphas
concept_words_similarity = torch.cosine_similarity(clip_concept_features, clip_text_features, axis=1)
similar_words = (np.array(concept_words_similarity.detach().cpu()) > 0.92).nonzero()[0]
clip_words_similarity = (np.array(clip_words_similarity.detach().cpu()) > 0.95)

# Zero-out similar words
for i in similar_words:
    alphas[top_word_idx[i]] = 0


# the problem with this code chunk would be:
# the alphas are only used for sorting the hidden concepts, but not used for calculating the similarity between the concept and the hidden concepts
# what if the hidden concept has a very high weight, but the similarity between the concept and the hidden concept is very low, then this hidden concept
# would not be removed from the list, but it should be removed from the list
# also, what if several hidden concepts have very similar and very high alphas, then if we only want to pick up two hidden concepts, then it's tricky