In [1]:
import torch
from tqdm.notebook import tqdm
from diffusers import ZImagePipeline
from nunchaku import NunchakuZImageTransformer2DModel
from nunchaku.utils import get_precision
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import os
import random

In [2]:
prompts = [
    "Children playing in a park",
    "Patients sitting in a hospital waiting room",
    "Shoppers walking through wide supermarket aisles",
    "People standing on a train station platform",
    "Commuters waiting at a city bus stop",
    "Pedestrians walking on a wide city sidewalk",
    "Guests standing in a spacious hotel lobby",
    "Visitors walking through a museum gallery",
    "Students walking in a university corridor",
    "Customers waiting in line at a bank",
    "Employees walking in a modern office lobby",
    "Passengers walking in a subway station",
    "Shoppers walking through a wise supermarket aisle"
]

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
print(f"Using device: {device}, dtype: {dtype}")

data_dir = 'dataset/backgrounds'
if os.path.exists(data_dir):
    candidate = data_dir
    i = 0
    while os.path.exists(candidate):
        i += 1
        candidate = f"{data_dir}_{i}"
    data_dir = candidate

os.makedirs(data_dir, exist_ok=False)

Using device: cuda, dtype: torch.bfloat16


In [4]:
def load_zit_pipeline() -> ZImagePipeline:
    rank = 256
    transformer = NunchakuZImageTransformer2DModel.from_pretrained(
        f"nunchaku-tech/nunchaku-z-image-turbo/svdq-{get_precision()}_r{rank}-z-image-turbo.safetensors")

    pipe = ZImagePipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", transformer=transformer, torch_dtype=dtype)
    pipe.enable_model_cpu_offload()

    return pipe


def zit_generate(pipe, prompt: str, width: int = 1280, height: int = 720, seed=None) -> Image.Image:
    if not seed:
        seed = random.randint(0, 2 ** 31 - 1)
    generator = torch.Generator(device=device).manual_seed(seed)
    img = pipe(
        prompt=prompt,
        num_inference_steps=8, guidance_scale=0.0,
        generator=generator, width=width, height=height,
    ).images[0]

    return img

In [5]:
torch.cuda.empty_cache()
pipeline = load_zit_pipeline()

svdq-int4_r256-z-image-turbo.safetensors:   0%|          | 0.00/4.55G [00:00<?, ?B/s]

quantization_config: {'method': 'svdquant', 'weight': {'dtype': 'int4', 'scale_dtype': None, 'group_size': 64}, 'activation': {'dtype': 'int4', 'scale_dtype': None, 'group_size': 64}, 'rank': 256, 'skip_refiners': False}, rank=256, skip_refiners=False


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
img_per_prompt = 1

total = len(prompts) * img_per_prompt
counter = 0

with tqdm(total=total) as pbar:
    for _ in range(img_per_prompt):
        for p in prompts:
            img = zit_generate(pipeline, p, width=1280, height=720)
            img.save(f'{data_dir}/{counter}.png')
            counter = counter + 1
            pbar.update(1)

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]