In [None]:
"""
코랩용 기존 파이프라인 vs. 파인튜닝된 파이프라인 비교 스크립트
./adapters_dpo에 존재하는 어댑터를 기준으로 GPT-Score / 비GPT-Score를 비교.
random_persona_campaign.csv의 더미 데이터를 기준으로 평가함.
비교 문서는 adapter_comparison_{timestamp}.md로 저장.
"""

In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
!pip install datasets peft trl bitsandbytes accelerate
!pip install -U transformers
!pip show transformers

Name: transformers
Version: 4.57.3
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /usr/local/lib/python3.12/dist-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: peft, sentence-transformers, trl


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
import os
print(os.getcwd())
print(os.listdir())

/content
['.config', '.env', 'drive', '.ipynb_checkpoints', 'AmoRe_crm_generator', 'sample_data']


In [5]:
# !git clone https://github.com/jjjh02/AmoRe_crm_generator.git
# %cd AmoRe_crm_generator
# !git checkout jinhyeok
# !git branch
os.chdir("/content/AmoRe_crm_generator/finetuning")
print(os.getcwd())

/content/AmoRe_crm_generator/finetuning


In [6]:
from dotenv import load_dotenv
load_dotenv()

True

In [None]:
#!/usr/bin/env python3
import argparse
import csv
import json
import os
import re
import sys
import urllib.error
import urllib.request
from collections import Counter
from contextlib import contextmanager
from datetime import datetime, timezone


BASE_DIR = os.getcwd()
SRC_DIR = os.path.abspath(os.path.join(BASE_DIR, "..", "src"))
DEFAULT_CSV = os.path.join(BASE_DIR, "random_persona_campaign.csv")
DEFAULT_ADAPTER_DIR = "/content/drive/MyDrive/멋사/adapters_dpo_2"
STAGE_ORDER = ["Acquisition", "Activation", "Retention", "Revenue", "Referral"]
out_path = "/content/drive/MyDrive/멋사/comparison_dpo/comparison_01"

print(BASE_DIR, SRC_DIR)
def _log(message):
    print(message)


def _import_pipeline_module():
    if SRC_DIR not in sys.path:
        sys.path.insert(0, SRC_DIR)
    try:
        import run_qwen_exaone_pipeline as pipeline_module
    except Exception as exc:
        raise ImportError(
            "Failed to import main from ../src/run_qwen_exaone_pipeline.py"
        ) from exc
    return pipeline_module


def _load_json(path):
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except FileNotFoundError:
        return None


def _parse_bool(value):
    if isinstance(value, bool):
        return value
    if value is None:
        return False
    if isinstance(value, (int, float)):
        return bool(value)
    text = str(value).strip().lower()
    return text in {"1", "true", "yes", "y", "t"}


def _load_rows(csv_path):
    with open(csv_path, "r", newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            if not row:
                continue
            persona_raw = row.get("persona", "").strip()
            brand_raw = row.get("brand", "").strip()
            product_raw = row.get("product", "").strip()
            stage_raw = row.get("stage_index", "").strip()
            style_raw = row.get("style_index", "").strip()
            if not persona_raw or not brand_raw or not product_raw:
                continue
            if not stage_raw or not style_raw:
                continue
            try:
                persona = int(persona_raw)
                stage_index = int(stage_raw)
                style_index = int(style_raw)
            except ValueError:
                continue
            yield {
                "persona": persona,
                "brand": brand_raw,
                "product": product_raw,
                "stage_index": stage_index,
                "style_index": style_index,
                "is_event": _parse_bool(row.get("is_event", "")),
            }


def _get_stage_name(stage_index):
    if isinstance(stage_index, int) and 0 <= stage_index < len(STAGE_ORDER):
        return STAGE_ORDER[stage_index]
    return ""


def _get_crm_goal(crm_goals, stage_index, stage_name=None):
    if not isinstance(crm_goals, dict):
        return {}
    if stage_name and stage_name in crm_goals:
        return crm_goals.get(stage_name, {}) or {}
    stage_name = _get_stage_name(stage_index)
    if stage_name:
        return crm_goals.get(stage_name, {}) or {}
    return {}


def _get_brand_story(brand_stories, brand_name):
    if not isinstance(brand_stories, dict) or not brand_name:
        return {}
    if brand_name in brand_stories:
        return brand_stories.get(brand_name, {}) or {}
    for story in brand_stories.values():
        if str(story.get("name_en", "")).lower() == brand_name.lower():
            return story
    return {}


def _format_event(selected_event):
    if selected_event in (None, "", {}):
        return "none"
    if isinstance(selected_event, dict):
        for key in ("title", "name", "event_name", "event"):
            if selected_event.get(key):
                return str(selected_event.get(key))
        return json.dumps(selected_event, ensure_ascii=False)
    return str(selected_event)


def _format_price(price):
    if price in (None, ""):
        return ""
    if isinstance(price, (int, float)):
        return f"{int(price):,} KRW"
    text = str(price).strip()
    if not text:
        return ""
    if text.replace(",", "").isdigit():
        return f"{int(text.replace(',', '')):,} KRW"
    return text


def _format_persona(persona_profile):
    if not isinstance(persona_profile, dict):
        return str(persona_profile or "")
    name = persona_profile.get("name", "")
    extras = []
    value_focus = persona_profile.get("value_focus")
    skin_type = persona_profile.get("skin_type")
    traits = persona_profile.get("traits")
    shopping_style = persona_profile.get("shopping_style")
    if value_focus:
        extras.append(str(value_focus))
    if skin_type:
        extras.append(str(skin_type))
    if traits:
        if isinstance(traits, list):
            extras.append(", ".join([str(t) for t in traits if t]))
        else:
            extras.append(str(traits))
    if shopping_style:
        extras.append(str(shopping_style))
    extra_text = ", ".join([e for e in extras if e])
    if name and extra_text:
        return f"{name} ({extra_text})"
    return name or extra_text


def _build_context_block(out, max_style_templates=3):
    persona = _format_persona(out.get("persona_profile"))
    stage = out.get("stage_name") or out.get("stage_kr") or ""
    brand = out.get("brand") or ""
    product_basic = out.get("product_basic") if isinstance(out.get("product_basic"), dict) else {}
    product_name = product_basic.get("name") or out.get("product_query") or ""
    price = _format_price(product_basic.get("price"))
    objective = out.get("objective") or ""
    target_state = out.get("target_state") or ""
    style_templates = out.get("style_templates") or []
    if isinstance(style_templates, list):
        style_templates = style_templates[:max_style_templates]
    selected_event = _format_event(out.get("selected_event"))

    lines = ["[Context]"]
    if persona:
        lines.append(f"- Persona: {persona}")
    if stage:
        lines.append(f"- Stage: {stage}")
    if brand or product_name:
        lines.append(f"- Brand/Product: {brand} / {product_name}".strip())
    if price:
        lines.append(f"- Price: {price}")
    if objective:
        lines.append(f"- Objective: {objective}")
    if target_state:
        lines.append(f"- Target state: {target_state}")
    if style_templates:
        lines.append("- Style templates:")
        for item in style_templates:
            lines.append(f"  - {item}")
    lines.append(f"- Event: {selected_event}")
    return "\n".join(lines).strip()


def _extract_message(out):
    exaone = out.get("exaone", {}) if isinstance(out, dict) else {}
    return exaone.get("result_raw") or ""


def _tokenize(text):
    if not text:
        return []
    return [t for t in re.split(r"\s+", str(text)) if len(t) > 1]


def _split_tokens(text):
    if not text:
        return []
    cleaned = re.sub(r"[^\w\uac00-\ud7a3]+", " ", str(text), flags=re.UNICODE)
    return [t for t in cleaned.split() if len(t) > 1]


def _extract_keywords(texts, max_terms=30):
    counter = Counter()
    for text in texts:
        for token in _split_tokens(text):
            if token.isdigit():
                continue
            counter[token] += 1
    if not counter:
        return []
    return [item for item, _ in counter.most_common(max_terms)]


def _coverage_score(message, out):
    total = 0
    hits = 0
    if not message:
        return 0.0

    brand = out.get("brand")
    if brand:
        total += 1
        if brand in message:
            hits += 1

    product_basic = out.get("product_basic") if isinstance(out.get("product_basic"), dict) else {}
    product_name = product_basic.get("name") or out.get("product_query") or ""
    if product_name:
        total += 1
        if product_name in message:
            hits += 1

    selected_event = _format_event(out.get("selected_event"))
    if selected_event and selected_event != "none":
        total += 1
        if selected_event in message:
            hits += 1

    stage_terms = []
    for text in (out.get("stage_kr"), out.get("objective"), out.get("target_state")):
        stage_terms.extend(_tokenize(text))
    if stage_terms:
        total += 1
        if any(term in message for term in stage_terms):
            hits += 1

    return hits / total if total else 0.0


def _tone_match_score(message, brand_story):
    if not message or not isinstance(brand_story, dict):
        return 0.0
    tone_keywords = brand_story.get("tone_keywords") or []
    if not tone_keywords:
        return 0.0
    hits = sum(1 for kw in tone_keywords if kw and kw in message)
    return hits / len(tone_keywords)


def _style_match_score(message, style_templates, max_terms=30):
    if not message or not style_templates:
        return 0.0
    if not isinstance(style_templates, list):
        style_templates = [str(style_templates)]
    keywords = _extract_keywords(style_templates, max_terms=max_terms)
    if not keywords:
        return 0.0
    hits = sum(1 for kw in keywords if kw in message)
    return hits / len(keywords)


def _info_density(message, out):
    if not message:
        return 0.0
    persona = out.get("persona_profile") if isinstance(out.get("persona_profile"), dict) else {}
    product_basic = out.get("product_basic") if isinstance(out.get("product_basic"), dict) else {}
    context_texts = [
        out.get("brand"),
        product_basic.get("name"),
        out.get("product_query"),
        out.get("stage_kr"),
        out.get("objective"),
        out.get("target_state"),
        persona.get("value_focus"),
        persona.get("skin_type"),
    ]
    if isinstance(persona.get("traits"), list):
        context_texts.extend(persona.get("traits"))
    if persona.get("shopping_style"):
        context_texts.append(persona.get("shopping_style"))

    keywords = _extract_keywords([t for t in context_texts if t], max_terms=40)
    if not keywords:
        return 0.0
    message_tokens = _split_tokens(message)
    if not message_tokens:
        return 0.0
    hits = sum(1 for kw in keywords if kw in message)
    return hits / len(message_tokens)


def _repetition_stats(message):
    tokens = _split_tokens(message)
    if not tokens:
        return 0.0, 0.0
    unique_tokens = set(tokens)
    repeat_token_ratio = (len(tokens) - len(unique_tokens)) / len(tokens)

    if len(tokens) < 6:
        return repeat_token_ratio, 0.0
    n = 3
    ngrams = [" ".join(tokens[i:i + n]) for i in range(len(tokens) - n + 1)]
    counts = Counter(ngrams)
    total_ngrams = len(ngrams)
    repeated = sum(count - 1 for count in counts.values() if count > 1)
    repeat_ngram_ratio = repeated / total_ngrams if total_ngrams else 0.0
    return repeat_token_ratio, repeat_ngram_ratio


def _length_target(stage_name):
    if stage_name == "Acquisition":
        return 60, 200
    if stage_name == "Activation":
        return 60, 200
    if stage_name == "Retention":
        return 60, 180
    if stage_name == "Revenue":
        return 60, 180
    if stage_name == "Referral":
        return 60, 160
    return 50, 220


def _length_ok(message, stage_name):
    if not message:
        return False
    min_len, max_len = _length_target(stage_name)
    return min_len <= len(message) <= max_len


def _forbidden_violations(message, crm_goal):
    if not message or not isinstance(crm_goal, dict):
        return 0
    forbidden = crm_goal.get("forbidden_context") or []
    if not forbidden:
        return 0
    hits = 0
    for term in forbidden:
        if term and term in message:
            hits += 1
    return hits


def _cta_present(message):
    if not message:
        return False
    cta_markers = [
        "\uc9c0\uae08", "\ud655\uc778", "\uad6c\ub9e4", "\uc2e0\uccad", "\ucc38\uc5ec",
        "\ud074\ub9ad", "\ubc1b\uae30", "\ud61c\ud0dd", "\ud560\uc778", "\ucfe0\ud3f0",
        "\ud574\ubcf4\uc138\uc694", "\ud558\uc138\uc694", "\ub458\ub7ec\ubcf4\uae30",
        "\ubc14\ub85c", "\ucd94\ucc9c", "\ubb38\uc758"
    ]
    return any(marker in message for marker in cta_markers)


def _call_gpt(context_block, base_message, adapter_message):
    api_key = os.environ.get("OPENAI_API_KEY")
    if not api_key:
        raise RuntimeError("OPENAI_API_KEY is not set.")

    candidate_block = (
        "[0]\n"
        f"{base_message}\n\n"
        "[1]\n"
        f"{adapter_message}"
    )

    system_prompt = (
        "You are evaluating CRM messages. Pick the message more likely to drive conversion.\n"
        "Compare with these criteria:\n"
        "1) Likelihood of action (click/repurchase)\n"
        "2) Fit to persona and stage goals\n"
        "3) Brand/product value delivery\n"
        "4) Use of style templates and event context when applicable\n"
        "5) Clarity without unnecessary decoration\n"
        "Return only the best candidate index as an integer (0 or 1)."
    )
    user_prompt = (
        "Context:\n"
        f"{context_block}\n\n"
        "Candidates:\n"
        f"{candidate_block}\n\n"
        "Return only the best candidate index."
    )

    payload = {
        "model": "gpt-5-nano",
        "input": [
            {
                "role": "system",
                "content": [{"type": "input_text", "text": system_prompt}],
            },
            {
                "role": "user",
                "content": [{"type": "input_text", "text": user_prompt}],
            },
        ],
    }

    request = urllib.request.Request(
        "https://api.openai.com/v1/responses",
        data=json.dumps(payload).encode("utf-8"),
        headers={
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        },
        method="POST",
    )

    try:
        with urllib.request.urlopen(request, timeout=30) as response:
            data = json.loads(response.read().decode("utf-8"))
    except urllib.error.HTTPError as exc:
        body = exc.read().decode("utf-8", errors="replace")
        raise RuntimeError(f"OpenAI API error {exc.code}: {body}") from exc

    output_text = _extract_response_text(data)
    match = re.search(r"-?\d+", str(output_text))
    if not match:
        raise ValueError(f"Invalid evaluator response: {output_text}")
    choice = int(match.group(0))
    if choice not in (0, 1):
        raise ValueError(f"Evaluator index out of range: {choice}")
    return choice

def _extract_response_text(data):
    if isinstance(data, dict):
        output_text = data.get("output_text")
        if isinstance(output_text, str) and output_text.strip():
            return output_text.strip()

        output = data.get("output")
        if isinstance(output, list):
            parts = []
            for item in output:
                if not isinstance(item, dict):
                    continue
                content = item.get("content", [])
                if isinstance(content, list):
                    for block in content:
                        if isinstance(block, dict) and isinstance(block.get("text"), str):
                            parts.append(block["text"])
                        elif isinstance(block, str):
                            parts.append(block)
                elif isinstance(content, str):
                    parts.append(content)
            if parts:
                return "".join(parts).strip()

    return ""

@contextmanager
def _patch_exaone(pipeline_module, adapter_path=None):
    import tone_correction

    class PatchedExaoneToneCorrector(tone_correction.ExaoneToneCorrector):
        _cache = {}

        def __init__(self, model_name="LGAI-EXAONE/EXAONE-4.0-1.2B"):
            key = (model_name, adapter_path)
            cached = self._cache.get(key)
            if cached:
                self.device = cached["device"]
                self.model_name = model_name
                self.tokenizer = cached["tokenizer"]
                self.model = cached["model"]
                return
            super().__init__(model_name=model_name)
            if adapter_path:
                try:
                    from peft import PeftModel
                except ImportError as exc:
                    raise RuntimeError("peft is required to load adapters.") from exc
                self.model = PeftModel.from_pretrained(self.model, adapter_path)
                try:
                    self.model.eval()
                except Exception:
                    pass
            self._cache[key] = {
                "device": self.device,
                "tokenizer": self.tokenizer,
                "model": self.model,
            }

    original = pipeline_module.ExaoneToneCorrector
    pipeline_module.ExaoneToneCorrector = PatchedExaoneToneCorrector
    try:
        yield
    finally:
        pipeline_module.ExaoneToneCorrector = original


def _run_pipeline_main(pipeline_main, row):
    argv = [
        "run_qwen_exaone_pipeline.py",
        "--persona",
        str(row["persona"]),
        "--brand",
        row["brand"],
        "--product",
        row["product"],
        "--stage_index",
        str(row["stage_index"]),
        "--style_index",
        str(row["style_index"]),
        "--is_event",
        "1" if row.get("is_event", False) else "0",
    ]
    old_argv = sys.argv
    try:
        sys.argv = argv
        return pipeline_main()
    finally:
        sys.argv = old_argv


def _write_report(out_path, summary, rows, max_examples):
    lines = []
    lines.append("# Adapter Comparison Report")
    lines.append("")
    lines.append(f"- CSV: {summary['csv']}")
    lines.append(f"- Adapter: {summary['adapter']}")
    lines.append(f"- Samples: {summary['samples']}")
    lines.append("")
    lines.append("## Summary")
    lines.append("")
    for item in summary["metrics"]:
        lines.append(f"- {item}")
    lines.append("")
    lines.append("## Per-sample Results")
    lines.append("")
    lines.append(
        "| idx | persona | brand/product | stage | event | gpt winner | base len | adapter len | base cov | adapter cov | base tone | adapter tone | base style | adapter style | base dens | adapter dens | base rep tok | adapter rep tok | base rep 3g | adapter rep 3g | base len ok | adapter len ok | base forb | adapter forb | base cta | adapter cta |"
    )
    lines.append(
        "| --- | --- | --- | --- | --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | --- | --- | ---: | ---: | --- | --- |"
    )
    for item in rows:
        lines.append(
            "| {idx} | {persona} | {brand_product} | {stage} | {event} | {winner} | {base_len} | {adapter_len} | {base_cov:.2f} | {adapter_cov:.2f} | {base_tone:.2f} | {adapter_tone:.2f} | {base_style:.2f} | {adapter_style:.2f} | {base_density:.2f} | {adapter_density:.2f} | {base_rep_token:.2f} | {adapter_rep_token:.2f} | {base_rep_ngram:.2f} | {adapter_rep_ngram:.2f} | {base_len_ok} | {adapter_len_ok} | {base_forbidden} | {adapter_forbidden} | {base_cta} | {adapter_cta} |".format(
                idx=item["idx"],
                persona=item["persona"],
                brand_product=item["brand_product"],
                stage=item["stage"],
                event=item["event"],
                winner=item["winner"],
                base_len=item["base_len"],
                adapter_len=item["adapter_len"],
                base_cov=item["base_cov"],
                adapter_cov=item["adapter_cov"],
                base_tone=item["base_tone"],
                adapter_tone=item["adapter_tone"],
                base_style=item["base_style"],
                adapter_style=item["adapter_style"],
                base_density=item["base_density"],
                adapter_density=item["adapter_density"],
                base_rep_token=item["base_rep_token"],
                adapter_rep_token=item["adapter_rep_token"],
                base_rep_ngram=item["base_rep_ngram"],
                adapter_rep_ngram=item["adapter_rep_ngram"],
                base_len_ok="yes" if item["base_len_ok"] else "no",
                adapter_len_ok="yes" if item["adapter_len_ok"] else "no",
                base_forbidden=item["base_forbidden"],
                adapter_forbidden=item["adapter_forbidden"],
                base_cta="yes" if item["base_cta"] else "no",
                adapter_cta="yes" if item["adapter_cta"] else "no",
            )
        )
    lines.append("")

    if max_examples > 0:
        lines.append("## Examples")
        lines.append("")
        for item in rows[:max_examples]:
            lines.append(f"### Sample {item['idx']}")
            lines.append("")
            lines.append("Context:")
            lines.append("")
            lines.append("```")
            lines.append(item["context"])
            lines.append("```")
            lines.append("")
            lines.append("Base message:")
            lines.append("")
            lines.append("```")
            lines.append(item["base_message"])
            lines.append("```")
            lines.append("")
            lines.append("Adapter message:")
            lines.append("")
            lines.append("```")
            lines.append(item["adapter_message"])
            lines.append("```")
            lines.append("")
            lines.append(f"GPT winner: {item['winner']}")
            lines.append("")

    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, "w", encoding="utf-8") as f:
        print([line for line in lines])
        f.write("\n".join(lines))


parser = argparse.ArgumentParser()
parser.add_argument("--csv_path", default=DEFAULT_CSV)
parser.add_argument("--adapter_path", default=DEFAULT_ADAPTER_DIR)
parser.add_argument("--out_path", default="content/drive/MyDrive/멋사/comparison_dpo")
parser.add_argument("--max_rows", type=int, default=10)
parser.add_argument("--max_examples", type=int, default=3)
parser.add_argument("--skip_llm_eval", action="store_true")
parser.add_argument("--max_style_templates", type=int, default=3)
args, _ = parser.parse_known_args()

if not os.path.exists(args.csv_path):
    raise FileNotFoundError(f"CSV not found: {args.csv_path}")
if not os.path.exists(args.adapter_path):
    raise FileNotFoundError(f"Adapter not found: {args.adapter_path}")

pipeline_module = _import_pipeline_module()
pipeline_main = pipeline_module.main
brand_stories = _load_json(os.path.join(BASE_DIR, "data", "brand_stories.json"))
crm_goals = _load_json(os.path.join(BASE_DIR, "data", "crm_goals.json"))

rows = []
for idx, row in enumerate(_load_rows(args.csv_path)):
    if args.max_rows is not None and idx >= args.max_rows:
        break
    rows.append(row)

if not rows:
    raise RuntimeError("No rows to evaluate.")

results = []
wins = {"base": 0, "adapter": 0}

for idx, row in enumerate(rows, start=1):
    _log(
        "[Row {idx}] persona={persona} brand={brand} product={product} "
        "stage_index={stage_index} style_index={style_index} is_event={is_event}".format(
            idx=idx,
            persona=row["persona"],
            brand=row["brand"],
            product=row["product"],
            stage_index=row["stage_index"],
            style_index=row["style_index"],
            is_event=row.get("is_event", False),
        )
    )

    _log("  Running base pipeline...")
    base_out = _run_pipeline_main(pipeline_main, row)

    _log("  Running adapter pipeline...")
    with _patch_exaone(pipeline_module, adapter_path=args.adapter_path):
        adapter_out = _run_pipeline_main(pipeline_main, row)

    base_message = _extract_message(base_out)
    adapter_message = _extract_message(adapter_out)

    context_block = _build_context_block(base_out, args.max_style_templates)
    stage_name = base_out.get("stage_name") or _get_stage_name(row["stage_index"])
    crm_goal = _get_crm_goal(crm_goals, row["stage_index"], stage_name)
    brand_story = _get_brand_story(brand_stories, base_out.get("brand"))

    winner = "n/a"
    if not args.skip_llm_eval:
        choice = _call_gpt(context_block, base_message, adapter_message)
        winner = "base" if choice == 0 else "adapter"
        wins[winner] += 1

    base_cov = _coverage_score(base_message, base_out)
    adapter_cov = _coverage_score(adapter_message, base_out)
    base_tone = _tone_match_score(base_message, brand_story)
    adapter_tone = _tone_match_score(adapter_message, brand_story)
    base_style = _style_match_score(base_message, base_out.get("style_templates"))
    adapter_style = _style_match_score(adapter_message, base_out.get("style_templates"))
    base_density = _info_density(base_message, base_out)
    adapter_density = _info_density(adapter_message, base_out)
    base_rep_token, base_rep_ngram = _repetition_stats(base_message)
    adapter_rep_token, adapter_rep_ngram = _repetition_stats(adapter_message)
    base_len_ok = _length_ok(base_message, stage_name)
    adapter_len_ok = _length_ok(adapter_message, stage_name)
    base_forbidden = _forbidden_violations(base_message, crm_goal)
    adapter_forbidden = _forbidden_violations(adapter_message, crm_goal)
    base_len = len(base_message)
    adapter_len = len(adapter_message)

    results.append(
        {
            "idx": idx,
            "persona": row["persona"],
            "brand_product": f"{row['brand']} / {row['product']}",
            "stage": base_out.get("stage_name") or base_out.get("stage_kr") or "",
            "event": _format_event(base_out.get("selected_event")),
            "winner": winner,
            "base_len": base_len,
            "adapter_len": adapter_len,
            "base_cov": base_cov,
            "adapter_cov": adapter_cov,
            "base_tone": base_tone,
            "adapter_tone": adapter_tone,
            "base_style": base_style,
            "adapter_style": adapter_style,
            "base_density": base_density,
            "adapter_density": adapter_density,
            "base_rep_token": base_rep_token,
            "adapter_rep_token": adapter_rep_token,
            "base_rep_ngram": base_rep_ngram,
            "adapter_rep_ngram": adapter_rep_ngram,
            "base_len_ok": base_len_ok,
            "adapter_len_ok": adapter_len_ok,
            "base_forbidden": base_forbidden,
            "adapter_forbidden": adapter_forbidden,
            "base_cta": _cta_present(base_message),
            "adapter_cta": _cta_present(adapter_message),
            "context": context_block,
            "base_message": base_message,
            "adapter_message": adapter_message,
        }
    )

avg_base_cov = sum(r["base_cov"] for r in results) / len(results)
avg_adapter_cov = sum(r["adapter_cov"] for r in results) / len(results)
avg_base_tone = sum(r["base_tone"] for r in results) / len(results)
avg_adapter_tone = sum(r["adapter_tone"] for r in results) / len(results)
avg_base_style = sum(r["base_style"] for r in results) / len(results)
avg_adapter_style = sum(r["adapter_style"] for r in results) / len(results)
avg_base_density = sum(r["base_density"] for r in results) / len(results)
avg_adapter_density = sum(r["adapter_density"] for r in results) / len(results)
avg_base_rep_token = sum(r["base_rep_token"] for r in results) / len(results)
avg_adapter_rep_token = sum(r["adapter_rep_token"] for r in results) / len(results)
avg_base_rep_ngram = sum(r["base_rep_ngram"] for r in results) / len(results)
avg_adapter_rep_ngram = sum(r["adapter_rep_ngram"] for r in results) / len(results)
base_len_ok_rate = sum(1 for r in results if r["base_len_ok"]) / len(results)
adapter_len_ok_rate = sum(1 for r in results if r["adapter_len_ok"]) / len(results)
base_forbidden_rate = sum(1 for r in results if r["base_forbidden"] > 0) / len(results)
adapter_forbidden_rate = sum(1 for r in results if r["adapter_forbidden"] > 0) / len(results)
base_cta_rate = sum(1 for r in results if r["base_cta"]) / len(results)
adapter_cta_rate = sum(1 for r in results if r["adapter_cta"]) / len(results)
avg_base_len = sum(r["base_len"] for r in results) / len(results)
avg_adapter_len = sum(r["adapter_len"] for r in results) / len(results)

timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
out_path = os.path.join(
    "/content/drive/MyDrive/멋사/comparison_dpo", f"adapter_comparison_{timestamp}.md"
) or args.out_path

summary = {
    "csv": args.csv_path,
    "adapter": args.adapter_path,
    "samples": len(results),
    "metrics": [
        f"GPT wins: adapter {wins['adapter']} / base {wins['base']} (skip_llm_eval={args.skip_llm_eval})",
        f"Avg coverage: adapter {avg_adapter_cov:.2f}, base {avg_base_cov:.2f}",
        f"Avg tone match: adapter {avg_adapter_tone:.2f}, base {avg_base_tone:.2f}",
        f"Avg style match: adapter {avg_adapter_style:.2f}, base {avg_base_style:.2f}",
        f"Avg info density: adapter {avg_adapter_density:.2f}, base {avg_base_density:.2f}",
        f"Repeat token ratio: adapter {avg_adapter_rep_token:.2f}, base {avg_base_rep_token:.2f}",
        f"Repeat 3-gram ratio: adapter {avg_adapter_rep_ngram:.2f}, base {avg_base_rep_ngram:.2f}",
        f"Length ok rate: adapter {adapter_len_ok_rate:.2f}, base {base_len_ok_rate:.2f}",
        f"Forbidden violation rate: adapter {adapter_forbidden_rate:.2f}, base {base_forbidden_rate:.2f}",
        f"CTA rate: adapter {adapter_cta_rate:.2f}, base {base_cta_rate:.2f}",
        f"Avg length: adapter {avg_adapter_len:.1f}, base {avg_base_len:.1f}",
    ],
}

_write_report(out_path, summary, results, args.max_examples)
_log(f"Saved report: {out_path}")

In [14]:
out_path = "/content/drive/MyDrive/멋사/comparison_dpo/comparison_01.md"

'content/drive/MyDrive/멋사/comparison_dpo'

In [18]:
_write_report(out_path, summary, results, args.max_examples)
_log(f"Saved report: {out_path}")

['# Adapter Comparison Report', '', '- CSV: /content/AmoRe_crm_generator/finetuning/random_persona_campaign.csv', '- Adapter: /content/drive/MyDrive/멋사/adapters_dpo_2', '- Samples: 10', '', '## Summary', '', '- GPT wins: adapter 5 / base 5 (skip_llm_eval=False)', '- Avg coverage: adapter 0.33, base 0.24', '- Avg tone match: adapter 0.00, base 0.00', '- Avg style match: adapter 0.05, base 0.07', '- Avg info density: adapter 0.15, base 0.19', '- Repeat token ratio: adapter 0.06, base 0.04', '- Repeat 3-gram ratio: adapter 0.00, base 0.00', '- Length ok rate: adapter 0.00, base 0.10', '- Forbidden violation rate: adapter 0.00, base 0.00', '- CTA rate: adapter 1.00, base 0.80', '- Avg length: adapter 349.8, base 297.5', '', '## Per-sample Results', '', '| idx | persona | brand/product | stage | event | gpt winner | base len | adapter len | base cov | adapter cov | base tone | adapter tone | base style | adapter style | base dens | adapter dens | base rep tok | adapter rep tok | base rep 3g

In [16]:
summary

{'csv': '/content/AmoRe_crm_generator/finetuning/random_persona_campaign.csv',
 'adapter': '/content/drive/MyDrive/멋사/adapters_dpo_2',
 'samples': 10,
 'metrics': ['GPT wins: adapter 5 / base 5 (skip_llm_eval=False)',
  'Avg coverage: adapter 0.33, base 0.24',
  'Avg tone match: adapter 0.00, base 0.00',
  'Avg style match: adapter 0.05, base 0.07',
  'Avg info density: adapter 0.15, base 0.19',
  'Repeat token ratio: adapter 0.06, base 0.04',
  'Repeat 3-gram ratio: adapter 0.00, base 0.00',
  'Length ok rate: adapter 0.00, base 0.10',
  'Forbidden violation rate: adapter 0.00, base 0.00',
  'CTA rate: adapter 1.00, base 0.80',
  'Avg length: adapter 349.8, base 297.5']}