<a href="https://colab.research.google.com/github/Schwaldlander/DomainNameSuggest/blob/main/model_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This code demonstrates intuitively the outputs of the model.

 Domain Suggestion Demo (Colab)
 ===========================
 - Loads base model + optional LoRA adapter
 - Generates strict-JSON suggestions for a given brief
 - Runs spec checks: length / TLD / digits / hyphen / ASCII
 - Displays a clean preview table

 Tip:
 Set ADAPTER=None to run the raw base model
  Set IMPROVED_ADAPTER to compare two checkpoints

In [None]:
# ===========================


!pip -q install "transformers>=4.43" "peft>=0.12" "bitsandbytes>=0.43" pandas sentencepiece

import json, re, os, torch, pandas as pd
from typing import Dict, Any, List, Tuple, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# --------------------------
# Config: model + adapters
# --------------------------
BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"  # change if needed

# Put your paths here (or None to skip)
ADAPTER = "/content/checkpoints/baseline_qlora_fixed"  # e.g., your baseline LoRA
IMPROVED_ADAPTER = "/content/checkpoints/dpo_v1"       # e.g., improved LoRA
# IMPROVED_ADAPTER = None

# Decoding params
MAX_NEW_TOKENS = 380
TEMPERATURE = 0.6
TOP_P = 0.9
DO_SAMPLE = True

# --------------------------
# Briefs (same dataset schema)
# --------------------------
job_networking_brief = {
    "brief_id": "job_networking_001",
    "title": "Job Networking Platform",
    "language": "en",
    "script": "Latin",
    "tone": "professional, trustworthy, aspirational",
    "keywords": ["jobs","hire","connect","career","talent","work","network","match","growth"],
    "constraints": {
        "max_len": 14,
        "allowed_tlds": [".com",".io",".ai"],
        "forbid_digits": True,
        "forbid_hyphens": True,
        "ascii_only": True
    },
    "notes": "SaaS platform for job seekers & recruiters"
}

donut_shop_brief = {
    "brief_id": "donut_shop_001",
    "title": "Local Donut Shop",
    "language": "en",
    "script": "Latin",
    "tone": "friendly, cozy, welcoming",
    "keywords": ["donut","coffee","sweet","local"],
    "constraints": {
        "max_len": 12,
        "allowed_tlds": [".com",".shop",".cafe"],
        "forbid_digits": True,
        "forbid_hyphens": True,
        "ascii_only": True
    },
    "notes": "Neighborhood bakery + cafe"
}

# Pick which brief to demo
BRIEF = job_networking_brief  # or donut_shop_brief

# --------------------------
# Utilities
# --------------------------
def spec_checks(brief: Dict[str, Any], domain: str) -> Tuple[bool, List[str]]:
    reasons = []
    try:
        name = domain.split(".")[0]
        tld = "." + domain.split(".")[-1]
    except Exception:
        return False, ["parse_failed"]

    max_len = brief["constraints"].get("max_len", 12)
    allowed = brief["constraints"].get("allowed_tlds", [])
    forbid_digits = brief["constraints"].get("forbid_digits", True)
    forbid_hyphens = brief["constraints"].get("forbid_hyphens", True)
    ascii_only = brief["constraints"].get("ascii_only", True)

    if len(name) > max_len:
        reasons.append("length_exceeded")
    if forbid_digits and any(ch.isdigit() for ch in name):
        reasons.append("digits_forbidden")
    if forbid_hyphens and "-" in name:
        reasons.append("hyphen_forbidden")
    if ascii_only and not name.isascii():
        reasons.append("non_ascii")
    if allowed and tld not in allowed:
        reasons.append("tld_not_allowed")

    return (len(reasons) == 0), reasons

def extract_first_json(text: str) -> Optional[dict]:
    # Extract the first top-level JSON object {...}
    stack, start = 0, -1
    for i, ch in enumerate(text):
        if ch == "{":
            if stack == 0:
                start = i
            stack += 1
        elif ch == "}":
            stack -= 1
            if stack == 0 and start >= 0:
                snippet = text[start:i+1]
                try:
                    return json.loads(snippet)
                except Exception:
                    return None
    return None

def build_prompt(brief: Dict[str, Any]) -> str:
    sys = "You generate brand-safe domain suggestions and strictly refuse unsafe requests. Output strict JSON only."
    user = f"""[BRIEF]
title: {brief['title']}
language: {brief.get('language','en')}
tone: {brief.get('tone','')}
keywords: {', '.join(brief.get('keywords', []))}
constraints:
  max_len: {brief['constraints'].get('max_len')}
  allowed_tlds: {', '.join(brief['constraints'].get('allowed_tlds', []))}
  forbid_digits: {brief['constraints'].get('forbid_digits')}
  forbid_hyphens: {brief['constraints'].get('forbid_hyphens')}
  ascii_only: {brief['constraints'].get('ascii_only')}
"""
    return f"<|im_start|>system\n{sys}\n<|im_end|>\n<|im_start|>user\n{user}\n<|im_end|>\n<|im_start|>assistant\n"

def load_model(base_model: str, adapter_dir: Optional[str] = None):
    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    tok = AutoTokenizer.from_pretrained(base_model, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        quantization_config=bnb_cfg,
        device_map="auto",
        torch_dtype=torch.bfloat16
    )
    model.eval()

    if adapter_dir and os.path.isdir(adapter_dir):
        try:
            model = PeftModel.from_pretrained(model, adapter_dir, is_trainable=False)
            print(f"Loaded LoRA adapter from: {adapter_dir}")
        except Exception as e:
            print("Warning: could not load adapter:", e)

    return tok, model

def generate_json(tok, model, brief: Dict[str, Any]) -> Dict[str, Any]:
    prompt = build_prompt(brief)
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=DO_SAMPLE,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            eos_token_id=tok.eos_token_id
        )
    text = tok.decode(out_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
    obj = extract_first_json(text)
    if obj is None:
        return {"error": "parse_error", "raw": text[:600]}
    return obj

def normalize_output(obj: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    Returns a list of suggestion records:
    [{"domain":..., "rationale":..., "spec_ok":bool, "spec_reasons":[...]}, ...]
    """
    suggestions = obj.get("suggestions", [])
    rows = []
    for s in suggestions:
        if not isinstance(s, dict):
            continue
        domain = s.get("domain")
        rationale = s.get("rationale", "")
        if isinstance(domain, str):
            ok, reasons = spec_checks(BRIEF, domain)
            rows.append({
                "domain": domain,
                "rationale": rationale,
                "spec_ok": ok,
                "spec_reasons": reasons
            })
    return rows

def show_table(rows: List[Dict[str, Any]], title: str):
    if not rows:
        print(title, "— (no suggestions)")
        return
    df = pd.DataFrame(rows)
    # sort: spec_ok first, then length
    df["name_len"] = df["domain"].apply(lambda d: len(d.split(".")[0]) if isinstance(d,str) else 0)
    df = df.sort_values(by=["spec_ok","name_len"], ascending=[False, True]).drop(columns=["name_len"])
    print(title)
    display(df)

# --------------------------
# Run: single model demo
# --------------------------
tok, model = load_model(BASE_MODEL, ADAPTER)
out = generate_json(tok, model, BRIEF)
rows = normalize_output(out)
show_table(rows, title=f"Suggestions — {'adapter: '+ADAPTER if ADAPTER else 'base model'}")

# --------------------------
# Optional: compare improved adapter
# --------------------------
if IMPROVED_ADAPTER:
    tok2, model2 = load_model(BASE_MODEL, IMPROVED_ADAPTER)
    out2 = generate_json(tok2, model2, BRIEF)
    rows2 = normalize_output(out2)
    show_table(rows2, title=f"Suggestions — improved adapter: {IMPROVED_ADAPTER}")

    # Quick side-by-side unique domains (spec_ok only)
    s1 = [r["domain"] for r in rows if r["spec_ok"]]
    s2 = [r["domain"] for r in rows2 if r["spec_ok"]]
    print("\nOverlap (spec-ok):", set(s1) & set(s2))
    print("Baseline-only:", set(s1) - set(s2))
    print("Improved-only:", set(s2) - set(s1))
