In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import time
import matplotlib.pyplot as plt
from diffusers import (
    FluxTransformer2DModel,
    FlowMatchEulerDiscreteScheduler,
    FluxPipeline,
    AutoencoderKL,
    AutoencoderTiny,
)
from torchvision import transforms
from PIL import Image
from einops import rearrange, repeat, reduce
import math

In [None]:
from utils.upcasting import (
    LayerwiseUpcastingGranularity,
    apply_layerwise_upcasting,
    apply_cached_layerwise_upcasting_pytorch_layer,
    get_module_size,
    cast_trainable_parameters
)
from utils.offload_all import apply_offload_all_hook

device = torch.device("cuda:0")
dtype = torch.bfloat16
# dtype = torch.float32
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
transformer: FluxTransformer2DModel = pipe.transformer
transformer.train()
transformer.enable_gradient_checkpointing()
transformer.requires_grad_(False);

In [None]:
from peft import LoraConfig

print(f"Before Installing LoRA: {get_module_size(pipe.transformer) / 1e9:.2f} GB")
rank = 128
target_modules = [
        "attn.to_k",
        "attn.to_q",
        "attn.to_v",
        "attn.to_out.0",
        "attn.add_k_proj",
        "attn.add_q_proj",
        "attn.add_v_proj",
        "attn.to_add_out",
        "ff.net.0.proj",
        "ff.net.2",
        "ff_context.net.0.proj",
        "ff_context.net.2",
]

transformer_lora_config = LoraConfig(
    r=rank,
    lora_alpha=rank,
    init_lora_weights=True,
    target_modules=target_modules,
    lora_bias=True,
) # type: ignore

transformer.add_adapter(transformer_lora_config)
# cast_trainable_parameters(transformer, dtype=dtype)
print(f"After Installing LoRA: {get_module_size(pipe.transformer) / 1e9:.2f} GB")

In [None]:
print(f"Before Installing Hooks: {get_module_size(pipe.transformer) / 1e9:.2f} GB")
apply_layerwise_upcasting(
    pipe.transformer,
    storage_dtype=torch.float8_e4m3fn,
    compute_dtype=dtype,
    granularity=LayerwiseUpcastingGranularity.PYTORCH_LAYER,
)
apply_offload_all_hook(
    pipe,
    execution_device=device,
    offload_device="cpu",
    submodules=["vae", "text_encoder", "text_encoder_2", "transformer"],
)
# def _cast(x):
#     if x.dtype == torch.bfloat16:
#         return x.to(torch.float32)
#     return x
# transformer._apply(_cast)
print(f"After Installing Hooks: {get_module_size(pipe.transformer) / 1e9:.2f} GB")

In [None]:
def show_dtypes():
    lora_layer: Any = transformer.transformer_blocks[0].attn.to_q # type: ignore
    print(f"Base layer dtype: {lora_layer.base_layer.weight.dtype}")
    print(f"LoRA layer dtype: {lora_layer.lora_A.default.weight.dtype}")
    if lora_layer.lora_A.default.weight.grad is not None:
        print(f"LoRA grad dtype: {lora_layer.lora_A.default.weight.grad.dtype}")
        
show_dtypes()

In [None]:
learnable_params = list(filter(lambda p: p.requires_grad, transformer.parameters()))
learnable_params_count = sum(p.numel() for p in learnable_params)
print(f"Learnable Parameters Count: {learnable_params_count / 1e6:.2f} M")

optimizer = torch.optim.AdamW(
    transformer.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01
)

In [None]:
hidden_states = torch.randn(1, 8192, 64, device=device, dtype=dtype)
encoder_hidden_states = torch.randn(1, 512, 4096, device=device, dtype=dtype)
pooled_projections = torch.randn(1, 768, device=device, dtype=dtype)
timestep = torch.tensor([1.], device=device, dtype=dtype)
img_ids = torch.randn(8192, 3, device=device, dtype=dtype)
txt_ids = torch.randn(512, 3, device=device, dtype=dtype)
guidance = torch.tensor([3.5], device=device, dtype=torch.float32)

print(f"Initial Memory Usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
res = transformer(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
    pooled_projections=pooled_projections,
    timestep=timestep,
    img_ids=img_ids,
    txt_ids=txt_ids,
    guidance=guidance,
    return_dict=True
)
print(f"After Forward Memory Usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

In [None]:
loss = torch.mean(res.sample)
loss.backward()
print(f"After Backward Memory Usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
show_dtypes()

In [None]:
optimizer.step()
print(f"After Optimizer Step Memory Usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
show_dtypes()

In [None]:
time_before = time.time()
optimizer.zero_grad()
show_dtypes()
res = transformer(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
    pooled_projections=pooled_projections,
    timestep=timestep,
    img_ids=img_ids,
    txt_ids=txt_ids,
    guidance=guidance,
    return_dict=True
)
loss = torch.mean(res.sample)
loss.backward()
optimizer.step()
time_after = time.time()
print(f"After Second Step Memory Usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Time Taken: {time_after - time_before:.2f} seconds")

In [None]:
time_before = time.time()
optimizer.zero_grad()
for i in range(3):
    res = transformer(
        hidden_states=hidden_states,
        encoder_hidden_states=encoder_hidden_states,
        pooled_projections=pooled_projections,
        timestep=timestep,
        img_ids=img_ids,
        txt_ids=txt_ids,
        guidance=guidance,
        return_dict=True
    )
    loss = torch.mean(res.sample)
    print(f">> Iteration {i} Before Backward: ")
    show_dtypes()
    loss.backward()
    print(f">> Iteration {i} After Backward: ")
    show_dtypes()
optimizer.step()
time_after = time.time()
print(f"After Second Step Memory Usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Time Taken: {time_after - time_before:.2f} seconds")