In [None]:
import torch
import inspect
from typing import List, Optional, Union
from diffusers import FluxPipeline
import numpy as np
import time
# import tqdm
from tqdm import *
import os
import matplotlib.pyplot as plt

In [3]:
def calculate_shift(
    image_seq_len,
    base_seq_len: int = 256,
    max_seq_len: int = 4096,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
):
    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
    b = base_shift - m * base_seq_len
    mu = image_seq_len * m + b
    return mu


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    sigmas: Optional[List[float]] = None,
    **kwargs,
):
    
    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps

In [4]:
def initial_pipe(model,data_type=torch.float16,cpu_flag=True,compile_flag=False,compile_flag_all=False):
    if model == 'dev':
        model_id = "black-forest-labs/FLUX.1-dev"
        pipe = FluxPipeline.from_pretrained("C:/Users/DELL/.cache/huggingface/flux.1-dev-directly-download", torch_dtype=data_type)
    else:
        model_id = "black-forest-labs/FLUX.1-schnell"
        pipe = FluxPipeline.from_pretrained("C:/Users/DELL/.cache/huggingface/flux.1-schnell-directly-download", torch_dtype=data_type)
    if cpu_flag:
        pipe.enable_model_cpu_offload()
    else:
        pipe.to("cuda")
    if compile_flag:
        pipe.transformer = torch.compile(pipe.transformer)
    if compile_flag_all:
        pipe.vae = torch.compile(pipe.vae)
        pipe.text_encoder = torch.compile(pipe.text_encoder)
        pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2)

    pipe.vae.requires_grad_(False)
    pipe.text_encoder_2.requires_grad_(False)
    pipe.text_encoder.requires_grad_(False)
    pipe.transformer.requires_grad_(False)
    pipe.transformer.config.guidance_embeds=True
    return pipe

In [5]:
def get_embeds(pipe,prompt,num_images):
    device = pipe._execution_device
    pipe.text_encoder.to(device)
    (
        prompt_embeds,
        pooled_prompt_embeds,
        text_ids,
    ) = pipe.encode_prompt(
        prompt=prompt,
        prompt_2 = None,
        device=device,
        num_images_per_prompt=num_images,
        lora_scale=None,
    )   
    pipe.text_encoder.cpu()
    pipe.text_encoder_2.cpu()
    torch.cuda.empty_cache()
    # print(pipe_schnell.text_encoder.device)
    return prompt_embeds,pooled_prompt_embeds,text_ids

In [6]:
def prepare_latents(pipe,num_images,generator,model_dtype):
    device = pipe._execution_device
    height = pipe.default_sample_size * pipe.vae_scale_factor
    width = pipe.default_sample_size * pipe.vae_scale_factor
    num_channels_latents = pipe.transformer.config.in_channels // 4
    
    latents, latent_image_ids = pipe.prepare_latents(
        num_images,
        num_channels_latents,
        height,
        width,
        model_dtype,
        device,
        generator,
        latents=None
    )
    return latents, latent_image_ids

In [None]:
def prepare_timesteps(pipe, num_inference_steps,latents):
    device = pipe._execution_device
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        pipe.scheduler.config.get("base_image_seq_len", 256),
        pipe.scheduler.config.get("max_image_seq_len", 4096),
        pipe.scheduler.config.get("base_shift", 0.5),
        pipe.scheduler.config.get("max_shift", 1.15),
    )
    timesteps, num_inference_steps = retrieve_timesteps(
        pipe.scheduler,
        num_inference_steps,
        device,
        sigmas=sigmas,
        mu=mu,
    )
    # num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)
    pipe._num_timesteps = len(timesteps)

    return timesteps, num_inference_steps

def prepare_timesteps_modified(pipe, num_inference_steps,cur_len=4096):
    device = pipe._execution_device
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
    image_seq_len = cur_len
    mu = calculate_shift(
        image_seq_len,
        pipe.scheduler.config.get("base_image_seq_len", 256),
        pipe.scheduler.config.get("max_image_seq_len", 4096),
        pipe.scheduler.config.get("base_shift", 0.5),
        pipe.scheduler.config.get("max_shift", 1.15),
    )
    timesteps, num_inference_steps = retrieve_timesteps(
        pipe.scheduler,
        num_inference_steps,
        device,
        sigmas=sigmas,
        mu=mu,
    )
    # num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)
    pipe._num_timesteps = len(timesteps)

    return timesteps, num_inference_steps

In [8]:
def handle_guidance(pipe,guidance_scale,latents):
    device = pipe._execution_device
    if pipe.transformer.config.guidance_embeds:
        guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
        guidance = guidance.expand(latents.shape[0])
    else:
        guidance = None
    return guidance

def handle_guidance_modified(pipe,guidance_scale,len_images):
    device = pipe._execution_device
    if pipe.transformer.config.guidance_embeds:
        guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
        guidance = guidance.expand(len_images)
    else:
        guidance = None
    return guidance

In [9]:
def pred_noise(pipe,latents,timestep,guidance,pooled_prompt_embeds,prompt_embeds,text_ids,latent_image_ids):
    # device = pipe._execution_device
    # pipe.transformer.to(device = pipe._execution_device)
    with torch.no_grad():
        noise_pred = pipe.transformer(
            hidden_states=latents,
            timestep=timestep / 1000,
            guidance=guidance,
            pooled_projections=pooled_prompt_embeds,
            encoder_hidden_states=prompt_embeds,
            txt_ids=text_ids,
            img_ids=latent_image_ids,
            joint_attention_kwargs={},
            return_dict=False,
        )[0]
    return noise_pred

In [10]:
def pred_noise_with_track(pipe,latents,timestep,guidance1,pooled_prompt_embeds,prompt_embeds,text_ids,latent_image_ids,guidance2=1):
    with torch.no_grad():
        noise_pred1 = pipe.transformer(
            hidden_states=latents,
            timestep=timestep / 1000,
            guidance=guidance1,
            pooled_projections=pooled_prompt_embeds,
            encoder_hidden_states=prompt_embeds,
            txt_ids=text_ids,
            img_ids=latent_image_ids,
            joint_attention_kwargs={},
            return_dict=False,
        )[0]

    with torch.no_grad():
        noise_pred2 = pipe.transformer(
            hidden_states=latents,
            timestep=timestep / 1000,
            guidance=guidance2,
            pooled_projections=pooled_prompt_embeds,
            encoder_hidden_states=prompt_embeds,
            txt_ids=text_ids,
            img_ids=latent_image_ids,
            joint_attention_kwargs={},
            return_dict=False,
        )[0]
    
    noise_gap_origion = (noise_pred1 - noise_pred2)/guidance1[0]
    
    return noise_pred1,noise_gap_origion.cpu()

def pred_noise_with_track_modified(pipe,latents,timestep,guidance1,pooled_prompt_embeds,prompt_embeds,text_ids,latent_image_ids,guidance2=1):
    opt_model = torch.compile(pipe.transformer)
    with torch.no_grad():
        noise_pred1 = opt_model(
            hidden_states=latents,
            timestep=timestep / 1000,
            guidance=guidance1,
            pooled_projections=pooled_prompt_embeds,
            encoder_hidden_states=prompt_embeds,
            txt_ids=text_ids,
            img_ids=latent_image_ids,
            joint_attention_kwargs={},
            return_dict=False,
        )[0]

    with torch.no_grad():
        noise_pred2 = opt_model(
            hidden_states=latents,
            timestep=timestep / 1000,
            guidance=guidance2,
            pooled_projections=pooled_prompt_embeds,
            encoder_hidden_states=prompt_embeds,
            txt_ids=text_ids,
            img_ids=latent_image_ids,
            joint_attention_kwargs={},
            return_dict=False,
        )[0]
    
    noise_gap_origion = (noise_pred1 - noise_pred2)/guidance1[0]
    
    return noise_pred1,noise_gap_origion.cpu()

In [11]:
def denoise_cur(pipe,noise_pred, t, latents):
    latents_dtype = latents.dtype
    latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    if latents.dtype != latents_dtype:
        if torch.backends.mps.is_available():
            # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
            latents = latents.to(latents_dtype)
    return latents

In [12]:
def decode_image(pipe,latents,offload_flag=True):
    pipe._current_timestep = None
    height = pipe.default_sample_size * pipe.vae_scale_factor
    width = pipe.default_sample_size * pipe.vae_scale_factor

    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor

    output_type="pil"
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type=output_type)

    # Offload all models
    if offload_flag:
        pipe.maybe_free_model_hooks()
    else:
        pipe.vae.cpu()
        pipe.text_encoder_2.cpu()
        pipe.text_encoder.cpu()
        pipe.transformer.cpu()
        torch.cuda.empty_cache()

    return image

"A city skyline at sunset with clouds forming the words 'Together we rise, apart we fall. Embrace unity!'"

In [13]:
source_dir = "guidance_nv_track_trial2/prompt2"

In [None]:
#guidance_scale = 6
guidance_scale = 6
cur_path = "dev-guidance{}".format(guidance_scale)
num_images_per_prompt = 1
num_inference_steps= 28


model_dtype = torch.float16
pipe_schnell = initial_pipe("dev",data_type=model_dtype,cpu_flag=True,compile_flag=True,compile_flag_all=False)
nv_gap_store1 = []
total_images = 100

for num in tqdm(range(total_images)):
    prompt = "A city skyline at sunset with clouds forming the words 'Together we rise, apart we fall. Embrace unity!'"
    prompt_embeds,pooled_prompt_embeds,text_ids = get_embeds(pipe_schnell,prompt,num_images=num_images_per_prompt)
    guidance = handle_guidance_modified(pipe_schnell,guidance_scale,num_images_per_prompt)
    guidance_2 = handle_guidance_modified(pipe_schnell,1,num_images_per_prompt)

    cur_nv_store = []
    seed = 42+num
    generator=torch.Generator("cpu").manual_seed(seed)
    latents, latent_image_ids = prepare_latents(pipe_schnell,num_images_per_prompt,generator,model_dtype)
    timesteps, num_inference_steps = prepare_timesteps_modified(pipe_schnell, num_inference_steps,cur_len=4096)

    for t in timesteps:
        pipe_schnell._current_timestep = t
        timestep = t.expand(latents.shape[0]).to(latents.dtype)

        noise_pred,nv_gap_origin = pred_noise_with_track_modified(pipe_schnell,latents,timestep,guidance,pooled_prompt_embeds,prompt_embeds,text_ids,latent_image_ids,guidance_2)
        latents = denoise_cur(pipe_schnell,noise_pred, t, latents)

        cur_nv_store.append(nv_gap_origin.squeeze(0))
        del noise_pred
        torch.cuda.empty_cache()

    images = decode_image(pipe_schnell,latents,offload_flag=False)
    image_path = os.path.join(source_dir,cur_path,"dev-guidance{}-{}.png".format(guidance_scale,num))
    images[0].save(image_path)
    nv_gap_store1.append(torch.stack(cur_nv_store))

    del latents,prompt_embeds,pooled_prompt_embeds,text_ids,latent_image_ids,timesteps,guidance,guidance_2
    torch.cuda.empty_cache()

nv_gap_1 = torch.stack(nv_gap_store1)
nv_path = os.path.join(source_dir,cur_path,"dev-guidance{}.pth".format(guidance_scale))
torch.save(nv_gap_1,nv_path)
torch.cuda.empty_cache()