# Local Manga AI

This notebook launches a fully-local pipeline: **Qwen2.5 → SDXL → Page Composer** and can optionally expose a public HTTPS URL via **Cloudflare Tunnel**.

Model folders:
- `models/qwen2.5/`
- `models/sdxl/`

In [None]:
import os
import warnings

# Suppress Hugging Face Hub deprecation warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*resume_download.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*local_dir_use_symlinks.*")

# Runtime settings (run this BEFORE loading any torch/diffusers models)
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

# On Lightning AI: set MANGA_AI_HF_TOKEN as a Secret/Env Var (do not hardcode tokens in notebooks)
# os.environ["MANGA_AI_HF_TOKEN"] = "..."

print("PYTORCH_CUDA_ALLOC_CONF=", os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))
print("HF token present:", bool(os.environ.get("MANGA_AI_HF_TOKEN")))
print("HF deprecation warnings suppressed")


In [None]:
import sys
from pathlib import Path

# Detect project root (folder that contains both ./scripts and ./models)
CWD = Path.cwd().resolve()
ROOT = CWD

if not ((ROOT / "scripts").exists() and (ROOT / "models").exists()):
    if (ROOT.parent / "scripts").exists() and (ROOT.parent / "models").exists():
        ROOT = ROOT.parent.resolve()

if not ((ROOT / "scripts").exists() and (ROOT / "models").exists()):
    raise RuntimeError(
        "Could not locate project ROOT. Expected folders 'scripts' and 'models' in the working directory (or its parent). "
        f"CWD={CWD}"
    )

# Make sure imports like `from scripts...` work
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

print("CWD:", CWD)
print("ROOT:", ROOT)
print("Python:", sys.version)


In [None]:
import os
from pathlib import Path

from scripts.model_downloader import ensure_models_downloaded

# Lightning AI tip: use a persistent path if your workspace provides one.
# You can override with MANGA_AI_MODELS_DIR.
MODELS_DIR = Path(os.environ.get("MANGA_AI_MODELS_DIR", str(Path(ROOT) / "models"))).resolve()
QWEN_DIR = MODELS_DIR / "qwen2.5"
QWEN_DIR.mkdir(parents=True, exist_ok=True)
print("Models dir:", MODELS_DIR)
print("Qwen dir:", QWEN_DIR)

try:
    ensure_models_downloaded(
        qwen_dir=QWEN_DIR,
        sdxl_dir=MODELS_DIR / "sdxl",  # unused by downloader (kept for compatibility)
        hf_token=os.environ.get("MANGA_AI_HF_TOKEN"),
    )
    print("Qwen model is present (downloaded if needed).")
except Exception as e:
    raise RuntimeError(
        "Qwen model download/setup failed. "
        "If the model is gated for you, accept the license on HuggingFace and set MANGA_AI_HF_TOKEN in Lightning Secrets. "
        f"Original error: {e}"
    )


In [None]:
# Quick sanity checks (doesn't load the full ML pipelines)
assert (QWEN_DIR / "config.json").exists(), "Qwen config.json not found"

print("Sanity checks passed.")


In [None]:
# (Deprecated) SDXL sanity checks removed because Animagine/SDXL was removed from the Qwen-only workflow.
# Keep this cell as a no-op so running top-to-bottom doesn't fail.
print("Skipping SDXL sanity checks (Animagine removed).")


## Launch Web UI (Lightning AI)

Running the next cell will:
- Start Gradio inside the Lightning runtime
- Lightning will expose it as a public URL automatically

Notes:
- The app binds to `0.0.0.0` and uses the `PORT` environment variable if present.
- You do **not** need Cloudflare Tunnel on Lightning.


In [None]:
import os
from typing import Any, List, Tuple

import gradio as gr
import torch

from scripts.storyboard import load_qwen_backend

# Lightning AI: bind to 0.0.0.0 and use a predictable port
HOST = "0.0.0.0"
PORT = int(os.environ.get("PORT", os.environ.get("LIGHTNING_PORT", "7860")))

_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
backend = load_qwen_backend(model_dir=QWEN_DIR, dtype=_dtype)


def _build_chat_input(tokenizer: Any, messages: List[dict]) -> torch.Tensor:
    if hasattr(tokenizer, "apply_chat_template"):
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        return tokenizer(text, return_tensors="pt").input_ids

    joined = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    return tokenizer(joined, return_tensors="pt").input_ids


@torch.inference_mode()
def _chat_generate(
    history: List[Tuple[str, str]],
    user_message: str,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
):
    tokenizer = backend.tokenizer
    model = backend.model

    system = "You are Qwen, a helpful assistant for manga story development."

    messages: List[dict] = [{"role": "system", "content": system}]
    for u, a in (history or []):
        messages.append({"role": "user", "content": u})
        messages.append({"role": "assistant", "content": a})

    messages.append({"role": "user", "content": user_message.strip()})

    input_ids = _build_chat_input(tokenizer, messages).to(model.device)

    do_sample = bool(temperature and float(temperature) > 0)
    gen_kwargs = {
        "input_ids": input_ids,
        "max_new_tokens": int(max_new_tokens),
        "do_sample": do_sample,
        "temperature": float(temperature) if do_sample else None,
        "top_p": float(top_p) if do_sample else None,
        "repetition_penalty": 1.05,
        "eos_token_id": getattr(tokenizer, "eos_token_id", None),
        "pad_token_id": getattr(tokenizer, "pad_token_id", getattr(tokenizer, "eos_token_id", None)),
    }
    gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}

    out = model.generate(**gen_kwargs)
    decoded = tokenizer.decode(out[0][input_ids.shape[-1] :], skip_special_tokens=True).strip()

    return (history or []) + [(user_message, decoded)], ""


with gr.Blocks(title="Qwen Chat") as demo:
    gr.Markdown("# Qwen Chat\nChat with your local Qwen model")

    chatbot = gr.Chatbot(height=520)
    msg = gr.Textbox(label="Message", lines=3)

    with gr.Row():
        send = gr.Button("Send")
        clear = gr.Button("Clear")

    max_new_tokens = gr.Slider(128, 2048, value=512, step=32, label="Max new tokens")
    temperature = gr.Slider(0.0, 1.2, value=0.4, step=0.05, label="Temperature")
    top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")

    def _send(history, message, max_new_tokens, temperature, top_p):
        if message is None or not str(message).strip():
            return history, ""
        return _chat_generate(history, str(message), max_new_tokens, temperature, top_p)

    send.click(_send, inputs=[chatbot, msg, max_new_tokens, temperature, top_p], outputs=[chatbot, msg])
    msg.submit(_send, inputs=[chatbot, msg, max_new_tokens, temperature, top_p], outputs=[chatbot, msg])
    clear.click(lambda: [], outputs=[chatbot])

print(f"Starting Gradio on {HOST}:{PORT} (Lightning will expose this as a public URL)")
demo.launch(server_name=HOST, server_port=PORT, share=False, inbrowser=False, prevent_thread_lock=True)
