In [1]:
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [None]:
import torch
import math
import matplotlib.pyplot as plt

from PIL import Image
from diffusers import FluxPipeline
from torch import Tensor
from torchvision import transforms


DTYPE = torch.bfloat16
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=DTYPE)
# pipe.load_lora_weights("/root/autodl-tmp/data/2rf-full-lora.safetensors")
# pipe.load_lora_weights("/root/autodl-tmp/lora_ckpt/reflow-dev-various/checkpoint-4500/pytorch_lora_weights.safetensors")
# pipe.load_lora_weights("/root/autodl-tmp/data/Flux_Aquarell_Watercolor_v2.safetensors", adapter_name="water")
# pipe.set_adapters(["accelerate", "water"], adapter_weights=[0.8, 0.8])
# pipe.fuse_lora(adapter_names=["accelerate", "water"], lora_scale=1.0)
pipe.to("cuda")

In [3]:
@torch.inference_mode()
def decode_imgs(latents, pipeline):
    imgs = (latents / pipeline.vae.config.scaling_factor) + pipeline.vae.config.shift_factor
    imgs = pipeline.vae.decode(imgs)[0]
    imgs = pipeline.image_processor.postprocess(imgs, output_type="pil")
    return imgs

@torch.inference_mode()
def encode_imgs(imgs, pipeline):
    latents = pipeline.vae.encode(imgs).latent_dist.sample()
    latents = (latents - pipeline.vae.config.shift_factor) * pipeline.vae.config.scaling_factor
    latents = latents.to(dtype=torch.bfloat16)
    return latents

def get_noise(
    num_samples: int,
    height: int,
    width: int,
    device: torch.device,
    dtype: torch.dtype,
    seed: int,
):
    return torch.randn(  # [B, 16, H // 8, W // 8], latents after VAE
        num_samples,
        16,
        2 * math.ceil(height / 16),
        2 * math.ceil(width / 16),
        device=device,
        dtype=dtype,
        generator=torch.Generator(device=device).manual_seed(seed),
    )

In [None]:
def time_shift(mu: float, sigma: float, t: Tensor):
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

def get_lin_function(
    x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
):
    m = (y2 - y1) / (x2 - x1)
    b = y1 - m * x1
    return lambda x: m * x + b

def get_schedule(
    num_steps: int,
    image_seq_len: int,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
    shift: bool = True,
) -> list[float]:
    timesteps = torch.linspace(1, 0, num_steps + 1)
    if shift:
        mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
        timesteps = time_shift(mu, 1.0, timesteps)
    return timesteps.tolist()

timesteps = get_schedule( # shape: [num_inference_steps]
            num_steps=4,
            image_seq_len=(1024 // 16) * (1024 // 16), # vae_scale_factor = 16
            shift=True,  # Set True for Flux-dev, False for Flux-schnell
        )

print(timesteps)

In [None]:
@torch.inference_mode()
def forward_denoise(pipeline, num_steps, prompt, resolution=512, guidance_scale=3.5, seed=0):
    timesteps = get_schedule( # shape: [num_inference_steps]
            num_steps=num_steps,
            image_seq_len=(resolution // 16) * (resolution // 16), # vae_scale_factor = 16
            shift=True,  # Set True for Flux-dev, False for Flux-schnell
        )
    
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt=prompt, prompt_2=prompt)

    noise = get_noise( # save, shape [num_samples, 16, resolution // 8, resolution // 8]
        num_samples=1,
        height=resolution,
        width=resolution,
        device="cuda",
        dtype=torch.bfloat16,
        seed=seed,
    )

    latent_image_ids = FluxPipeline._prepare_latent_image_ids(
        noise.shape[0],
        noise.shape[2],
        noise.shape[3],
        noise.device,
        torch.bfloat16,
    )

    packed_latents = FluxPipeline._pack_latents( # shape [num_samples, (resolution // 16 * resolution // 16), 16 * 2 * 2]
        noise,
        batch_size=noise.shape[0],
        num_channels_latents=noise.shape[1],
        height=noise.shape[2],
        width=noise.shape[3],
    )
    
    # Reversed denoising loop in latent space
    with pipeline.progress_bar(total=len(timesteps)-1) as progress_bar:
        for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
            t_vec = torch.full((packed_latents.shape[0],), t_curr, dtype=packed_latents.dtype, device=packed_latents.device)
            guidance_vec = torch.full((packed_latents.shape[0],), guidance_scale, device=packed_latents.device, dtype=packed_latents.dtype)
            print(f"time step: {t_vec[0]}")
            pred = pipeline.transformer(
                    hidden_states=packed_latents, # shape: [batch_size, seq_len, num_channels_latents], e.g. [1, 4096, 64] for 1024x1024
                    timestep=t_vec,        # range: [0, 1]
                    guidance=guidance_vec, # scalar guidance values for each sample in the batch
                    pooled_projections=pooled_prompt_embeds, # CLIP text embedding
                    encoder_hidden_states=prompt_embeds,     # T5 text embedding
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    joint_attention_kwargs=None,
                    return_dict=pipeline,
                )[0]
            packed_latents = packed_latents + (t_prev - t_curr) * pred
            progress_bar.update()
    
    img_latents = FluxPipeline._unpack_latents( # save, shape [num_samples, 16, resolution//8, resolution//8]
            packed_latents,
            height=resolution,
            width=resolution,
            vae_scale_factor=pipeline.vae_scale_factor,
    )
    return img_latents

# prompt = 'A water color painting of a lone hiker standing triumphantly on the summit of a snow-capped mountain. The hiker, wearing a red jacket and backpack, is silhouetted against a clear blue sky. The word "ASCEND" is painted in bold, blue letters at the base of the mountain, adding a powerful message to the image. The overall impression is one of determination, achievement, and the beauty of nature. Dramatic high quality'

# prompt = 'A high resolution photo of a scientist, white background, photo-realistic, high-detail'

# prompt = 'A vibrant, starry night sky illuminates a lively street café, with warm golden lights spilling from its windows. The café is nestled on a narrow cobblestone street, surrounded by rustic buildings with swirling, textured brushstrokes. Bold, dynamic colors—deep blues and glowing yellows—fill the scene. People are seated at small round tables, sipping coffee, and chatting. The atmosphere is cozy and inviting, yet full of movement and energy, capturing the timeless essence of a Van Gogh painting.'

# prompt = 'Jewelry design, a ring with bright rose-cut blue diamonds, surrounded by small lily-of-the-valley flower-shaped diamonds, golden stems form the ring of the ring. The center of the base is a beautiful rose gold, with a detachable black ring on both sides'

# prompt = 'An exquisite gothic queen vampiress with dark blue hair and crimson red eyes: Her sensuous white skin gleams in the atmospheric, dense fog, creating an epic and dramatic mood. This hyper-realistic portrait is filled with morbid beauty, from her gothicattire to the intense lighting that highlights every intricate detail. The scene combines glamour with dark, mysterious elements, blending fantasy and horror in a visually stunning way.'

# prompt = "A photo of a street performer juggling, with a sign that reads 'Tips Welcome' next to him."

prompt = "A billboard advertising a new movie, titled 'The Great Adventure', in bold cinematic style."


steps_list = [1, 4, 8, 10, 16, 24]

# steps_list = [40]

fig, axes = plt.subplots(2, 3, figsize=(15, 10), dpi=300)

for i, num_steps in enumerate(steps_list):
    img_latents = forward_denoise(pipe, num_steps=num_steps, prompt=prompt, 
                                  resolution=1024, guidance_scale=3.5, seed=100)
    out = decode_imgs(img_latents, pipe)[0]
    
    ax = axes[i // 3, i % 3]
    ax.imshow(out)
    ax.set_title(f"{num_steps} Steps")
    ax.axis('off') 

plt.tight_layout()
plt.show()