Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wan2.1 transformer gguf load error #11088

Open
Passenger12138 opened this issue Mar 17, 2025 · 4 comments
Open

wan2.1 transformer gguf load error #11088

Passenger12138 opened this issue Mar 17, 2025 · 4 comments
Labels
bug Something isn't working

Comments

@Passenger12138
Copy link

Passenger12138 commented Mar 17, 2025

Describe the bug

I am testing the performance of the Wan2.1 image-to-video generation on an RTX 4090 using Diffusers' Wan2.1 model【https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers】 and the City96 quantized GGUF model【https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/tree/main】. I referred to this documentation: https://huggingface.co/docs/diffusers/quantization/gguf. However, I encountered an error while trying to load the GGUF model—it seems to be downloading additional data that I don't require. If you could provide any suggestions to resolve this, I would greatly appreciate it.

Image

Reproduction

import os
import torch
import numpy as np
from diffusers import (
    AutoencoderKLWan, 
    WanImageToVideoPipeline, 
    WanTransformer3DModel, 
    UniPCMultistepScheduler, 
    GGUFQuantizationConfig, 
    export_to_video, 
    load_image
)
from transformers import CLIPVisionModel

# 模型路径
model_id = "/share/haobang.geng/cache/Wan2.1-I2V-14B-480P-Diffusers"

# 加载图像编码器
image_encoder = CLIPVisionModel.from_pretrained(
    model_id, subfolder="image_encoder", torch_dtype=torch.float32
)

# 加载VAE模块
vae = AutoencoderKLWan.from_pretrained(
    model_id, subfolder="vae", torch_dtype=torch.float32
)

# 加载Transformer模型并启用量化
ckpt_path = "/share/haobang.geng/cache/Wan2.1-I2V-14B-480P-gguf/wan2.1-i2v-14b-480p-Q4_1.gguf"
transformer = WanTransformer3DModel.from_single_file(
    ckpt_path,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)

# 配置调度器
scheduler = UniPCMultistepScheduler(
    prediction_type='flow_prediction',
    use_flow_sigmas=True,
    num_train_timesteps=1000,
    flow_shift=3.0  # 对应480P的flow_shift参数
)

# 创建管道
pipe = WanImageToVideoPipeline.from_pretrained(
    model_id, 
    transformer=transformer, 
    vae=vae, 
    image_encoder=image_encoder, 
    torch_dtype=torch.bfloat16, 
    scheduler=scheduler
)

pipe.enable_model_cpu_offload()

# 加载输入图像
image = load_image("/share/haobang.geng/dataset/character-imgs/1.jpeg")
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))

# 输入prompt
prompt = (
    "Anime thick painted cartoon illustration, a blonde girl gracefully lifted her wine glass and gently sipped a mouthful of red wine. "
    "She has delicate facial features, purple eyes shimmering with wisdom, and cheeks slightly red, appearing particularly charming. "
    "She was dressed in luxurious traditional attire, adorned with a golden headpiece, and the background was a blurry indoor scene with faintly visible wooden structures. "
    "Soft light and shadow effects create a classical and romantic atmosphere. Close up half body close-up perspective."
)
negative_prompt = (
    "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, "
    "JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, "
    "still picture, messy background, three legs, many people in the background, walking backwards"
)

# 执行图生视频生成
output = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=81,
    guidance_scale=5.0,
    num_inference_steps=40,
).frames[0]

os.makedirs("/share/haobang.geng/code/wanx-baseline/results/check-wanmulti-framework/gguf", exist_ok=True)
# 导出视频
export_to_video(
    output, 
    "/share/haobang.geng/code/wanx-baseline/results/check-wanmulti-framework/gguf/wanx-diffusers-Q4.mp4", 
    fps=15
)

Logs

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-5.15.0-71-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.0
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.29.3
  • Transformers version: 4.46.2
  • Accelerate version: 1.5.2
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
    NVIDIA GeForce RTX 4090, 24564 MiB
    NVIDIA GeForce RTX 4090, 24564 MiB
    NVIDIA GeForce RTX 4090, 24564 MiB
    NVIDIA GeForce RTX 4090, 24564 MiB
    NVIDIA GeForce RTX 4090, 24564 MiB
    NVIDIA GeForce RTX 4090, 24564 MiB
    NVIDIA GeForce RTX 4090, 24564 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@DN6 @a-r-r-o-w

@Passenger12138 Passenger12138 added the bug Something isn't working label Mar 17, 2025
@hlky
Copy link
Member

hlky commented Mar 17, 2025

This is working on my end, the config is downloaded from Wan-AI/Wan2.1-I2V-14B-480P-Diffusers. If you have connection issues to the hub try export HF_ENDPOINT=https://hf-mirror.com

import torch
from diffusers import (
    WanTransformer3DModel,
    GGUFQuantizationConfig,
)

from huggingface_hub import hf_hub_download

path = hf_hub_download(
    repo_id="city96/Wan2.1-I2V-14B-480P-gguf", filename="wan2.1-i2v-14b-480p-Q4_1.gguf"
)
transformer = WanTransformer3DModel.from_single_file(
    path,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)

@nitinmukesh
Copy link

Hello @hlky

Does diffusers Wan2.1 support GGUF for text encoder?
https://huggingface.co/city96/umt5-xxl-encoder-gguf/tree/main

@hlky
Copy link
Member

hlky commented Mar 17, 2025

@nitinmukesh, not atm, UMT5 comes from transformers, request support there https://github.com/huggingface/transformers/issues/new/choose

Here is a minimal reproduction you can share with transformers team.

import torch
from transformers import UMT5EncoderModel

from huggingface_hub import hf_hub_download

path = hf_hub_download(
    repo_id="city96/umt5-xxl-encoder-gguf", filename="umt5-xxl-encoder-Q8_0.gguf"
)
text_encoder = UMT5EncoderModel.from_pretrained(
    "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers", # or any Hub path with the correct config
    subfolder="text_encoder",
    gguf_file=path,
    torch_dtype=torch.bfloat16,
)

@nitinmukesh
Copy link

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants