In [1]:
import shutil

import numpy as np
import os
from PIL import Image
import sys
from shutil import copyfile
from pathlib import Path

from diffusers.schedulers import LMSDiscreteScheduler
from diffusers import StableDiffusionPipeline


import torch

import torchvision.transforms as transforms
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer

import glob
import argparse

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
concept = 'terrorist'
target_seed = 55
folder = f'./{concept}'
prompt = f'a photo of a '
num_inference_steps = 25

In [None]:
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
pipe.to("cuda")
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.set_progress_bar_config(disable=True)
pipe.tokenizer.add_tokens('<>')
trained_id = pipe.tokenizer.convert_tokens_to_ids('<>')
pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
_ = pipe.text_encoder.get_input_embeddings().weight.requires_grad_(False)


clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to('cuda')
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

clip_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

transform_tensor = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
def clip_transform(image_tensor):
    image_tensor = torch.nn.functional.interpolate(image_tensor, size=(224, 224), mode='bicubic',
                                                   align_corners=False)
    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                      std=[0.26862954, 0.26130258, 0.27577711])
    image_tensor = normalize(image_tensor)
    return image_tensor

def load_alphas(alphas_projection, token_embeddings, seed, prompt, avg_norm=0.1):
    alphas_copy = alphas_projection.clone()
    # embeddings_mat = token_embeddings[dictionary]
    embedding = torch.matmul(alphas_copy, token_embeddings)
    embedding = torch.mul(embedding, 1 / embedding.norm())
    embedding = torch.mul(embedding, avg_norm)
    pipe.text_encoder.text_model.embeddings.token_embedding.weight[trained_id] = torch.nn.Parameter(
        embedding)
    generator = torch.Generator("cuda").manual_seed(seed)
    return pipe(prompt, guidance_scale=7.5,
                generator=generator,
                return_dict=False,
                num_images_per_prompt=1,
                num_inference_steps=num_inference_steps)[0][0]

In [None]:
"""
torch.Generator:
Purpose: A Generator is an object representing an independent random number generator. It allows for more fine-grained control of random number generation.
Use Cases: You might use a Generator when you need to maintain multiple independent random number generators. 
For example, in a scenario where different parts of your code need to be randomized independently of each other, 
or when you want to ensure reproducibility of specific operations without affecting the global RNG state.
Flexibility: With Generator, you can create different streams of random numbers that are not influenced by the global seed or by each other. 
This is particularly useful in parallel processing or when you want to isolate the randomness in different parts of your code.
Usage: To use a Generator, you first create an instance of it (e.g., gen = torch.Generator()), 
optionally set its seed (e.g., gen.manual_seed(1234)), and then pass it as an argument to functions that accept a generator.
"""

In [None]:
concept_nu = concept.replace('_', ' ')
concept_u = concept.replace(' ', '_')

orig_embeddings = pipe.text_encoder.text_model.embeddings.token_embedding.weight.clone().detach()
norms = [i.norm().item() for i in orig_embeddings]
avg_norm = np.mean(norms)

alphas_dict = torch.load(f'{folder}/output/best_alphas.pt').detach_().requires_grad_(False)

dictionary = torch.load(f'{folder}/output/dictionary.pt')
sorted_alphas, sorted_indices = torch.sort(alphas_dict, descending=True)
alpha_ids = []
num_alphas = 50
for i, idx in enumerate(sorted_indices[:num_alphas]):
    alpha_ids.append((i, pipe.tokenizer.decode([dictionary[idx]])))
alphas = torch.zeros(orig_embeddings.shape[0]).cuda()
top_word_idx = [dictionary[i] for i in sorted_indices[:num_alphas]]
for i, index in enumerate(top_word_idx):
    alphas[index] = alphas_dict[sorted_indices[i]]

clip_concept_inputs = clip_tokenizer([concept_nu], padding=True, return_tensors="pt").to('cuda')
clip_concept_features = clip_model.get_text_features(**clip_concept_inputs)

clip_text_inputs = clip_tokenizer([pipe.tokenizer.decode([x]) for x in top_word_idx], padding=True, return_tensors="pt").to('cuda')
clip_text_features = clip_model.get_text_features(**clip_text_inputs)
clip_words_similarity = (torch.matmul(clip_text_features, clip_text_features.transpose(1, 0)) /
                         torch.matmul(clip_text_features.norm(dim=1).unsqueeze(1),
                                      clip_text_features.norm(dim=1).unsqueeze(0)))

concept_words_similarity = torch.cosine_similarity(clip_concept_features, clip_text_features, axis=1)
similar_words = (np.array(concept_words_similarity.detach().cpu()) > 0.92).nonzero()[0]
clip_words_similarity = (np.array(clip_words_similarity.detach().cpu()) > 0.95)

# Zero-out similar words
for i in similar_words:
    alphas[top_word_idx[i]] = 0

In [None]:
concept_nu = concept.replace('_', ' ')
concept_u = concept.replace(' ', '_')

orig_embeddings = pipe.text_encoder.text_model.embeddings.token_embedding.weight.clone().detach()
norms = [i.norm().item() for i in orig_embeddings]
avg_norm = np.mean(norms)

alphas_dict = torch.load(f'{folder}/output/best_alphas.pt').detach_().requires_grad_(False)

dictionary = torch.load(f'{folder}/output/dictionary.pt')
sorted_alphas, sorted_indices = torch.sort(alphas_dict, descending=True)
alpha_ids = []
num_alphas = 50
for i, idx in enumerate(sorted_indices[:num_alphas]):
    alpha_ids.append((i, pipe.tokenizer.decode([dictionary[idx]])))
alphas = torch.zeros(orig_embeddings.shape[0]).cuda()
top_word_idx = [dictionary[i] for i in sorted_indices[:num_alphas]]
for i, index in enumerate(top_word_idx):
    alphas[index] = alphas_dict[sorted_indices[i]]

clip_concept_inputs = clip_tokenizer([concept_nu], padding=True, return_tensors="pt").to('cuda')
clip_concept_features = clip_model.get_text_features(**clip_concept_inputs)

clip_text_inputs = clip_tokenizer([pipe.tokenizer.decode([x]) for x in top_word_idx], padding=True, return_tensors="pt").to('cuda')
clip_text_features = clip_model.get_text_features(**clip_text_inputs)
clip_words_similarity = (torch.matmul(clip_text_features, clip_text_features.transpose(1, 0)) /
                         torch.matmul(clip_text_features.norm(dim=1).unsqueeze(1),
                                      clip_text_features.norm(dim=1).unsqueeze(0)))

concept_words_similarity = torch.cosine_similarity(clip_concept_features, clip_text_features, axis=1)
similar_words = (np.array(concept_words_similarity.detach().cpu()) > 0.92).nonzero()[0]
clip_words_similarity = (np.array(clip_words_similarity.detach().cpu()) > 0.95)

# Zero-out similar words
for i in similar_words:
    alphas[top_word_idx[i]] = 0