## Imports

In [None]:
from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler, StableDiffusionPix2PixZeroPipeline
from transformers import BlipForConditionalGeneration, BlipProcessor
import torch
from PIL import Image
from util import img_grid

In [None]:
base_img = Image.open("../../data/tmp/concept_bmw.jpeg")

## Captioning Model

In [None]:
captioner_id = "Salesforce/blip-image-captioning-base"
processor = BlipProcessor.from_pretrained(captioner_id)
model = BlipForConditionalGeneration.from_pretrained(captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True)

## Model Inversion Pipeline

In [None]:
sd_model_ckpt = "CompVis/stable-diffusion-v1-4"
inv_pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
    sd_model_ckpt,
    caption_generator=model,
    caption_processor=processor,
    torch_dtype=torch.float16,
    safety_checker=None,
)
inv_pipeline.scheduler = DDIMScheduler.from_config(inv_pipeline.scheduler.config)
inv_pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(inv_pipeline.scheduler.config)
inv_pipeline.enable_model_cpu_offload()

## Generate Caption

In [None]:
caption = inv_pipeline.generate_caption(base_img)
caption = ""

## Invert the image into latent embedding

In [None]:
generator = torch.manual_seed(0)
inv_latents = inv_pipeline.invert(caption, image=base_img, generator=generator, num_inference_steps=100).latents

In [None]:
inv_pipeline.to("cpu")
model_id = "CompVis/stable-diffusion-v1-4"
pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
# pipeline.scheduler = inv_pipeline.scheduler
pipeline = pipeline.to("cuda")

# prompt = "a photo of an astronaut riding a horse on mars"
recon_image = pipeline(prompt=caption, latents=inv_latents, num_inference_steps=100).images[0]  

In [None]:
img_grid([base_img, recon_image], 2, 1)