In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt
from diffusers import (
    FluxTransformer2DModel,
    FlowMatchEulerDiscreteScheduler,
    FluxPipeline,
    AutoencoderKL,
    AutoencoderTiny,
)
from torchvision import transforms
from PIL import Image
from einops import rearrange, repeat, reduce
from tqdm.notebook import tqdm
import math

## Model Loading

Load the memory optimized Flux.1-Dev model. The model weights are downcasted to fp8 but activations are still in bf16. With the cached layerwise upcasting and gradient checkpointing trick, it is possible to back-propagate through the model on 24GB GPUs.

Memory footprint:

- T5 encoder weights: 9.5GB (will be automatically offloaded when not needed)
- Flux transformer weights: 12GB
- Activations for backward (to `hidden_states`): 1.65GB (gradient checkpointing enabled)

TAEF1 is used for live preview of the sampling loop.

In [None]:
from pipeline import create_low_vram_flux_pipeline, freeze_pipeline

device = torch.device("cuda:0")
dtype = torch.bfloat16
pipe = create_low_vram_flux_pipeline(device)
freeze_pipeline(pipe)
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(
    device
)

## Image and Latent Space

In the implementation of Flux from diffusers, the image tensor is in BCHW format, raneged in $[-1, 1]$. The VAE latent space has 16 channels, and spatial resolution is reduced 8 times from the input image. However, the transformer model require 2x2 pixelshuffled input as its `hidden_states`. Then the `hidden_states` are in BND format, where N is $H\times W/256$ and D is 64.

In the following cells, `latent` refers to the transformer's `hidden_states`, and `image` refers to decoded image tensor.

In [None]:
height = 1024
width = 1024


def show_image_tensor(x):
    x = x.detach().cpu()
    return pipe.image_processor.postprocess(x, output_type="pil")[0]


def load_image_tensor(path):
    img = Image.open(path)
    img = pipe.image_processor.preprocess(img, height=height, width=width)
    return img.to(device=device, dtype=dtype)


def image_to_latent(img: torch.Tensor):
    vae: AutoencoderKL = pipe.vae
    latent: torch.Tensor = vae.encode(img).latent_dist.sample()
    latent = (latent - vae.config.shift_factor) * vae.config.scaling_factor
    latent = rearrange(latent, "b c (h nh) (w nw) -> b (h w) (c nh nw)", nh=2, nw=2)
    return latent


def make_empty_latent(height: int, width: int):
    latent = torch.randn(
        1,
        (height * width // 256),
        pipe.transformer.config.in_channels,
        device=device,
        dtype=dtype,
    )
    return latent


def latent_to_image(
    latent: torch.Tensor, height: int, width: int, vae: AutoencoderKL = None
):
    if vae is None:
        vae = pipe.vae
    n = pipe.vae_scale_factor * 2
    latent = rearrange(
        latent,
        "b (h w) (c nh nw) -> b c (h nh) (w nw)",
        h=height // n,
        w=width // n,
        nh=2,
        nw=2,
    )
    latent = latent / vae.config.scaling_factor + vae.config.shift_factor
    img: torch.Tensor = vae.decode(latent, return_dict=False)[0]
    return img

In [None]:
img = load_image_tensor("assets/editing_cat.png")
latent_img = image_to_latent(img)
img_decoded = latent_to_image(latent_img, height, width)
show_image_tensor(img_decoded)

In [None]:
_gallery = []

def add_gallery(img, tag):
    global _gallery
    _gallery.append((show_image_tensor(img), tag))

def show_gallery(n_cols=6):
    global _gallery
    n = len(_gallery)
    if n == 0:
        print("Nothing to show")
        return
    if n < n_cols:
        n_cols = n
    n_rows = math.ceil(n / n_cols)
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3.2))
    for i, (img, tag) in enumerate(_gallery):
        ax = axs[i]
        ax.imshow(img)
        ax.axis("off")
        ax.set_title(tag)
    fig.subplots_adjust(hspace=0.01, wspace=0.01)
    fig.tight_layout()
    plt.show()
    _gallery = []

In [None]:
def calculate_timestep_shift(t: torch.Tensor, latent_len: int):
    scfg = pipe.scheduler.config
    m = (scfg.max_shift - scfg.base_shift) / (
        scfg.max_image_seq_len - scfg.base_image_seq_len
    )
    b = scfg.base_shift - m * scfg.base_image_seq_len
    mu = m * latent_len + b
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1))

In [None]:
n_timesteps = 28
t = torch.linspace(1.0, 0.0, n_timesteps + 1)
t = calculate_timestep_shift(t, latent_img.shape[1])
plt.plot(t.numpy())
plt.show()

In [None]:
def make_velocity_function(prompt: str, height: int, width: int):
    prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
        prompt=prompt,
        prompt_2=prompt,
        device=device,
    )
    n = pipe.vae_scale_factor * 2
    latent_image_ids = FluxPipeline._prepare_latent_image_ids(
        batch_size=1, height=height // n, width=width // n, device=device, dtype=dtype
    )
    guidance = torch.tensor([3.5], device=device, dtype=dtype)

    def velocity(latent: torch.Tensor, t: torch.Tensor):
        transformer: FluxTransformer2DModel = pipe.transformer
        if isinstance(t, float):
            timestep = torch.tensor([t], device=device, dtype=dtype)
        elif t.dim() == 0:
            timestep = t.unsqueeze(0)
        noise_pred = transformer(
            hidden_states=latent,
            timestep=timestep,
            guidance=guidance,
            pooled_projections=pooled_prompt_embeds,
            encoder_hidden_states=prompt_embeds,
            txt_ids=text_ids,
            img_ids=latent_image_ids,
        )[0]
        return noise_pred

    return velocity


def sample_preview(latent, velocity, t):
    n_timesteps = len(t) - 1
    latent_display = display(
        show_image_tensor(latent_to_image(latent, height, width, vae=taef1)),
        display_id="latent_preview",
    )
    for i in tqdm(range(n_timesteps)):
        noise_pred = velocity(latent, t[i])
        latent_hp = latent.to(torch.float32)
        noise_pred_hp = noise_pred.to(torch.float32)
        latent_hp = latent_hp + noise_pred_hp * (t[i + 1] - t[i])
        latent = latent_hp.to(dtype)
        latent_display.update(
            show_image_tensor(latent_to_image(latent, height, width, vae=taef1))
        )
    latent_display.update(show_image_tensor(latent_to_image(latent, height, width)))
    return latent

In [None]:
n_timesteps = 28
velocity = make_velocity_function("a sitting orange cat in winter", height, width)
latent = make_empty_latent(height, width)
t_shift = calculate_timestep_shift(torch.linspace(1.0, 0.0, n_timesteps + 1), latent.shape[1])
latent = sample_preview(latent, velocity, t_shift)

## Inversion

Almost useless with less than 100 iterations.

In [None]:
t_inv = 0.8
latent = image_to_latent(img)
add_gallery(img, "Original")
t_shift = calculate_timestep_shift(
    torch.linspace(0.0, t_inv, n_timesteps + 1), latent.shape[1]
)
latent_inv = sample_preview(latent, velocity, t_shift)
add_gallery(latent_to_image(latent_inv, height, width), f"Inverse t={t_inv}")
t_shift = calculate_timestep_shift(
    torch.linspace(t_inv, 0.0, n_timesteps + 1), latent.shape[1]
)
latent = sample_preview(latent_inv, velocity, t_shift)
add_gallery(latent_to_image(latent, height, width), f"Reconstructed")
show_gallery()

## Interp Editor

Original FlowGrad:

$$
Z_{i+1} = Z_i + u_i + (t_{i+1}-t_i) \cdot v(Z_i, t_i)
$$

Interp Editor:

$$
Z_{i+1} = Z_i + (t_{i+1}-t_i) \cdot v'(Z_i, t_i)
$$

$$
v'(Z, t) = \epsilon v(Z, t) + (1-\epsilon) v_{\text{target}}(Z, t)
$$

$$
v_{\text{target}}(Z_i, t_i) = \frac{1}{t}(X_{\text{target}} - Z_i)
$$

To apply FlowGrad to the Interp Editor, we have 2 options to put $u_i$:

1. (Same as Original) As addition to $Z_i$. $\epsilon$ is only used to initialize $u_i$.
2. Replace $v_{\text{target}}$ with $u_i$. $\epsilon$ is always used.

In [None]:
def make_eta_schedule(start_t, end_t, eta, eta_trend, alpha=0.0):
    def eta_schedule(t):
        if t < start_t or t > end_t:
            return 0
        tau = (t - start_t) / (end_t - start_t)
        if eta_trend == "constant":
            return eta
        elif eta_trend == "linear_decrease":
            return eta * (1 - tau)
        elif eta_trend == "exponential_decrease":
            if abs(alpha) < 1e-5:
                return eta * (1 - tau)
            else:
                numerator = math.exp(alpha * tau) - 1
                denominator = math.exp(alpha) - 1
                return eta * (1 - (numerator / denominator))
        else:
            raise NotImplementedError(f"Unsupported eta_trend: {eta_trend}")

    if start_t > end_t:
        start_t = 1 - start_t
        end_t = 1 - end_t
        return lambda t: eta_schedule(1 - t)
    return eta_schedule


eta_fn = make_eta_schedule(1, 0.75, 0.93, "exponential_decrease", alpha=2.0)
eta = [eta_fn(t) for t in torch.linspace(1, 0, 100)]
plt.plot(eta)

In [None]:
def flowgrad1_init_u(latent, latent_target, velocity, t, eta_fn):
    u = []
    n_timesteps = len(t) - 1
    latent_display = display(show_image_tensor(latent_to_image(latent, height, width, vae=taef1)), display_id="latent_preview")
    for i in tqdm(range(n_timesteps)):
        v_orig = velocity(latent, t[i]).to(torch.float32)
        latent = latent.to(torch.float32)
        
        eta = eta_fn(t[i])
        v_target = (latent - latent_target) / t[i]
        v_total = eta * v_target + (1 - eta) * v_orig
        u_i = v_total - v_orig
        u.append(u_i)
        latent = latent + v_total * (t[i+1] - t[i])
        
        latent = latent.to(dtype)
        latent_display.update(show_image_tensor(latent_to_image(latent, height, width, vae=taef1)))
        
    latent_display.update(show_image_tensor(latent_to_image(latent, height, width)))
    return u, latent

In [None]:
latent = make_empty_latent(height, width)
latent_target = image_to_latent(img).to(torch.float32)
velocity = make_velocity_function("a photo of a sitting tiger", height, width)
t = calculate_timestep_shift(torch.linspace(1.0, 0.0, n_timesteps + 1), latent.shape[1])
eta_fn = make_eta_schedule(1, 0.75, 0.93, "exponential_decrease", alpha=2.0)
u, latent_out = flowgrad1_init_u(latent, latent_target, velocity, t, eta_fn)