In [1]:
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [None]:
import torch
import math
import matplotlib.pyplot as plt
from diffusers import FluxPipeline

In [None]:
DTYPE = torch.bfloat16

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=DTYPE)
pipe.to("cuda") 

In [None]:
from matplotlib import pyplot as plt

prompt = "A high-impact Telegram post with the text 'ATTENTION!': the background is a vibrant and intense gradient of red and orange, with a subtle radial burst effect emanating from the center. The word 'ATTENTION!' is placed prominently in the center in a bold, sans-serif font with a metallic finish, slightly tilted for added dynamism. Surrounding the text, there are subtle, glowing lines and digital glitch effects, creating a sense of urgency and importance. The overall style is modern and eye-catching, perfect for grabbing attention on social media."

out = pipe(
    prompt=prompt,
    guidance_scale=3.5,
    height=1024,
    width=1024,
    num_inference_steps=50,
	generator=torch.Generator(device='cuda').manual_seed(0)
).images[0]

plt.figure(figsize=(8, 8), dpi=300)
plt.imshow(out)

In [None]:
from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    FluxPipeline,
    FluxTransformer2DModel,
)
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast, CLIPTextModel, T5EncoderModel
from reflow.flux_utils import encode_imgs, decode_imgs, get_noise, get_schedule

def get_models(pretrained_model_name_or_path):
    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="scheduler",
        torch_dtype=torch.bfloat16,
    )
    tokenizer_one = CLIPTokenizer.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="tokenizer",
        torch_dtype=torch.bfloat16,
    )
    tokenizer_two = T5TokenizerFast.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="tokenizer_2",
        torch_dtype=torch.bfloat16,
    )
    text_encoder_one = CLIPTextModel.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        torch_dtype=torch.bfloat16,
    )
    text_encoder_two = T5EncoderModel.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder_2",
        torch_dtype=torch.bfloat16,
    )
    vae = AutoencoderKL.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="vae",
        torch_dtype=torch.bfloat16,
    )
    transformer = FluxTransformer2DModel.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="transformer",
        torch_dtype=torch.bfloat16,
    )

    transformer.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder_one.requires_grad_(False)
    text_encoder_two.requires_grad_(False)

    return (
        scheduler,
        tokenizer_one,
        tokenizer_two,
        text_encoder_one,
        text_encoder_two,
        vae,
        transformer
    )

# enforce text encoder to be fp32
(
	scheduler,
	tokenizer_one,
	tokenizer_two,
	text_encoder_one,
	text_encoder_two,
	vae,
	transformer
) = get_models("black-forest-labs/FLUX.1-dev")

pipeline = FluxPipeline(
	scheduler=scheduler,
	tokenizer=tokenizer_one,
	text_encoder=text_encoder_one,
	tokenizer_2=tokenizer_two,
	text_encoder_2=text_encoder_two,
	vae=vae,
	transformer=transformer,
).to("cuda")

for name, param in transformer.named_parameters():
    print(name, param.shape, param.dtype)

In [None]:
from peft import LoraConfig, set_peft_model_state_dict, get_peft_model_state_dict

# 思路，先初始化 lora，然后拿到 state dict，然后 set_peft_model_state_dict
transformer_lora_config = LoraConfig(
        r=128,
        lora_alpha=128,
        init_lora_weights="gaussian", # also try "default"
        target_modules=["to_k", 
                        "to_q", 
                        "to_v", 
                        "to_out.0",
                        "add_k_proj",
                        "add_q_proj",
                        "add_v_proj",
                        "to_add_out",
                        "norm.linear",
                        "proj_mlp",
                        "proj_out",
                        "ff.net.0.proj",
                        "ff.net.2",
                        "ff_context.net.0.proj",
                        "ff_context.net.2",
                        "norm1.linear",
                        "norm1_context.linear",
                        "norm.linear",
                        "timestep_embedder.linear_1",
                        "timestep_embedder.linear_2",
                        "guidance_embedder.linear_1",
                        "guidance_embedder.linear_2",
                        ],
    )

transformer.add_adapter(transformer_lora_config)

state_dict = get_peft_model_state_dict(transformer)

for name in state_dict:
	print(name, state_dict[name].shape, state_dict[name].dtype, state_dict[name].device, state_dict[name].requires_grad)

In [None]:
state_dict = get_peft_model_state_dict(transformer)

for name, param in transformer.named_parameters():
	if param.requires_grad:
		print(name, param.shape, param.dtype, param.requires_grad)

# transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
# for param in transformer_lora_parameters:
# 	print(param.shape, param.dtype)

In [None]:
# 自己写的 load lora  

from safetensors import safe_open

lora_state_dict = {}
safetensors_path = "/root/autodl-tmp/data/2rf_inference_t.safetensors"
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
    for key in f.keys():
        lora_state_dict[key] = f.get_tensor(key)
        print(key, lora_state_dict[key].shape, lora_state_dict[key].dtype, lora_state_dict[key].device, lora_state_dict[key].requires_grad)
        
# check all key has lora
for name in lora_state_dict:
	assert "lora" in name, name

# for name, param in transformer.named_parameters():
#     # Find the matching LoRA weight key
#     lora_A_key = f"{name}.lora_A.weight"
#     lora_B_key = f"{name}.lora_B.weight"
    
#     if lora_A_key in lora_weights:
#         print(f"Setting LoRA A weights for {name}")
#         param.lora_A.weight.data = lora_weights[lora_A_key]
    
#     if lora_B_key in lora_weights:
#         print(f"Setting LoRA B weights for {name}")
#         param.lora_B.weight.data = lora_weights[lora_B_key]

In [None]:
from diffusers.utils import convert_unet_state_dict_to_peft

lora_state_dict = FluxPipeline.lora_state_dict("/root/autodl-tmp/data/2rf_inference_t.safetensors")

# for key in lora_state_dict:
# 	print(key, lora_state_dict[key].shape, lora_state_dict[key].dtype, lora_state_dict[key].device, lora_state_dict[key].requires_grad)

transformer_state_dict = {
	f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}

transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)

incompatible_keys = set_peft_model_state_dict(transformer, transformer_state_dict, adapter_name="default")

if incompatible_keys is not None:
	# check only for unexpected keys
	unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
	if unexpected_keys:
		print(
			f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
			f" {unexpected_keys}. "
		)

print(incompatible_keys)

In [None]:
state_dict = get_peft_model_state_dict(transformer)

# for name in state_dict:
# 	print(name, state_dict[name].shape, state_dict[name].dtype, state_dict[name].device, state_dict[name].requires_grad)

for name, param in transformer.named_parameters():
	if param.requires_grad:
		print(name, param.shape, param.dtype, param.requires_grad, param.device)

In [None]:
@torch.inference_mode()
def sample(prompt, height=1024, width=1024, guidance_scale=3.5):
	(
		prompt_embeds,         # save
		pooled_prompt_embeds,  # save
		text_ids,
	) = pipeline.encode_prompt(
		prompt=prompt,
		prompt_2=prompt,
		device=pipeline.device,
		max_sequence_length=512,
	)

	prompt_embeds = prompt_embeds.to(torch.bfloat16)
	pooled_prompt_embeds = pooled_prompt_embeds.to(torch.bfloat16)

	noise = get_noise(  # save, shape [num_samples, 16, height // 8, width // 8]
		num_samples=1,
		height=height,
		width=width,
		device="cuda",
		dtype=torch.bfloat16,
		seed=0,
	)

	latent_image_ids = FluxPipeline._prepare_latent_image_ids(
		noise.shape[0],
		noise.shape[2],
		noise.shape[3],
		noise.device,
		torch.bfloat16,
	)

	packed_latents = FluxPipeline._pack_latents(
		# [num_samples, (height // 16 * width // 16), 16 * 2 * 2]
		noise,
		batch_size=noise.shape[0],
		num_channels_latents=noise.shape[1],
		height=noise.shape[2],
		width=noise.shape[3],
	)

	timesteps = timesteps = get_schedule(  # shape: [num_inference_steps]
		num_steps=50,
		image_seq_len=(height // 16) * (width // 16),  # vae // 8 and patchify // 2
		shift=True,  # Set True for Flux-dev, False for Flux-schnell
	)

	with pipeline.progress_bar(total=50) as progress_bar:
		for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
			t_vec = torch.full((packed_latents.shape[0],), t_curr, dtype=packed_latents.dtype,
								device=packed_latents.device)
			guidance_vec = torch.full((packed_latents.shape[0],), guidance_scale,
										device=packed_latents.device,
										dtype=packed_latents.dtype)
			pred = transformer(
				hidden_states=packed_latents,
				timestep=t_vec,
				guidance=guidance_vec,
				pooled_projections=pooled_prompt_embeds,
				encoder_hidden_states=prompt_embeds,
				txt_ids=text_ids,
				img_ids=latent_image_ids,
				joint_attention_kwargs=None,
				return_dict=pipeline,
			)[0]
			packed_latents = packed_latents + (t_prev - t_curr) * pred
			progress_bar.update()

	assert noise.shape[2] * 8 == height and noise.shape[3] * 8 == width
	assert pipeline.vae_scale_factor == 16
	img_latents = FluxPipeline._unpack_latents(  # save, shape [num_samples, 16, height//8, width//8]
		packed_latents,
		height=height,
		width=width,
		vae_scale_factor=pipeline.vae_scale_factor,
	)

	imgs = decode_imgs(img_latents, vae, pipeline)[0]

	return imgs	

prompt = "A high-impact Telegram post with the text 'ATTENTION!': the background is a vibrant and intense gradient of red and orange, with a subtle radial burst effect emanating from the center. The word 'ATTENTION!' is placed prominently in the center in a bold, sans-serif font with a metallic finish, slightly tilted for added dynamism. Surrounding the text, there are subtle, glowing lines and digital glitch effects, creating a sense of urgency and importance. The overall style is modern and eye-catching, perfect for grabbing attention on social media."
out = sample(prompt, height=1024, width=1024, guidance_scale=3.5)
plt.figure(figsize=(8, 8), dpi=300)
plt.imshow(out)