# Soft-Prompt Experiment Notebook

This notebook orchestrates hyperparameter experiments for **Brand Voice Rewriter**.

**What it does:**  
- Runs `scripts/train_gemma_softprompt.py` multiple times with different hyperparameters.  
- Logs and aggregates metrics from `artifacts/runs/<style>/run_*.json`.  
- Picks the best configurations by validation proxies (time/seconds, etc.) and your custom metrics.  
- Optionally renders visualization figures using `scripts/visualize_softprompt.py`.

> **Run this notebook from the project root** (same level as `scripts/` and `config/`).  
> Requires a GPU machine for training; CPU is fine for data aggregation/plots.


In [None]:
import os, json, time, subprocess, shlex
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import os
import yaml

import os
from pathlib import Path
import math

import pandas as pd
import numpy as np
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt

PROJECT_ROOT = Path.cwd().parent
os.chdir(PROJECT_ROOT)

CONFIG_PATH = PROJECT_ROOT / "config" / "app.yaml"
RUNS_ROOT   = PROJECT_ROOT / "artifacts" / "softprompt"
FIG_ROOT = PROJECT_ROOT / "artifacts" / "reports" / "figures"
FIG_ROOT.mkdir(parents=True, exist_ok=True)

TEST_CSV    = PROJECT_ROOT / "data" / "datasets" / "neutral_texts.csv"
SOFTP_ROOT  = None

cfg = yaml.safe_load(open(CONFIG_PATH, "r", encoding="utf-8"))
BASE_MODEL_ID = cfg["model_id"]

STYLES = ["fintech", "compliance", "motivation_guru", "ai_newsletter"]

print("Working dir:", PROJECT_ROOT)
print("Config:", CONFIG_PATH)
print("Runs root:", RUNS_ROOT)


## Choose style and search space

In [None]:
# We'll compute grad_accum to hit eff_batch ~ 32.
PER_DEVICE_BSZ = 4           # set 2 on 16GB T4, 4 on 24GB L4/A10
EFF_BATCH      = 4
GRAD_ACCUM     = max(1, EFF_BATCH // PER_DEVICE_BSZ)

# === Search grid (coarse) ===
V_CHOICES   = [8]
SEQ_CHOICES = [64]
LR_CHOICES  = [1e-3]
EPOCHS      = [5]

# Initialization prompt variants (short style hints)
INIT_PROMPTS = {
    "TEXT_neutral": "[rewrite]"
}
INIT_KEYS = list(INIT_PROMPTS.keys())

print("grad_accum:", GRAD_ACCUM, "(per_device_bsz:", PER_DEVICE_BSZ, ")")


## Run a single training trial

In [None]:
def run_trial(style:str, vtok:int, seq_len:int, lr:float, epochs:int, bsz:int, grad_accum:int, init_key:str):
    """Launch a single training run via CLI. Returns (ok, stdout, stderr)."""
    # We pass extra env to override init_text inside your trainer (if supported).
    env = os.environ.copy()
    env["PYTHONPATH"] = str(PROJECT_ROOT)
    env["BVW_INIT_TEXT_VARIANT"] = init_key  # your trainer can read this to choose init

    cmd = f"python scripts/train_gemma_softprompt.py --style {style} --virtual-tokens {vtok} --max-seq-len {seq_len} --lr {lr} --epochs {epochs} --bsz {bsz}"

    # If your trainer supports grad_accum and checkpointing flags, add them here:
    # cmd += f" --grad-accum {grad_accum} --ckpt"

    print("\n[RUN]", cmd)
    start = time.time()
    proc = subprocess.run(shlex.split(cmd), cwd=PROJECT_ROOT, capture_output=True, text=True, env=env)
    dur = time.time() - start
    ok = proc.returncode == 0
    print(proc.stdout)
    if not ok:
        print("[ERR]", proc.stderr)
    return ok, proc.stdout, proc.stderr, dur


## Launch coarse grid search

In [None]:
results = []
for v in V_CHOICES:
    for seql in SEQ_CHOICES:
        for lrate in LR_CHOICES:
            for init_key in INIT_KEYS:
                for STYLE in STYLES:
                
                    ok, out, err, dur = run_trial(
                        style=STYLE, vtok=v, seq_len=seql, lr=lrate, epochs=EPOCHS[0],
                        bsz=PER_DEVICE_BSZ, grad_accum=GRAD_ACCUM, init_key=init_key
                    )
                    results.append({
                        "style": STYLE, "vtok": v, "seq_len": seql, "lr": lrate, "epochs": EPOCHS[0],
                        "per_device_bsz": PER_DEVICE_BSZ, "grad_accum": GRAD_ACCUM, "duration_sec": round(dur,2),
                        "ok": ok
                    })

df = pd.DataFrame(results)

## Collect metrics from artifacts/runs

In [None]:
df_base = pd.read_csv(TEST_CSV)
df_base.head()


In [None]:
def get_latest_run_dir(style: str) -> Path:
    style_root = RUNS_ROOT / style
    if not style_root.exists():
        raise RuntimeError(f"No softprompt dir for style '{style}' at {style_root}")
    runs = sorted([d for d in style_root.iterdir() if d.is_dir() and d.name.startswith("run_")])
    if not runs:
        raise RuntimeError(f"No run_* dirs for style '{style}' at {style_root}")
    return runs[-1]


def load_base_model_and_tokenizer(device: str = "cpu"):
    tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True, use_fast=True)
    dtype = torch.float16 if device == "cuda" else torch.float32
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        dtype=dtype,
    ).to(device)
    model.eval()
    return tok, model


def load_style_model(style: str, base_model, device: str = "cpu"):
    run_dir = get_latest_run_dir(style)
    print(f"[LOAD] style={style} run={run_dir.name}")
    peft_model = PeftModel.from_pretrained(base_model, run_dir)
    peft_model.to(device)
    peft_model.eval()
    return peft_model


In [None]:
def rewrite_with_style(
    tok,
    styled_model,
    device: str,
    text: str,
    style: str,
    max_new_tokens: int = 160,
    temperature: float = 0.0,
    top_p: float = 0.9
) -> str:
    
    system_prompt = (
        f"You are a brand voice rewriter. "
        f"Rewrite the text into the '{style}' brand voice. "
        f"Preserve meaning and approximate length. "
        f"Your output MUST follow this EXACT format:\n"
        f"[rewrite]\n"
        f"<rewritten text here>\n"
        f"[/rewrite]\n"
        f"No explanations. No alternatives. No extra words. "
        f"DO NOT output anything outside the [rewrite] block."
    )

    prompt = (
        f"{system_prompt}\n\n"
        f"Original text:\n{text.strip()}\n\n"
        f"Rewritten text:\n"
        f"[rewrite]\n"
    )

    # Tokenize
    inputs = tok(prompt, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Generate
    with torch.no_grad():
        out = styled_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=temperature,
            top_p=top_p,
            eos_token_id=tok.eos_token_id,
        )

    generated_ids = out[0]

    # Extract only generated continuation
    prompt_len = inputs["input_ids"].shape[1]
    gen_ids = generated_ids[prompt_len:]
    decoded = tok.decode(gen_ids, skip_special_tokens=True)

    # Normalize whitespace
    decoded = decoded.strip()

    # Extract content between tags
    if "[rewrite]" in decoded:
        decoded = decoded.split("[rewrite]", 1)[-1]
    if "[/rewrite]" in decoded:
        decoded = decoded.split("[/rewrite]", 1)[0]

    final_text = decoded.strip()

    return final_text


In [None]:
def rewrite_neutral(tok, base_model, device: str, text: str,
                    max_new_tokens: int = 160, temperature: float = 0.0, top_p: float = 0.9) -> str:
    system_prompt = (
        f"You are a brand voice rewriter. "
        f"Rewrite the text in neutral tone. "
        f"Preserve meaning and approximate length. "
        f"Your output MUST follow this EXACT format:\n"
        f"[rewrite]\n"
        f"<rewritten text here>\n"
        f"[/rewrite]\n"
        f"No explanations. No alternatives. No extra words. "
        f"DO NOT output anything outside the [rewrite] block."
    )
    #system_prompt = (f"Rewrite the input in the user's voice while preserving meaning.")
    prompt = (
            f"{system_prompt}\n\n"
            "Original text:\n"
            f"{text.strip()}\n\n"
            "Rewritten text:\n"
        )

    inputs = tok(prompt, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        out = base_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,             
            temperature=temperature,
            top_p=top_p,
            eos_token_id=tok.eos_token_id,
        )

    generated_ids = out[0]

    # Extract only generated continuation
    prompt_len = inputs["input_ids"].shape[1]
    gen_ids = generated_ids[prompt_len:]
    decoded = tok.decode(gen_ids, skip_special_tokens=True)

    # Normalize whitespace
    decoded = decoded.strip()

    # Extract content between tags
    if "[rewrite]" in decoded:
        decoded = decoded.split("[rewrite]", 1)[-1]
    if "[/rewrite]" in decoded:
        decoded = decoded.split("[/rewrite]", 1)[0]

    final_text = decoded.strip()
    
    return final_text


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
tok, base_model = load_base_model_and_tokenizer(device=device)

base_rewrites = []
for _, row in df_base.iterrows():
    txt = row["text"]
    out = rewrite_neutral(tok, base_model, device, txt)
    base_rewrites.append(out)

df_base["neutral_out"] = base_rewrites

rows = []
for style in STYLES:
    styled_model = load_style_model(style, base_model, device=device)
    for _, row in df_base.iterrows():
        sample_id = row["id"]
        src = row["text"]
        neutral = row["neutral_out"]
        styled = rewrite_with_style(tok, styled_model, device, src, style)
        rows.append({
            "id": sample_id,
            "style": style,
            "src": src,
            "neutral_out": neutral,
            "styled_out": styled,
        })

df_styles = pd.DataFrame(rows)

df_styles.to_csv('test_results_brak.csv')
df_base.to_csv('test_results_neutral_brak.csv')


In [None]:
# %pip install -U sentence-transformers

emb_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def cosine_angle(v1: np.ndarray, v2: np.ndarray) -> float:
    """
    """
    v1 = v1 / (np.linalg.norm(v1) + 1e-8)
    v2 = v2 / (np.linalg.norm(v2) + 1e-8)
    cos_sim = float(np.clip(np.dot(v1, v2), -1.0, 1.0))
    angle = math.degrees(math.acos(cos_sim))
    return angle, cos_sim

angles_rows = []

for (style, sample_id), grp in df_styles.groupby(["style", "id"]):
    neutral_text = grp.iloc[0]["neutral_out"]
    styled_text = grp.iloc[0]["styled_out"]

    emb_neutral = emb_model.encode(neutral_text, convert_to_numpy=True)
    emb_styled  = emb_model.encode(styled_text, convert_to_numpy=True)

    angle_deg, cos_sim = cosine_angle(emb_neutral, emb_styled)

    angles_rows.append({
        "id": sample_id,
        "style": style,
        "angle_deg": angle_deg,
        "cosine": cos_sim,
        "embedding": emb_styled,
    })

df_angles = pd.DataFrame(angles_rows)
df_angles.head()


In [None]:
plt.figure(figsize=(8,5))
df_angles.boxplot(column="angle_deg", by="style")
plt.title("Angle between neutral and styled outputs (degrees)")
plt.suptitle("")
plt.ylabel("Angle (deg)")
plt.xlabel("Style")
plt.grid(alpha=0.3)
plt.show()


In [None]:
pivot = df_angles.pivot(index="id", columns="style", values="angle_deg").sort_index()

plt.figure(figsize=(8,6))
plt.imshow(pivot.values, aspect="auto")
plt.colorbar(label="Angle (deg)")
plt.xticks(ticks=np.arange(len(pivot.columns)), labels=pivot.columns, rotation=45, ha="right")
plt.yticks(ticks=np.arange(len(pivot.index)), labels=pivot.index)
plt.title("Angle heatmap: neutral vs styled per sample/style")
plt.tight_layout()
plt.show()


In [None]:
STYLES = ["fintech", "compliance", "motivation_guru", "ai_newsletter"]

length_map = df_base.set_index("id")["neutral_out"].apply(lambda s: len(s.split()))
df_angles["neutral_len"] = df_angles["id"].map(length_map)

plt.figure(figsize=(7,5))
for style in STYLES:
    sub = df_angles[df_angles["style"] == style]
    plt.scatter(sub["neutral_len"], sub["angle_deg"], label=style, alpha=0.7)

plt.xlabel("Neutral output length (words)")
plt.ylabel("Angle (deg)")
plt.title("Angle vs neutral text length")
plt.legend()
plt.grid(alpha=0.3)
plt.show()


In [None]:
if "neutral_out" not in df_base.columns:
    if "output" in df_base.columns:
        df_base = df_base.rename(columns={"output": "neutral_out"})
    elif "text" in df_base.columns:
        df_base = df_base.rename(columns={"text": "neutral_out"})
    else:
        raise RuntimeError("Cannot find neutral text column in df_base")

if "styled_out" not in df_styles.columns:
    if "output" in df_styles.columns:
        df_styles = df_styles.rename(columns={"output": "styled_out"})
    elif "text" in df_styles.columns:
        df_styles = df_styles.rename(columns={"text": "styled_out"})
    else:
        raise RuntimeError("Cannot find styled text column in df_styles")

df_base["neutral_out"]  = df_base["neutral_out"].fillna("").astype(str)
df_styles["styled_out"] = df_styles["styled_out"].fillna("").astype(str)
df_styles["style"]      = df_styles["style"].astype(str)

def cosine_angle(v1: np.ndarray, v2: np.ndarray):
    """
    """
    v1 = v1 / (np.linalg.norm(v1) + 1e-8)
    v2 = v2 / (np.linalg.norm(v2) + 1e-8)
    cos_sim = float(np.clip(np.dot(v1, v2), -1.0, 1.0))
    angle = math.degrees(math.acos(cos_sim))
    return angle, cos_sim

centroids = {}

neutral_texts = df_base["neutral_out"].tolist()
neutral_embs  = emb_model.encode(neutral_texts, convert_to_numpy=True)
centroids["neutral"] = neutral_embs.mean(axis=0)

styles = sorted(df_styles["style"].unique().tolist())

for style in styles:
    sub = df_styles[df_styles["style"] == style]
    texts = sub["styled_out"].tolist()
    texts = [t for t in texts if isinstance(t, str) and t.strip()]
    if not texts:
        continue
    embs = emb_model.encode(texts, convert_to_numpy=True)
    centroids[style] = embs.mean(axis=0)


labels = ["neutral"] + styles 
n = len(labels)

angle_matrix = np.zeros((n, n), dtype=float)

for i, s1 in enumerate(labels):
    for j, s2 in enumerate(labels):
        v1 = centroids[s1]
        v2 = centroids[s2]
        angle_deg, _ = cosine_angle(v1, v2)
        angle_matrix[i, j] = angle_deg

plt.figure(figsize=(6, 5))
im = plt.imshow(angle_matrix, aspect="equal")

plt.colorbar(im, label="Angle between style centroids (degrees)")
plt.xticks(ticks=np.arange(n), labels=labels, rotation=45, ha="right")
plt.yticks(ticks=np.arange(n), labels=labels)
plt.title("Style vs Style: angular distance (neutral + styled)")
plt.tight_layout()
plt.show()