In [5]:
import torch
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
from diffusion.utils.config import SanaConfig, model_init_config
from diffusion.utils.logger import get_root_logger

from diffusion import DPMS_SDE
from app.sana_pipeline_tts import SanaPipeline

In [None]:
device = torch.device('cuda:4')
config = "../configs/sana_config/1024ms/Sana_1600M_img1024.yaml"
model_path = "hf://Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth"
pipe = SanaPipeline(config, device)
pipe.from_pretrained(model_path)

2025-08-04 13:38:45 - [1m[Sana][0m - INFO - Sampler flow_dpm-solver, flow_shift: 3.0
2025-08-04 13:38:45 - [1m[Sana][0m - INFO - Inference with torch.bfloat16, PAG guidance layer: [8]


[1m[AutoencoderDC] Loading model from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers[0m


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
def guidance_type_select(default_guidance_type, pag_scale, attn_type):
    guidance_type = default_guidance_type
    if not (pag_scale > 1.0 and attn_type == "linear"):
        guidance_type = "classifier-free"
    elif pag_scale > 1.0 and attn_type == "linear":
        guidance_type = "classifier-free_PAG"
    return guidance_type


def classify_height_width_bin(height: int, width: int, ratios: dict):
    """Returns binned height and width."""
    ar = float(height / width)
    closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
    default_hw = ratios[closest_ratio]
    return int(default_hw[0]), int(default_hw[1])

@torch.inference_mode()
def forward(
    self,
    prompt=None,
    height=1024,
    width=1024,
    negative_prompt="",
    num_inference_steps=20,
    guidance_scale=4.5,
    pag_guidance_scale=1.0,
    num_images_per_prompt=1,
    generator=torch.Generator().manual_seed(42),
    latents=None,
    use_resolution_binning=True,
):
    self.ori_height, self.ori_width = height, width
    if use_resolution_binning:
        self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
    else:
        self.height, self.width = height, width
    self.latent_size_h, self.latent_size_w = (
        self.height // self.config.vae.vae_downsample_rate,
        self.width // self.config.vae.vae_downsample_rate,
    )
    self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)

    # 1. pre-compute negative embedding
    if negative_prompt != "":
        null_caption_token = self.tokenizer(
            negative_prompt,
            max_length=self.max_sequence_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        ).to(self.device)
        self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
            0
        ]

    if prompt is None:
        prompt = [""]
    prompts = prompt if isinstance(prompt, list) else [prompt]
    samples = []

    for prompt in prompts:
        # data prepare
        prompts, hw, ar = (
            [],
            torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
                num_images_per_prompt, 1
            ),
            torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
        )

        for _ in range(num_images_per_prompt):
            prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())

        with torch.no_grad():
            # prepare text feature
            if not self.config.text_encoder.chi_prompt:
                max_length_all = self.config.text_encoder.model_max_length
                prompts_all = prompts
            else:
                chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
                prompts_all = [chi_prompt + prompt for prompt in prompts]
                num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
                max_length_all = (
                    num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
                )  # magic number 2: [bos], [_]

            caption_token = self.tokenizer(
                prompts_all,
                max_length=max_length_all,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).to(device=self.device)
            select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
            caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
                :, :, select_index
            ].to(self.weight_dtype)
            emb_masks = caption_token.attention_mask[:, select_index]
            null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)

            n = len(prompts)
            if latents is None:
                z = torch.randn(
                    n,
                    self.config.vae.vae_latent_dim,
                    self.latent_size_h,
                    self.latent_size_w,
                    generator=generator,
                    device=self.device,
                    # dtype=self.weight_dtype,
                )
            else:
                z = latents.to(self.device)
            model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)

            scheduler = DPMS_SDE(
                self.model,
                condition=caption_embs,
                uncondition=null_y,
                guidance_type=self.guidance_type,
                cfg_scale=guidance_scale,
                pag_scale=pag_guidance_scale,
                pag_applied_layers=self.config.model.pag_applied_layers,
                model_type="flow",
                model_kwargs=model_kwargs,
                schedule="FLOW",
            )
            scheduler.register_progress_bar(self.progress_fn)
            sample = scheduler.sample(
                z,
                steps=num_inference_steps,
                order=2,
                skip_type="time_uniform_flow",
                method="multistep",
                flow_shift=self.flow_shift,
            )

        sample = sample.to(self.vae_dtype)
        with torch.no_grad():
            sample = vae_decode(self.config.vae.vae_type, self.vae, sample)

        if use_resolution_binning:
            sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
        samples.append(sample)

        return sample

    return samples