In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from diffusers.schedulers import LMSDiscreteScheduler
from diffusers import StableDiffusionPipeline
from transformers import AutoTokenizer, CLIPProcessor, CLIPModel
import torchvision.transforms as transforms
from PIL import Image
from IPython import display

In [2]:
output_dir = './concept/sweetpepper'

In [None]:
class Config:
    threshold = 0.8
    seed = 1024
    device = 'cuda'
    batch_size = 8
    num_inference_steps = 200
    concept = 'sweetpepper'

In [None]:
# preprocess of the concept
concept = Config.concept.replace('_', ' ')

In [None]:
# load the concept dictionary
concept_dict = torch.load(f'{output_dir}/output/dictionary.pt')

Generate a ground truth image then do the decomposition


In [None]:
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
pipe.to("cuda")
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.set_progress_bar_config(disable=True)
pipe.tokenizer.add_tokens('<>')
trained_id = pipe.tokenizer.convert_tokens_to_ids('<>')
pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
_ = pipe.text_encoder.get_input_embeddings().weight.requires_grad_(False)


clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to('cuda')
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

clip_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

transform_tensor = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
concept = Config.concept
prompt = f'A image of {concept}'
generator = torch.Generator("cuda").manual_seed(Config.seed)
image = pipe(prompt, guidance_scale=7.5,
             generator=generator,
             return_dict=False,
             num_images_per_prompt=1,
             num_inference_steps=Config.num_inference_steps)[0][0]
# Display the original image
display(image.resize((224, 224)))

In [None]:
# load the best alphas we got from before
alpha_dict = torch.load(f'{output_dir}/best_alphas.pt').detach().require_grad_(False)

# debug, check if the alphas are dictionary
print(alpha_dict)

In [None]:
# alpha is the weights of different hidden concepts
# normalize the weights so they add up to 1
alpha_sum = alpha_dict.sum(0)   
alpha_normalized = alpha_dict / alpha_sum

# sort the weights in descending order, select the weights which sum up to 0.8
# this code returns the indices with the descending order
alpha_sort = torch.sort(alpha_normalized, descending=True)

# debug
print(alpha_sort)

In [None]:
# sorted weights
alpha_median = alpha_normalized[alpha_sort]

# debug
print(alpha_median)

# calculate the cumulative sum of the weights   
alpha_cumsum = torch.cumsum(alpha_median, dim=0)

# select the weights which sum up to 0.8    
top_80_percent = torch.where[alpha_cumsum <= Config.threshold][0]

# debug
print(top_80_percent)

In [None]:
top_80_tokens = [alpha_dict[idx] for idx in alpha_sort[top_80_percent]]

# debug 
print(top_80_tokens)

In [None]:
# get the top 80 percents hidden concepts name
top_80_concepts = [concept_dict[i] for i in top_80_tokens]

In [None]:
for hidden in top_80_tokens:
    hidden_concepts = pipe.tokenizer.decode(hidden)
    prompt = f'a photo of a {hidden}'
    generator = torch.Generator("cuda").manual_seed(Config.seed)
    image = pipe(prompt, guidance_scale=7.5,
                 generator=generator,
                 return_dict=False,
                 num_images_per_prompt=1,
                 num_inference_steps=Config.num_inference_steps)[0][0]
    # Display the image
    display(image.resize((224, 224)))
