In [None]:
import torch
import math
import matplotlib.pyplot as plt

from PIL import Image
from diffusers import FluxPipeline
from torch import Tensor
from torchvision import transforms


DTYPE = torch.bfloat16
pipe = FluxPipeline.from_pretrained("/root/autodl-tmp/data/FLUX-dev", torch_dtype=DTYPE)
pipe.load_lora_weights("/root/autodl-tmp/flux-lora-dreambooth")
pipe.to("cuda")

In [None]:
@torch.inference_mode()
def decode_imgs(latents, pipeline):
    imgs = (latents / pipeline.vae.config.scaling_factor) + pipeline.vae.config.shift_factor
    imgs = pipeline.vae.decode(imgs)[0]
    imgs = pipeline.image_processor.postprocess(imgs, output_type="pil")
    return imgs

@torch.inference_mode()
def encode_imgs(imgs, pipeline):
    latents = pipeline.vae.encode(imgs).latent_dist.sample()
    latents = (latents - pipeline.vae.config.shift_factor) * pipeline.vae.config.scaling_factor
    latents = latents.to(dtype=torch.bfloat16)
    return latents

def get_noise(
    num_samples: int,
    height: int,
    width: int,
    device: torch.device,
    dtype: torch.dtype,
    seed: int,
):
    return torch.randn(  # [B, 16, H // 8, W // 8], latents after VAE
        num_samples,
        16,
        2 * math.ceil(height / 16),
        2 * math.ceil(width / 16),
        device=device,
        dtype=dtype,
        generator=torch.Generator(device=device).manual_seed(seed),
    )

In [None]:
def time_shift(mu: float, sigma: float, t: Tensor):
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

def get_lin_function(
    x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
):
    m = (y2 - y1) / (x2 - x1)
    b = y1 - m * x1
    return lambda x: m * x + b

def get_schedule(
    num_steps: int,
    image_seq_len: int,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
    shift: bool = True,
) -> list[float]:
    timesteps = torch.linspace(1, 0, num_steps + 1)
    if shift:
        mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
        timesteps = time_shift(mu, 1.0, timesteps)
    return timesteps.tolist()

timesteps = get_schedule( # shape: [num_inference_steps]
            num_steps=100,
            image_seq_len=(1024 // 16) * (1024 // 16), # vae_scale_factor = 16
            shift=False,  # Set True for Flux-dev, False for Flux-schnell
        )

print(timesteps)

In [None]:
@torch.inference_mode()
def forward_denoise(pipeline, timesteps, prompt, resolution=1024, guidance_scale=3.5, seed=0):
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
        prompt=prompt, prompt_2=prompt, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds)

    noise = get_noise( # save, shape [num_samples, 16, resolution // 8, resolution // 8]
        num_samples=1,
        height=resolution,
        width=resolution,
        device="cuda",
        dtype=torch.bfloat16,
        seed=seed,
    )

    latent_image_ids = FluxPipeline._prepare_latent_image_ids(
        noise.shape[0],
        noise.shape[2],
        noise.shape[3],
        noise.device,
        torch.bfloat16,
    )

    packed_latents = FluxPipeline._pack_latents( # shape [num_samples, (resolution // 16 * resolution // 16), 16 * 2 * 2]
        noise,
        batch_size=noise.shape[0],
        num_channels_latents=noise.shape[1],
        height=noise.shape[2],
        width=noise.shape[3],
    )
    
    # Reversed denoising loop in latent space
    with pipeline.progress_bar(total=len(timesteps)-1) as progress_bar:
        for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
            t_vec = torch.full((packed_latents.shape[0],), t_curr, dtype=packed_latents.dtype, device=packed_latents.device)
            guidance_vec = torch.full((packed_latents.shape[0],), guidance_scale, device=packed_latents.device, dtype=packed_latents.dtype)
            print(f"time step: {t_vec[0]}")
            pred = pipeline.transformer(
                    hidden_states=packed_latents, # shape: [batch_size, seq_len, num_channels_latents], e.g. [1, 4096, 64] for 1024x1024
                    timestep=t_vec,        # range: [0, 1]
                    guidance=guidance_vec, # scalar guidance values for each sample in the batch
                    pooled_projections=pooled_prompt_embeds, # CLIP text embedding
                    encoder_hidden_states=prompt_embeds,     # T5 text embedding
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    joint_attention_kwargs=None,
                    return_dict=pipeline,
                )[0]
            packed_latents = packed_latents + (t_prev - t_curr) * pred
            progress_bar.update()
    
    img_latents = FluxPipeline._unpack_latents( # save, shape [num_samples, 16, resolution//8, resolution//8]
            packed_latents,
            height=1024,
            width=1024,
            vae_scale_factor=pipeline.vae_scale_factor,
    )
    return img_latents

img_latents = forward_denoise(pipe, timesteps, "a photo of sks dog running on the beach", resolution=1024, guidance_scale=3.5, seed=0)

out = decode_imgs(img_latents, pipe)[0]

plt.figure(figsize=(8, 8), dpi=300)
plt.imshow(out)