In [1]:
from prompt_to_prompt.ptp_utils import load_512
from prompt_to_prompt.null_text_inversion_batched import NullTextInversion



In [2]:
import random
import torch
import torch.nn.functional as F
from diffusers import StableDiffusionPipeline, DDIMScheduler
from tqdm import tqdm
from PIL import Image
import numpy as np

random.seed(8888)
generator_cuda = torch.Generator("cuda:0").manual_seed(8888)
generator_cpu = torch.Generator().manual_seed(8888)
torch_dtype = torch.float32
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77
LOW_RESOURCE = False

In [4]:
from torchvision.datasets import OxfordIIITPet
ds = OxfordIIITPet(root=".", split="trainval", download=True)
# image = load_512(np.array(ds[0][0]))
image = load_512(np.array(Image.open("milo.jpeg")))

In [5]:
vowels = "aeiou"
classes = [_class.lower() for _class in ds.class_to_idx.keys()]
class_strings = [f"{'an' if _class[0] in vowels else 'a'} {_class}" for _class in classes]

In [6]:
model_path = "runwayml/stable-diffusion-v1-5"
model = StableDiffusionPipeline.from_pretrained(
    model_path,
    torch_dtype=torch_dtype,
    safety_checker=None,
).to(device)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


## Load MS COCO annotations

In [7]:
import json

val_captions = "captions_train2017.json"

with open(val_captions, "r") as f:
    captions_train = json.load(f)

annotations = captions_train["annotations"]
images = captions_train["images"]

annotation_embeddings_index = torch.load("annotation_embeddings_index").cuda()

## Null-text Inversion

In [8]:
nti = NullTextInversion(model, NUM_DIFFUSION_STEPS, GUIDANCE_SCALE)

In [9]:
dataset_size = 100
ds_indexes = random.sample(range(len(ds)), dataset_size)
images = [np.array(ds[i][0]) for i in ds_indexes]
prompts = [f"a photo of {class_strings[ds[i][1]]}" for i in ds_indexes]

In [None]:
null_embeddings = nti.fit(images, prompts, max_steps=100, num_inner_steps=1, lr_scale_factor=1e-2)

  0%|                                     | 0/100 [00:00<?, ?it/s]

## Interpreting Optimised Latents

In [None]:
from sklearn.decomposition import PCA

seed = 50
pca = PCA(n_components=7, random_state=seed)
annotation_embeddings_index_proj = pca.fit_transform(annotation_embeddings_index.cpu().numpy())

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

def interpret_aligned_latents(embeddings_list):
    idxs = [0, 10, 20, 30, 40, 49]
    for idx in idxs:
        mean_embedding = embeddings_list[idx].cpu().numpy().mean(axis=1)
        cond_embedding = pca.transform(mean_embedding)
        min_index = np.argmax(cosine_similarity(annotation_embeddings_index_proj, cond_embedding))
        annotation = annotations[int(min_index)]
        print(idx, annotation["caption"])

In [None]:
interpret_aligned_latents(null_embeddings)