In [None]:
import os
from glob import glob
import warnings
warnings.filterwarnings("ignore")

import numpy as np
from PIL import Image

import torch
from torchvision.utils import make_grid

from diffusers import AutoencoderKL
from transformers import AutoTokenizer, CLIPImageProcessor
from mmdiff import MMDiffStableDiffusionXLPipeline
from datasets import PersonalizedDataset

In [None]:
vae_path = "madebyollin/sdxl-vae-fp16-fix"
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = "openai/clip-vit-large-patch14"
mmdiff_ckpt = "checkpoints/portrait_generation"

device = "cuda"

In [None]:
image_processor = CLIPImageProcessor()

tokenizer_one = AutoTokenizer.from_pretrained(
    base_model_path,
    subfolder="tokenizer",
    use_fast=False,
    local_files_only=True,
)

tokenizer_two = AutoTokenizer.from_pretrained(
    base_model_path,
    subfolder="tokenizer_2",
    use_fast=False,
    local_files_only=True,
)

vae = AutoencoderKL.from_pretrained(
    vae_path,
    subfolder=None,
    torch_dtype=torch.float16,
    local_files_only=True
)

pipeline = MMDiffStableDiffusionXLPipeline.from_pretrained(
    base_model_path,
    vae=vae,
    tokenizer=tokenizer_one,
    tokenizer_2=tokenizer_two,
    variant="fp16",
    torch_dtype=torch.float16,
    local_files_only=True,
).to(device)

pipeline.load_from_checkpoint(image_encoder_path, mmdiff_ckpt, device, fuse_lora=True)

### Customized generation with single reference image

In [None]:
images_ref_root = "demo_data/Barack Obama"
image_ref_path = sorted(
    glob(os.path.join(images_ref_root, "*.jpg")) + \
    glob(os.path.join(images_ref_root, "*.png")) + \
    glob(os.path.join(images_ref_root, "*.jpeg"))
)[0]

demo_image = Image.open(image_ref_path)
w, h = demo_image.size
demo_image.resize((int(512 * w / h), 512), resample=3)

In [None]:
prompt = "a man<|subj|> in front of the White House"      # man<|subj|>: class + trigger token
negative_prompt = "longbody, lowres, bad anatomy, bad teeth, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
num_inference_images = 8
generator = torch.Generator(device=device).manual_seed(23)

num_inference_steps = 25
start_merge_step = 0
fuse_scale = 0.8    # weight of image condition

inference_dataset = PersonalizedDataset(
    prompt=prompt,
    images_ref_root=images_ref_root,
    tokenizer_one=tokenizer_one,
    tokenizer_two=tokenizer_two,
    image_processor=image_processor,
    max_num_objects=1,
)
images_ref = inference_dataset.prepare_data()

samples_tensor = []
for _ in range(num_inference_images):
    image = pipeline(
        prompt=prompt,
        images_ref=images_ref,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        start_merge_step=start_merge_step,
        fuse_scale=fuse_scale,
        guidance_scale=5.0,
        generator=generator,
        height=512,
        width=512,
    ).images[0]
    samples_tensor.append(torch.from_numpy(np.array(image)).permute(2, 0, 1))


grid = torch.stack(samples_tensor, 0)
grid = make_grid(grid, nrow=4)
grid = grid.permute(1, 2, 0).numpy()
Image.fromarray(grid.astype(np.uint8))