In [None]:
import os
import torch
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor 

device = 'cuda'
ae_path = '/home/jeeves/JJ_Projects/klara_models/flux_dev/train-model/ae.safetensors'
transformer_path = '/home/jeeves/JJ_Projects/klara_models/flux_dev/train-model/flux1-dev.safetensors'
os.environ['FLUX_DEV'] = transformer_path
os.environ['AE'] = ae_path

In [None]:
from flux.util import load_ae, load_clip, load_flow_model, load_t5
from PIL import Image
import numpy as np

# load models
transformer = load_flow_model(name='flux-dev', device=device, hf_download=False)

print('**** flux attn processor before load ip-adapter *********')
attn_procs = transformer.attn_processors
for k, v in attn_procs.items():
    print(k, v)

transformer.load_ip_adapter(image_encoder_path='openai/clip-vit-large-patch14', ip_model_path='/home/jeeves/JJ_Projects/github/ComfyUI/models/xlabs/ipadapters/ip_adapter.safetensors')

print('**** flux attn processor after load ip-adapter *********')
attn_procs = transformer.attn_processors
for k, v in attn_procs.items():
    print(k, v)


ae = load_ae(name='flux-dev', device=device, hf_download=False)
clip = load_clip(device=device)
t5 = load_t5(device=device, max_length=512)

class ImageProjector:

    def __init__(self):
        image_encoder_path = 'openai/clip-vit-large-patch14'
        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(
            'cuda', dtype=torch.float16
        )
        self.clip_image_processor = CLIPImageProcessor()

    def __call__(self, image_prompt: Image.Image | np.ndarray):
        # encode image-prompt embeds
        image_prompt = self.clip_image_processor(
            images=image_prompt,
            return_tensors="pt"
        ).pixel_values

        image_prompt = image_prompt.to(self.image_encoder.device)
        image_prompt_embeds = self.image_encoder(image_prompt).image_embeds.to( device='cuda', dtype=torch.bfloat16)
        return image_prompt_embeds

img_embbeder = ImageProjector()

torch.cuda.empty_cache()


In [None]:
from PIL import Image
import einops
# 现在我们需要一步步自己实现下面的采样函数
from flux.sampling import get_noise, prepare, denoise, get_schedule, unpack
from songmisc.utils import plot_multi_images


width, height = 1024, 1024
device = torch.device('cuda')
dtype = torch.bfloat16
seed = 42
# prompt = "a handsome man wearing suit, running on the street"
# prompt = "a handsome man wearing suit, standing in the classroom, hugging with a woman with white dress"
# prompt = 'anime style, a man wearing suit, running on the street'
prompt = 'anime style, (a man wearing black white interleaved suit and glasses) standing on the street, hugging with (a woman with ponytail hair, wearing white t-shirt and jeans), looking at each other'

steps = 50
guidance = 3.5
ip_scale = 0.5

path = 'imgs/qye.png'
ref_image = Image.open(path)
image_emb = img_embbeder(ref_image)
print('image emb', image_emb.shape)

with torch.inference_mode():
    # 获取噪声
    x_T = get_noise(num_samples=1, height=height, width=width, device=device, dtype=dtype, seed=seed)

    # 准备输入：img，img_ids，txt，txt_ids, vec
    inputs_dict = prepare(t5, clip, x_T, prompt)
    timesteps = get_schedule(steps, inputs_dict["img"].shape[1], shift=True)

    # 去噪循环
    img = inputs_dict['img']
    guidance_vec = torch.full((img.shape[0], ), guidance, device=device, dtype=img.dtype)
    for t_curr, t_prev in zip(timesteps[: -1], timesteps[1: ]):  
        t_vec = torch.full((img.shape[0], ), t_curr, device=device, dtype=img.dtype)
        pred = transformer(
            img=img, img_ids=inputs_dict['img_ids'],
            txt=inputs_dict['txt'], txt_ids=inputs_dict['txt_ids'],
            y = inputs_dict['vec'],
            timesteps=t_vec,
            guidance=guidance_vec,
            image_embeddings=image_emb,
            ip_scale=ip_scale
        )

        img = img + (t_prev - t_curr) * pred 
    x = unpack(img.float(), height, width)

    with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
        x = ae.decode(x)

    x = x.clamp(-1, 1)
    x = einops.rearrange(x[0], "c h w -> h w c")
    img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())

plot_multi_images([ref_image, img])
