Run this code on a A100 GPU

In [None]:
!pip -q install --upgrade pip
!pip -q install git+https://github.com/huggingface/diffusers
!pip -q install git+https://github.com/huggingface/transformers accelerate
!pip -q install gradio pillow safetensors sentencepiece bitsandbytes qwen-vl-utils

In [None]:
import gc
import random
import psutil
import torch
from PIL import Image

def mem(tag=""):
    ram_gb = psutil.virtual_memory().used / 1e9
    if torch.cuda.is_available():
        v_alloc = torch.cuda.memory_allocated() / 1e9
        v_res   = torch.cuda.memory_reserved() / 1e9
        print(f"[{tag}] RAM={ram_gb:.1f}GB | VRAM alloc={v_alloc:.1f}GB reserved={v_res:.1f}GB")
    else:
        print(f"[{tag}] RAM={ram_gb:.1f}GB | CUDA unavailable")

def cleanup(tag="cleanup"):
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    mem(tag)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
print("DEVICE:", DEVICE, "| DTYPE:", DTYPE)
if DEVICE == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")
cleanup("startup")


# ---------------------------
# Presets (platform sizes)
# ---------------------------
ASPECT_PRESETS = {
    "Instagram Post (1:1) — 768×768": (768, 768),
    "Instagram Story (9:16) — 768×1344": (768, 1344),
    "Banner (16:9) — 1024×576": (1024, 576),
    "Poster (3:4) — 832×1104": (832, 1104),
    "1:1 — 1328×1328 (hi-res)": (1328, 1328),
    "9:16 — 928×1664 (hi-res)": (928, 1664),
    "16:9 — 1664×928 (hi-res)": (1664, 928),
    "3:4 — 1104×1472 (hi-res)": (1104, 1472),
    "4:3 — 1472×1104 (hi-res)": (1472, 1104),
}

DEFAULT_NEG = (
    "low resolution, blurry, jpeg artifacts, messy layout, "
    "distorted text, misspellings, gibberish letters, watermark, logo"
)


# ---------------------------
# 3) Step A: Draft -> Brief (VL)
# ---------------------------
def describe_draft_image(draft_pil: Image.Image, language: str = "English") -> str:
    """
    Loads Qwen2.5-VL-3B-Instruct in 4-bit, creates a design brief, then unloads the model.
    RAM-safe: model exists only during this function call.
    """
    from transformers import AutoProcessor, BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration
    from qwen_vl_utils import process_vision_info

    cleanup("before VL load")

    if DEVICE != "cuda":
        raise RuntimeError("VL description on CPU will be too slow / RAM heavy. Use GPU.")
    bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)


    vl_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-3B-Instruct",
        quantization_config=bnb_cfg,
        device_map={"": 0},
        low_cpu_mem_usage=True,
        torch_dtype="auto",
    )
    vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

    mem("after VL load")

    img = draft_pil.convert("RGB")
    img.thumbnail((1024, 1024))

    prompt = f"""
You are a senior creative director. Analyze the uploaded poster draft.

Return a concise, usable design brief in {language}.
Return exactly these sections:
- Layout:
- Style:
- Color palette:
- Typography:
- Keep:
- Improve:
- Notes:

Do NOT invent brand logos or copyrighted slogans.
Be concrete (grid, margins, hierarchy, spacing).
""".strip()

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    text = vl_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = vl_processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    out = vl_model.generate(**inputs, max_new_tokens=220)
    brief = vl_processor.batch_decode(out, skip_special_tokens=True)[0].strip()
    del vl_model, vl_processor, inputs, out
    cleanup("after VL unload")

    return brief


# ---------------------------
# 4) Step B: Brief + product info -> Prompt
# ---------------------------
def build_final_prompt(
    product_name: str,
    product_desc: str,
    offer: str,
    price: str,
    cta: str,
    benefits_text: str,
    tone: str,
    style_keywords: str,
    draft_brief: str | None,
    language: str,
) -> str:
    brief_block = f"\n\nDRAFT-BASED DESIGN BRIEF:\n{draft_brief}\n" if draft_brief else ""
    return f"""
Create a high-converting e-commerce promotional poster in {language}. Clean grid, strong hierarchy.

- Name: "{product_name}"
- Description: "{product_desc}"
- Offer headline (exact): "{offer}"
- Price (exact): "{price}"
- CTA button text (exact): "{cta}"
- Benefits (use these exact phrases, no typos):
{benefits_text}

- Tone: {tone}
- Style keywords: {style_keywords}

- Text must be legible and correctly spelled.
- No extra words, no fake prices, no random letters.
- Align to a neat grid with consistent margins.

{brief_block}

Output: premium, realistic lighting, clean composition, readable typography.
""".strip()


# ---------------------------
# 5) Step C: Generate Poster (T2I)
# ---------------------------
from diffusers import QwenImagePipeline

def load_t2i_pipeline():
    cleanup("before T2I load")

    t2i = QwenImagePipeline.from_pretrained(
    "Qwen/Qwen-Image-2512",
    torch_dtype=DTYPE,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    ).to("cuda", dtype=DTYPE)
    if getattr(t2i, "vae", None) is not None:
        t2i.vae.to(dtype=torch.float32)
        try:
            t2i.vae.enable_tiling()
        except Exception:
            pass

    mem("after T2I load")
    return t2i

t2i = load_t2i_pipeline()

def generate_poster(
    prompt: str,
    negative_prompt: str,
    preset_name: str,
    steps: int,
    true_cfg_scale: float,
    seed: int,
):
    w, h = ASPECT_PRESETS[preset_name]
    gen = torch.Generator(device=DEVICE).manual_seed(int(seed))

    img = t2i(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=int(w),
        height=int(h),
        num_inference_steps=int(steps),
        true_cfg_scale=float(true_cfg_scale),
        generator=gen,
    ).images[0]
    return img

In [None]:
# ---------------------------
# 6) Gradio UI
# ---------------------------
import gradio as gr

def run_pipeline(
    draft_img,
    use_draft_brief,
    platform_preset,
    product_name,
    product_desc,
    offer,
    price,
    cta,
    benefits,
    tone,
    style_keywords,
    language,
    negative_prompt,
    steps,
    true_cfg_scale,
    seed,
):
    if seed is None or int(seed) < 0:
        seed = random.randint(0, 2**31 - 1)
    else:
        seed = int(seed)
    brief = None
    if use_draft_brief:
        if draft_img is None:
            raise gr.Error("You enabled 'Use draft brief' but did not upload an image.")
        draft_pil = Image.fromarray(draft_img).convert("RGB")
        brief = describe_draft_image(draft_pil, language=language)
    final_prompt = build_final_prompt(
        product_name=product_name,
        product_desc=product_desc,
        offer=offer,
        price=price,
        cta=cta,
        benefits_text=benefits,
        tone=tone,
        style_keywords=style_keywords,
        draft_brief=brief,
        language=language,
    )

    out_img = generate_poster(
        prompt=final_prompt,
        negative_prompt=negative_prompt,
        preset_name=platform_preset,
        steps=int(steps),
        true_cfg_scale=float(true_cfg_scale),
        seed=seed,
    )

    return brief or "(draft brief disabled)", final_prompt, out_img

with gr.Blocks(title="Qwen Poster Studio") as demo:
    gr.Markdown("## Qwen Poster Studio")
    gr.Markdown(
        "- Uses **Qwen2.5-VL 4-bit** only temporarily to extract a design brief, then unloads it.\n"
        "- Keeps **Qwen-Image-2512** loaded on GPU for repeated generations.\n"
    )

    with gr.Row():
        with gr.Column(scale=1):
            draft = gr.Image(label="Upload rough draft poster (optional)", type="numpy")
            use_draft = gr.Checkbox(value=True, label="Use draft → design brief (loads VL briefly, then unloads)")

            platform = gr.Dropdown(
                choices=list(ASPECT_PRESETS.keys()),
                value="Instagram Post (1:1) — 768×768",
                label="Platform preset (aspect ratio / size)"
            )

            product_name = gr.Textbox(value="RÅSKOG Utility Cart", label="Product name")
            product_desc = gr.Textbox(value="Compact rolling cart for small spaces. Durable metal frame.", lines=2, label="Product description")
            offer = gr.Textbox(value="NEW YEAR SALE — 20% OFF", label="Offer headline (exact)")
            price = gr.Textbox(value="$29.99", label="Price (exact)")
            cta = gr.Textbox(value="Shop now", label="CTA button text (exact)")
            benefits = gr.Textbox(value="- Compact\n- Easy to move\n- Fits small spaces", lines=4, label="Benefits (exact phrases)")

            tone = gr.Dropdown(["Premium", "Minimal", "Bold", "Playful", "Tech"], value="Premium", label="Tone")
            style = gr.Textbox(value="Scandinavian minimal, clean grid, soft warm lighting, premium product photography", label="Style keywords")
            language = gr.Dropdown(["English", "中文"], value="English", label="Language")

            negative = gr.Textbox(value=DEFAULT_NEG, label="Negative prompt", lines=2)

            with gr.Row():
                steps = gr.Slider(10, 60, value=25, step=1, label="Steps")
                cfg = gr.Slider(1.0, 8.0, value=4.0, step=0.1, label="true_cfg_scale")

            seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
            btn = gr.Button("Generate", variant="primary")

        with gr.Column(scale=1):
            brief_out = gr.Textbox(label="Draft-based design brief", lines=10)
            prompt_out = gr.Textbox(label="Final prompt sent to Qwen-Image-2512", lines=10)
            out_img = gr.Image(label="Output poster")

    btn.click(
        fn=run_pipeline,
        inputs=[
            draft, use_draft, platform,
            product_name, product_desc, offer, price, cta, benefits,
            tone, style, language,
            negative, steps, cfg, seed
        ],
        outputs=[brief_out, prompt_out, out_img]
    )

demo.launch(share=True, debug = True)
