In [1]:
import torch
import math
import matplotlib.pyplot as plt

from PIL import Image
from diffusers import FluxPipeline
from torch import Tensor
from torchvision import transforms


DTYPE = torch.bfloat16
pipe = FluxPipeline.from_pretrained("/root/autodl-tmp/FLUX-dev", torch_dtype=DTYPE)
pipe.load_lora_weights("/root/autodl-tmp/lora_ckpt/reflow-dev-various/checkpoint-4500/pytorch_lora_weights.safetensors")
pipe.to("cuda")

  warn(


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


FluxPipeline {
  "_class_name": "FluxPipeline",
  "_diffusers_version": "0.31.0.dev0",
  "_name_or_path": "/root/autodl-tmp/FLUX-dev",
  "scheduler": [
    "diffusers",
    "FlowMatchEulerDiscreteScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "text_encoder_2": [
    "transformers",
    "T5EncoderModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "tokenizer_2": [
    "transformers",
    "T5TokenizerFast"
  ],
  "transformer": [
    "diffusers",
    "FluxTransformer2DModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [2]:
@torch.inference_mode()
def decode_imgs(latents, pipeline):
    imgs = (latents / pipeline.vae.config.scaling_factor) + pipeline.vae.config.shift_factor
    imgs = pipeline.vae.decode(imgs)[0]
    imgs = pipeline.image_processor.postprocess(imgs, output_type="pil")
    return imgs

@torch.inference_mode()
def encode_imgs(imgs, pipeline):
    latents = pipeline.vae.encode(imgs).latent_dist.sample()
    latents = (latents - pipeline.vae.config.shift_factor) * pipeline.vae.config.scaling_factor
    latents = latents.to(dtype=torch.bfloat16)
    return latents

def get_noise(
    num_samples: int,
    height: int,
    width: int,
    device: torch.device,
    dtype: torch.dtype,
    seed: int,
):
    return torch.randn(  # [B, 16, H // 8, W // 8], latents after VAE
        num_samples,
        16,
        2 * math.ceil(height / 16),
        2 * math.ceil(width / 16),
        device=device,
        dtype=dtype,
        generator=torch.Generator(device=device).manual_seed(seed),
    )

In [9]:
def time_shift(mu: float, sigma: float, t: Tensor):
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

def get_lin_function(
    x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
):
    m = (y2 - y1) / (x2 - x1)
    b = y1 - m * x1
    return lambda x: m * x + b

def get_schedule(
    num_steps: int,
    image_seq_len: int,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
    shift: bool = True,
) -> list[float]:
    timesteps = torch.linspace(1, 0, num_steps + 1)
    if shift:
        mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
        timesteps = time_shift(mu, 1.0, timesteps)
    return timesteps.tolist()

timesteps = get_schedule( # shape: [num_inference_steps]
            num_steps=4,
            image_seq_len=(1024 // 16) * (1024 // 16), # vae_scale_factor = 16
            shift=True,  # Set True for Flux-dev, False for Flux-schnell
        )

print(timesteps)

[1.0, 0.9045307636260986, 0.7595109343528748, 0.5128440856933594, 0.0]


In [13]:
@torch.inference_mode()
def forward_denoise(pipeline, num_steps, prompt, resolution=1024, guidance_scale=3.5, seed=0):
    timesteps = get_schedule( # shape: [num_inference_steps]
            num_steps=num_steps,
            image_seq_len=(1024 // 16) * (1024 // 16), # vae_scale_factor = 16
            shift=True,  # Set True for Flux-dev, False for Flux-schnell
        )
    
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt=prompt, prompt_2=prompt)

    noise = get_noise( # save, shape [num_samples, 16, resolution // 8, resolution // 8]
        num_samples=1,
        height=resolution,
        width=resolution,
        device="cuda",
        dtype=torch.bfloat16,
        seed=seed,
    )

    latent_image_ids = FluxPipeline._prepare_latent_image_ids(
        noise.shape[0],
        noise.shape[2],
        noise.shape[3],
        noise.device,
        torch.bfloat16,
    )

    packed_latents = FluxPipeline._pack_latents( # shape [num_samples, (resolution // 16 * resolution // 16), 16 * 2 * 2]
        noise,
        batch_size=noise.shape[0],
        num_channels_latents=noise.shape[1],
        height=noise.shape[2],
        width=noise.shape[3],
    )
    
    # Reversed denoising loop in latent space
    with pipeline.progress_bar(total=len(timesteps)-1) as progress_bar:
        for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
            t_vec = torch.full((packed_latents.shape[0],), t_curr, dtype=packed_latents.dtype, device=packed_latents.device)
            guidance_vec = torch.full((packed_latents.shape[0],), guidance_scale, device=packed_latents.device, dtype=packed_latents.dtype)
            print(f"time step: {t_vec[0]}")
            pred = pipeline.transformer(
                    hidden_states=packed_latents, # shape: [batch_size, seq_len, num_channels_latents], e.g. [1, 4096, 64] for 1024x1024
                    timestep=t_vec,        # range: [0, 1]
                    guidance=guidance_vec, # scalar guidance values for each sample in the batch
                    pooled_projections=pooled_prompt_embeds, # CLIP text embedding
                    encoder_hidden_states=prompt_embeds,     # T5 text embedding
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    joint_attention_kwargs=None,
                    return_dict=pipeline,
                )[0]
            packed_latents = packed_latents + (t_prev - t_curr) * pred
            progress_bar.update()
    
    img_latents = FluxPipeline._unpack_latents( # save, shape [num_samples, 16, resolution//8, resolution//8]
            packed_latents,
            height=1024,
            width=1024,
            vae_scale_factor=pipeline.vae_scale_factor,
    )
    return img_latents

# prompt = 'A water color painting of a lone hiker standing triumphantly on the summit of a snow-capped mountain. The hiker, wearing a red jacket and backpack, is silhouetted against a clear blue sky. The word "ASCEND" is painted in bold, blue letters at the base of the mountain, adding a powerful message to the image. The overall impression is one of determination, achievement, and the beauty of nature. Dramatic high quality'

# prompt = 'A high resolution photo of Einstein, white background, photo-realistic, high-detail'

# prompt = 'A vibrant, starry night sky illuminates a lively street café, with warm golden lights spilling from its windows. The café is nestled on a narrow cobblestone street, surrounded by rustic buildings with swirling, textured brushstrokes. Bold, dynamic colors—deep blues and glowing yellows—fill the scene. People are seated at small round tables, sipping coffee, and chatting. The atmosphere is cozy and inviting, yet full of movement and energy, capturing the timeless essence of a Van Gogh painting.'

prompt = 'Jewelry design, a ring with bright rose-cut blue diamonds, surrounded by small lily-of-the-valley flower-shaped diamonds, golden stems form the ring of the ring. The center of the base is a beautiful rose gold, with a detachable black ring on both sides'

steps_list = [4, 8, 16, 24, 32, 40]

fig, axes = plt.subplots(2, 3, figsize=(15, 10), dpi=300)

for i, num_steps in enumerate(steps_list):
    img_latents = forward_denoise(pipe, num_steps=num_steps, prompt=prompt, 
                                  resolution=1024, guidance_scale=3.5, seed=12)
    out = decode_imgs(img_latents, pipe)[0]
    
    ax = axes[i // 3, i % 3]
    ax.imshow(out)
    ax.set_title(f"{num_steps} Steps")
    ax.axis('off') 

plt.tight_layout()
plt.show()

The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['determination, achievement, and the beauty of nature. dramatic high quality']


  0%|          | 0/4 [00:00<?, ?it/s]

time step: 1.0
time step: 0.90625
time step: 0.7578125
time step: 0.51171875


The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['determination, achievement, and the beauty of nature. dramatic high quality']


  0%|          | 0/8 [00:00<?, ?it/s]

time step: 1.0
time step: 0.95703125
time step: 0.90625
time step: 0.83984375
time step: 0.7578125
time step: 0.65625
time step: 0.51171875
time step: 0.310546875


The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['determination, achievement, and the beauty of nature. dramatic high quality']


  0%|          | 0/16 [00:00<?, ?it/s]

time step: 1.0
time step: 0.98046875
time step: 0.95703125
time step: 0.93359375
time step: 0.90625
time step: 0.875
time step: 0.83984375
time step: 0.80078125
time step: 0.7578125
time step: 0.7109375
time step: 0.65625
time step: 0.58984375
time step: 0.51171875
time step: 0.421875
time step: 0.310546875
time step: 0.173828125


The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['determination, achievement, and the beauty of nature. dramatic high quality']


  0%|          | 0/24 [00:00<?, ?it/s]

time step: 1.0
time step: 0.98828125
time step: 0.97265625
time step: 0.95703125
time step: 0.94140625
time step: 0.921875
time step: 0.90625
time step: 0.8828125
time step: 0.86328125
time step: 0.83984375
time step: 0.81640625
time step: 0.7890625
time step: 0.7578125
time step: 0.7265625
time step: 0.69140625
time step: 0.65625
time step: 0.61328125
time step: 0.56640625
time step: 0.51171875
time step: 0.453125
time step: 0.38671875
time step: 0.310546875
time step: 0.22265625
time step: 0.12060546875


The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['determination, achievement, and the beauty of nature. dramatic high quality']


  0%|          | 0/32 [00:00<?, ?it/s]

time step: 1.0
time step: 0.98828125
time step: 0.98046875
time step: 0.96875
time step: 0.95703125
time step: 0.9453125
time step: 0.93359375
time step: 0.91796875
time step: 0.90625
time step: 0.890625
time step: 0.875
time step: 0.859375
time step: 0.83984375
time step: 0.8203125
time step: 0.80078125
time step: 0.78125
time step: 0.7578125
time step: 0.734375
time step: 0.7109375
time step: 0.68359375
time step: 0.65625
time step: 0.625
time step: 0.58984375
time step: 0.5546875
time step: 0.51171875
time step: 0.46875
time step: 0.421875
time step: 0.369140625
time step: 0.310546875
time step: 0.24609375
time step: 0.173828125
time step: 0.09228515625


The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['determination, achievement, and the beauty of nature. dramatic high quality']


  0%|          | 0/40 [00:00<?, ?it/s]

time step: 1.0
time step: 0.9921875
time step: 0.984375
time step: 0.9765625
time step: 0.96484375
time step: 0.95703125
time step: 0.9453125
time step: 0.9375
time step: 0.92578125
time step: 0.9140625
time step: 0.90625
time step: 0.89453125
time step: 0.87890625
time step: 0.8671875
time step: 0.85546875
time step: 0.83984375
time step: 0.82421875
time step: 0.80859375
time step: 0.79296875
time step: 0.77734375
time step: 0.7578125
time step: 0.7421875
time step: 0.72265625
time step: 0.69921875
time step: 0.6796875
time step: 0.65625
time step: 0.62890625
time step: 0.6015625
time step: 0.57421875
time step: 0.546875
time step: 0.51171875
time step: 0.478515625
time step: 0.44140625
time step: 0.400390625
time step: 0.357421875
time step: 0.310546875
time step: 0.259765625
time step: 0.2041015625
time step: 0.142578125
time step: 0.07470703125
