In [None]:
# using transformers + peft (LoRA) for generation

import logging
import os
import pickle
import time
from typing import List, Dict, Any

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig


# -----------------------
# Logging
# -----------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)


# -----------------------
# User-editable settings (no hard-coded paths)
# -----------------------
os.environ["TOKENIZERS_PARALLELISM"] = "false"

BASE_MODEL_ID = "your-model-id"                 # e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct", "meta-llama/Llama-2-7b-chat-hf"
CKPT_PATH = "path/to/dp_lora_checkpoint.pth"    # saved model.state_dict() from training
PROMPTS_PATH = "path/to/prompts.pkl"            # pickle dict -> values are prompt strings

DEVICE_INDEX = 0                                # choose GPU index if available

PROMPT_LIMIT = 1304                             # set None to use all prompts in PROMPTS_PATH
BATCH_SIZE = 4

MAX_PROMPT_LEN = 324                            # 512, 768
MAX_NEW_TOKENS = 700                            # 1280

DO_SAMPLE = True
TEMPERATURE = 0.7                               # 0.3, 0.5
TOP_K = 50                                      # 10, 50
REPETITION_PENALTY = 1.2

NUM_REPEATS = 3                                 # save synth outputs 1..NUM_REPEATS
RUN_TAG = "run_tag"                             # used in output filenames, e.g., "100context_1dp"


# -----------------------
# LoRA settings (must match training)
# -----------------------
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.05                             # 0.0
INCLUDE_LM_HEAD = True                          # False if lm_head was not LoRA-wrapped


# -----------------------
# Helpers
# -----------------------
def sanitize_prompts(texts: List[str], eos_str: str) -> List[str]:
    """Remove accidental EOS markers from prompt text."""
    cleaned = []
    for t in texts:
        t = t.replace("</s>", " ")
        if eos_str:
            t = t.replace(eos_str, " ")
        cleaned.append(t.strip())
    return cleaned


def strip_opacus_prefix(state_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Opacus sometimes wraps modules and prefixes keys with '_module.'.
    Strip that prefix if it exists.
    """
    if state_dict and all(k.startswith("_module.") for k in state_dict.keys()):
        return {k[len("_module.") :]: v for k, v in state_dict.items()}
    return state_dict


def load_prompts(path: str, limit: int | None) -> List[str]:
    with open(path, "rb") as f:
        prompt_dict = pickle.load(f)
    prompts = list(prompt_dict.values())
    return prompts if limit is None else prompts[:limit]


def build_model_and_tokenizer():
    # Tokenizer: left padding for decoder-only models
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, padding_side="left")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Base model: 8-bit, pinned to one GPU (if CUDA)
    bnb = BitsAndBytesConfig(load_in_8bit=True)
    device_map = {"": DEVICE_INDEX} if torch.cuda.is_available() else None

    base = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        quantization_config=bnb,
        low_cpu_mem_usage=True,
        device_map=device_map,
    )
    base.config.pad_token_id = tokenizer.pad_token_id
    base.config.use_cache = True

    # Rebuild LoRA modules (must match training)
    targets = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    if INCLUDE_LM_HEAD:
        targets.append("lm_head")

    lora_cfg = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        target_modules=targets,
        bias="none",
        lora_dropout=LORA_DROPOUT,
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(base, lora_cfg)

    # Load checkpoint
    sd = torch.load(CKPT_PATH, map_location="cpu")
    sd = strip_opacus_prefix(sd)

    missing, unexpected = model.load_state_dict(sd, strict=False)
    logger.info("[load_state_dict] Missing: %d | Unexpected: %d", len(missing), len(unexpected))

    model.eval()
    torch.set_grad_enabled(False)

    return model, tokenizer


def generate_batch(model, tokenizer, batch_prompts: List[str]) -> List[str]:
    batch_prompts = sanitize_prompts(batch_prompts, tokenizer.eos_token or "")

    enc = tokenizer(
        batch_prompts,
        padding=True,                 # left padding
        truncation=True,
        max_length=MAX_PROMPT_LEN,
        add_special_tokens=False,     # prevents BOS/EOS injection into prompt
        return_tensors="pt",
    )

    dev = next(model.parameters()).device
    input_ids = enc["input_ids"].to(dev, non_blocking=True)
    attn_mask = enc["attention_mask"].to(dev, non_blocking=True)

    with torch.inference_mode():
        gen = model.generate(
            input_ids=input_ids,
            attention_mask=attn_mask,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=DO_SAMPLE,
            temperature=TEMPERATURE,
            top_k=TOP_K,
            repetition_penalty=REPETITION_PENALTY,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            num_return_sequences=1,
        )

    new_tokens = gen[:, input_ids.shape[1] :]
    return [t.strip() for t in tokenizer.batch_decode(new_tokens, skip_special_tokens=True)]


def generate_with_retries(model, tokenizer, batch_prompts: List[str], retries: int = 2) -> List[str]:
    for attempt in range(retries + 1):
        try:
            return generate_batch(model, tokenizer, batch_prompts)
        except Exception as e:
            logger.warning("Generation error (attempt %d/%d): %r", attempt + 1, retries + 1, e)
    return [""] * len(batch_prompts)


def save_pickle(obj: Any, out_path: str) -> None:
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    with open(out_path, "wb") as f:
        pickle.dump(obj, f)
    logger.info("Saved: %s", out_path)


# -----------------------
# Main
# -----------------------
def main():
    model, tokenizer = build_model_and_tokenizer()
    logger.info("padding_side: %s", tokenizer.padding_side)

    prompt_list = load_prompts(PROMPTS_PATH, PROMPT_LIMIT)
    n = len(prompt_list)
    if n < 1:
        raise ValueError(f"No prompts found in {PROMPTS_PATH}")

    out_dir = os.path.dirname(CKPT_PATH) or "."

    for rep in range(1, NUM_REPEATS + 1):
        notes = []
        start = time.time()

        for i in range(0, n, BATCH_SIZE):
            batch = prompt_list[i : i + BATCH_SIZE]
            logger.info("%4d/%4d  elapsed=%7.2fs  batch=%d", i, n, time.time() - start, len(batch))

            out = generate_with_retries(model, tokenizer, batch, retries=2)
            print(out)
            notes.extend(out)

        logger.info("Repeat %d | Total generated: %d | Total time: %.2fs", rep, len(notes), time.time() - start)

        out_path = os.path.join(out_dir, f"synth_{RUN_TAG}_{rep}.pkl")
        save_pickle(notes, out_path)


if __name__ == "__main__":
    main()
