In [2]:
from dataclasses import dataclass
from typing import List, Optional, Tuple, Dict, Any, Union, Callable
import math
import torch
import torch.nn.functional as F
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
import random
import collections
import sys
from collections import Counter
from google.colab import files
import pickle
import matplotlib.pyplot as plt
import os
import json
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification,
    TrainingArguments, Trainer, DataCollatorWithPadding
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import DPOTrainer, DPOConfig
from tqdm.auto import tqdm
from collections import Counter

In [None]:
!pip install git+https://github.com/huggingface/trl.git

**Data, Model, and Helper Functions Preparation**



In [None]:
class GreedyConfig:
  model_name = "HuggingFaceTB/smollm2-135M-SFT-Only"
  tokenizer_name = None
  dataset_name = "yahma/alpaca-cleaned"
  seed = 42
  test_size = 400
  max_new_tokens = 128
  temperature = 1.0
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  prompt_template = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
  batch_decode = False
  print_every = 50

def set_seed(seed: int):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)


In [3]:

def prepare_model_and_tokenizer(model_name: str, tokenizer_name: Optional[str] = None, device: Optional[str] = None):
  if tokenizer_name is None:
    tokenizer_name = model_name

  device = device or ("cuda" if torch.cuda.is_available() else "cpu")

  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
  if tokenizer.pad_token is None:
    if tokenizer.eos_token is not None:
      tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
    else:
      tokenizer.add_special_tokens({"pad_token": "<|pad|>"})

  model = AutoModelForCausalLM.from_pretrained(model_name)
  model.resize_token_embeddings(len(tokenizer))
  model.to(device)
  model.eval()
  return model, tokenizer, device


def prepare_eval_prompts(dataset_name: str, prompt_template: str, seed: int = 42, test_size: int = 400) -> List[Dict[str, str]]:
  ds = load_dataset(dataset_name)
  if isinstance(ds, dict):
    if "test" in ds:
      base = ds["test"]
    elif "validation" in ds:
      base = ds["validation"]
    else:
      base = ds[list(ds.keys())[0]]
  else:
    base = ds

  base = base.shuffle(seed=seed)
  n = len(base)

  prompts = []
  for i, ex in enumerate(selected):
    instruction = ex.get("instruction") or ex.get("prompt") or ""
    inp = ex.get("input") or ""
    prompt_text = prompt_template.format(instruction=instruction, input=inp)
    prompts.append({"id": i, "prompt": prompt_text, "instruction": instruction, "input": inp})
  return prompts

def postprocess_generated_text(gen_text: str) -> str:
    if gen_text is None:
      return ""
    s = gen_text.strip()
    if not s:
      return ""
    markers = ["### Instruction", "### Response", "\n### Instruction", "\n### Response"]
    cut_idx = None
    for m in markers:
      idx = s.find(m)
      if idx != -1:
        if cut_idx is None or idx < cut_idx:
            cut_idx = idx
    if cut_idx is not None:
      trimmed = s[:cut_idx].strip()
      return trimmed if trimmed else s.strip()
    return s


In [None]:
def sanitize_prompt_for_generation(prompt: str) -> str:
    p = prompt.rstrip()
    if p.endswith("### Response:"):
      p = p + "\nAnswer:\n"
    else:
      p = p + "\nAnswer:\n"
    return p


@torch.no_grad()
def greedy_decode_single( model: AutoModelForCausalLM, tokenizer, prompt: str, max_new_tokens: int = 128, temperature: float = 1.0, eos_token_id: Optional[int] = None, device: Optional[str] = None, sanitize_prompt: bool = True, repeat_ngram_block: int = 3, repeat_threshold: int = 4,
):
  if device is None:
    device = next(model.parameters()).device
  if eos_token_id is None:
    eos_token_id = tokenizer.eos_token_id

  if sanitize_prompt:
    prompt = sanitize_prompt_for_generation(prompt)

  enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)
  input_ids = enc["input_ids"]
  attention_mask = enc.get("attention_mask", None)

  model.eval()
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True, return_dict=True)
  past_key_values = outputs.past_key_values

  generated = input_ids.clone()
  chosen_token_ids: List[int] = []
  chosen_token_logprobs: List[float] = []
  ngram_counter = Counter()

  token_buffer: List[int] = []

  for step in range(max_new_tokens):
    if past_key_values is not None:
      cur_input = generated[:, -1:].to(device)
      outputs = model(input_ids=cur_input, past_key_values=past_key_values, use_cache=True, return_dict=True)
      past_key_values = outputs.past_key_values
    else:
      outputs = model(input_ids=generated, use_cache=False, return_dict=True)

    logits = outputs.logits
    next_token_logits = logits[:, -1, :]
    if temperature <= 0:
      raise ValueError("temperature must be > 0")
    scaled_logits = next_token_logits / float(temperature)
    probs = torch.softmax(scaled_logits, dim=-1)

    next_token_id_tensor = torch.argmax(probs, dim=-1).unsqueeze(-1)
    next_token_id_item = int(next_token_id_tensor[0, 0].item())
    prob_of_choice = float(probs[0, next_token_id_item].item())
    logprob_of_choice = math.log(max(prob_of_choice, 1e-50))

    generated = torch.cat([generated, next_token_id_tensor.to(device)], dim=-1)
    chosen_token_ids.append(next_token_id_item)
    chosen_token_logprobs.append(logprob_of_choice)
    token_buffer.append(next_token_id_item)

    if len(token_buffer) >= repeat_ngram_block:
      ngram = tuple(token_buffer[-repeat_ngram_block:])
      ngram_counter[ngram] += 1
      if ngram_counter[ngram] >= repeat_threshold:
        break

    try:
      decoded_recent = tokenizer.decode(chosen_token_ids[-64:], skip_special_tokens=False)
    except Exception:
      decoded_recent = ""

    if "### Instruction" in decoded_recent or "### Response" in decoded_recent:
      full_decoded = tokenizer.decode(chosen_token_ids, skip_special_tokens=False)
      cut_at = min([idx for idx in [full_decoded.find("### Instruction"), full_decoded.find("### Response")] if idx != -1] + [len(full_decoded)])
      cleaned = full_decoded[:cut_at].strip()
      cleaned_ids = tokenizer(cleaned, return_tensors="pt", add_special_tokens=False).input_ids.squeeze(0).tolist() if cleaned else []
      chosen_token_ids = cleaned_ids
      break

    if eos_token_id is not None and next_token_id_item == eos_token_id:
      break

  gen_text = tokenizer.decode(chosen_token_ids, skip_special_tokens=True).strip()
  avg_logprob = float(np.mean(chosen_token_logprobs)) if chosen_token_logprobs else float("nan")

  return {
      "generated_text": gen_text,
      "full_token_ids": generated.squeeze(0).tolist(),
      "generated_token_ids": chosen_token_ids,
      "avg_logprob": avg_logprob,
      "num_generated_tokens": len(chosen_token_ids),
  }

In [None]:
def ngrams_from_text(text: str, n: int) -> List[Tuple[str, ...]]:
  tokens = text.split()
  if len(tokens) < n:
    return []
  return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)]

def distinct_n(texts: List[str], n: int, tokenizer) -> float:
  all_ngrams = []

  for t in texts:
    ids = tokenizer.encode(t, add_special_tokens=False)
    if len(ids) >= n:
      for i in range(len(ids) - n + 1):
        all_ngrams.append(tuple(ids[i : i+n]))

  if len(all_ngrams) == 0:
    return 0.0

  unique_ngrams = len(set(all_ngrams))
  return unique_ngrams / len(all_ngrams)


**Greedy Search**

In [None]:
def run_greedy_eval(cfg):
  set_seed(cfg.seed)
  model, tokenizer, device = prepare_model_and_tokenizer(cfg.model_name, cfg.tokenizer_name or cfg.model_name, cfg.device)
  prompts = prepare_eval_prompts(cfg.dataset_name, cfg.prompt_template, seed=cfg.seed, test_size=cfg.test_size)
  print(f"Prepared {len(prompts)} prompts for evaluation.")
  print(f"{prompts[:2]}")
  results = []
  for i, p in enumerate(prompts):
    raw_out = greedy_decode_single(model, tokenizer, p["prompt"], max_new_tokens=cfg.max_new_tokens, temperature=cfg.temperature, eos_token_id=tokenizer.eos_token_id, device=device)
    clean_text = postprocess_generated_text(raw_out.get("generated_text", ""))
    num_tokens = len(clean_text.split()) if clean_text.strip() else 0
    results.append({
      "id": p["id"],
      "prompt": p["prompt"],
      "generated_text": clean_text,
      "avg_logprob": raw_out.get("avg_logprob", float("nan")),
      "num_generated_tokens": num_tokens
    })
    if (i+1) % cfg.print_every == 0:
      print(f"Decoded {i+1}/{len(prompts)} prompts...")
  gens = [r["generated_text"] for r in results]
  summary = {
    "num_prompts": len(prompts),
    "distinct_1_across": distinct_n(gens, 1, tokenizer),
    "distinct_2_across": distinct_n(gens, 2, tokenizer),
    "distinct_3_across": distinct_n(gens, 3, tokenizer),
    "avg_logprob_all": float(np.nanmean([r["avg_logprob"] for r in results if not math.isnan(r["avg_logprob"])]))
  }
  print("Greedy summary:", summary)
  return {"config": cfg, "results": results, "summary": summary}

In [None]:
import pickle

greedy_out = run_greedy_eval(GreedyConfig)

with open("greedy_out.pkl", "wb") as f:
  pickle.dump(greedy_out, f)

In [None]:
import os
from typing import List, Dict, Any, Optional
import matplotlib.pyplot as plt
import numpy as np

def plot_distinct_across_within(summary_like: Dict[str, Any], outpath: Optional[str] = None, method_name: str = "Method"):
  d_across = [
    summary_like.get("distinct_1_across", 0.0),
    summary_like.get("distinct_2_across", 0.0),
    summary_like.get("distinct_3_across", 0.0),
  ]
  d_within = [
    summary_like.get("distinct_1_within", 0.0),
    summary_like.get("distinct_2_within", 0.0),
    summary_like.get("distinct_3_within", 0.0),
  ]
  labels = ["Distinct-1", "Distinct-2", "Distinct-3"]
  x = np.arange(len(labels))
  width = 0.35
  fig, ax = plt.subplots(figsize=(7,4.5))
  ax.bar(x - width/2, d_across, width, label="Across-prompts")
  if any(v > 0 for v in d_within):
    ax.bar(x + width/2, d_within, width, label=f"Within-prompt ({method_name})")
  ax.set_xticks(x)
  ax.set_xticklabels(labels)
  ax.set_ylabel("Distinct-N (unique N-grams / total N-grams)")
  ax.set_title(f"{method_name}: Distinct-N (Across vs Within)")
  ax.set_ylim(0, max(max(d_across + d_within) * 1.2, 0.01))
  ax.legend()
  ax.grid(axis="y", linestyle="--", linewidth=0.4, alpha=0.7)
  plt.tight_layout()
  if outpath:
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    plt.savefig(outpath, dpi=200)
  plt.show()
  return fig

def plot_avg_logprob_box(greedy_results: List[Dict[str, Any]], outpath: Optional[str] = None):
  logprobs = [r.get("avg_logprob", np.nan) for r in greedy_results]
  logprobs = np.array([x for x in logprobs if (x is not None) and (not np.isnan(x))], dtype=float)

  fig, ax = plt.subplots(figsize=(6,4))
  if logprobs.size == 0:
    ax.text(0.5, 0.5, "No avg_logprob data", ha="center", va="center")
  else:
    ax.boxplot(logprobs, vert=True, patch_artist=True, widths=0.5, boxprops=dict(facecolor="lightgray", color="black"), medianprops=dict(color="red"))
    mean_lp = float(np.mean(logprobs))
    n = logprobs.size
    ax.text(1.05, 0.95, f"mean = {mean_lp:.3f}\nn = {n}", transform=ax.transAxes, verticalalignment="top", bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"))

  ax.set_ylabel("Average token log-prob (per sample)")
  ax.set_title("Greedy: Distribution of avg token log-prob (quality)")
  ax.grid(axis="y", linestyle="--", linewidth=0.4, alpha=0.7)
  plt.tight_layout()

  if outpath:
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    plt.savefig(outpath, dpi=200)

  plt.show()
  return fig

def plot_length_histogram(greedy_results: List[Dict[str, Any]], outpath: Optional[str] = None):
  lengths = [int(r.get("num_generated_tokens", 0)) for r in greedy_results]
  if len(lengths) == 0:
    print("No length data available.")
    return None
  fig, ax = plt.subplots(figsize=(7,4))
  ax.hist(lengths, bins=30, edgecolor="black", alpha=0.75)
  ax.set_xlabel("Number of generated tokens")
  ax.set_ylabel("Count")
  ax.set_title("Greedy: Generated length distribution")
  ax.grid(axis="y", linestyle="--", linewidth=0.4, alpha=0.7)
  mean_len = np.mean(lengths)
  med_len = np.median(lengths)
  ax.axvline(mean_len, color="red", linestyle="--", label=f"mean={mean_len:.1f}")
  ax.axvline(med_len, color="orange", linestyle="-.", label=f"median={med_len:.1f}")
  ax.legend()
  plt.tight_layout()
  if outpath:
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    plt.savefig(outpath, dpi=200)
  plt.show()
  return fig

def plot_greedy_from_eval(results, summary, outdir: Optional[str] = None) -> Dict[str, Optional[str]]:
  saved = {}

  p1 = os.path.join(outdir, "greedy_distinct_across_within.png") if outdir else None
  plot_distinct_across_within(summary, outpath=p1, method_name = "Greedy")
  saved["distinct"] = p1

  p2 = os.path.join(outdir, "greedy_avg_logprob_box.png") if outdir else None
  plot_avg_logprob_box(results, outpath=p2)
  saved["avg_logprob_box"] = p2

  p3 = os.path.join(outdir, "greedy_length_hist.png") if outdir else None
  plot_length_histogram(results, outpath=p3)
  saved["length_hist"] = p3

  return saved

In [None]:
saved = plot_greedy_from_eval(greedy_out["results"], greedy_out["summary"], outdir="figs/greedy")

**Beam Search**

In [None]:

@torch.no_grad()
def beam_search_decode_single(model: AutoModelForCausalLM, tokenizer, prompt: str, beam_width: int = 4, max_new_tokens: int = 128,
                              temperature: float = 1.0, eos_token_id: Optional[int] = None, device: Optional[torch.device] = None,
                              repeat_ngram_block: int = 3, repeat_threshold: int = 4):

  if device is None:
    device = next(model.parameters()).device
  if eos_token_id is None:
    eos_token_id = tokenizer.eos_token_id

  prompt_for_model = sanitize_prompt_for_generation(prompt)

  enc = tokenizer(prompt_for_model, return_tensors="pt", add_special_tokens=True).to(device)
  input_ids = enc["input_ids"]
  attention_mask = enc.get("attention_mask", None)

  model.eval()
  init_out = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True, return_dict=True)
  init_past = init_out.past_key_values
  init_next_logits = init_out.logits[:, -1, :]
  vocab_size = init_next_logits.size(-1)
  dv = init_next_logits.device

  beams = [{"tokens": [], "score": 0.0, "past": init_past, "next_logits": init_next_logits, "ended": False, "ngram_counts": Counter()}]

  for step in range(max_new_tokens):
    candidate_info: List[Tuple[int, Optional[int], float, bool, Counter]] = []
    for b_idx, beam in enumerate(beams):
      if beam["ended"]:
        candidate_info.append((b_idx, None, beam["score"], True, beam["ngram_counts"]))
        continue

      logits = beam["next_logits"]
      if logits is None:
        continue
      if temperature <= 0:
        raise ValueError("temperature must be > 0")
      scaled_logits = logits / float(temperature)
      log_probs = torch.log_softmax(scaled_logits, dim=-1).squeeze(0)
      cand_scores = (log_probs + beam["score"]).cpu()
      k = min(beam_width, vocab_size)
      topk_vals, topk_idx = torch.topk(cand_scores, k=k)
      topk_vals = topk_vals.tolist()
      topk_idx = topk_idx.tolist()
      for score_val, token_id in zip(topk_vals, topk_idx):
        candidate_info.append((b_idx, int(token_id), float(score_val), False, beam["ngram_counts"]))

    if not candidate_info:
      break

    scores_arr = np.array([c[2] for c in candidate_info], dtype=float)
    pick_k = min(beam_width, len(scores_arr))
    top_idxs = scores_arr.argsort()[-pick_k:][::-1]

    new_beams = []
    for idx in top_idxs:
        parent_idx, token_id, score_val, is_ended, parent_ngram_counts = candidate_info[idx]
        parent_beam = beams[parent_idx]
        if is_ended:
          new_beams.append({
              "tokens": parent_beam["tokens"].copy(),
              "score": parent_beam["score"],
              "past": parent_beam["past"],
              "next_logits": parent_beam["next_logits"],
              "ended": True,
              "ngram_counts": parent_beam["ngram_counts"].copy()
          })
        else:
          new_tokens = parent_beam["tokens"].copy()
          new_tokens.append(token_id)
          ended_flag = (eos_token_id is not None and token_id == eos_token_id)
          new_ngram_counts = parent_beam["ngram_counts"].copy()
          if len(new_tokens) >= repeat_ngram_block:
            last_ngram = tuple(new_tokens[-repeat_ngram_block:])
            new_ngram_counts[last_ngram] += 1
            if new_ngram_counts[last_ngram] >= repeat_threshold:
                ended_flag = True

          new_beams.append({
              "tokens": new_tokens,
              "score": score_val,
              "past": parent_beam["past"],
              "last_token_id": token_id,
              "next_logits": None,
              "ended": ended_flag,
              "ngram_counts": new_ngram_counts
          })

    for nb in new_beams:
      if nb["ended"]:
        nb.pop("last_token_id", None)
        continue
      last_tok = torch.tensor([[nb["last_token_id"]]], dtype=torch.long, device=dv)
      out = model(input_ids=last_tok, past_key_values=nb["past"], use_cache=True, return_dict=True)
      nb["past"] = out.past_key_values
      nb["next_logits"] = out.logits[:, -1, :].to(dv)
      nb.pop("last_token_id", None)

    beams = new_beams

    if all(b["ended"] for b in beams):
      break

  ended_beams = [b for b in beams if b["ended"]]
  if ended_beams:
    best_beam = max(ended_beams, key=lambda x: x["score"])
  else:
    best_beam = max(beams, key=lambda x: x["score"])

  gen_token_ids = best_beam["tokens"]
  gen_text_raw = tokenizer.decode(gen_token_ids, skip_special_tokens=True)
  gen_text = postprocess_generated_text(gen_text_raw)
  num_generated = len(gen_token_ids)
  avg_logprob = float(best_beam["score"] / num_generated) if num_generated > 0 else float("nan")

  return {
    "generated_text": gen_text,
    "generated_token_ids": gen_token_ids,
    "avg_logprob": avg_logprob,
    "num_generated_tokens": len(gen_text.split()) if gen_text.strip() else 0,
    "beam_score": float(best_beam["score"]),
    "beam_width": beam_width
  }



In [None]:
def run_beam_eval(cfg, beam_width = 4):
  set_seed(cfg.seed)
  model, tokenizer, device = prepare_model_and_tokenizer(cfg.model_name, cfg.tokenizer_name or cfg.model_name, cfg.device)
  prompts = prepare_eval_prompts(cfg.dataset_name, cfg.prompt_template, seed=cfg.seed, test_size=cfg.test_size)
  results = []
  for i, p in enumerate(prompts):
    out = beam_search_decode_single(
      model, tokenizer,
      p["prompt"],
      beam_width=beam_width,
      max_new_tokens=cfg.max_new_tokens,
      temperature=cfg.temperature,
      eos_token_id=tokenizer.eos_token_id,
      device=device
    )
    results.append({
      "id": p["id"],
      "prompt": p["prompt"],
      "generated_text": out["generated_text"],
      "avg_logprob": out["avg_logprob"],
      "num_generated_tokens": out["num_generated_tokens"],
      "beam_score": out["beam_score"]
    })
    if (i+1) % cfg.print_every == 0:
      print(f"Beam-decoded {i+1}/{len(prompts)} prompts...")
  gens = [r["generated_text"] for r in results]
  summary = {
    "num_prompts": len(prompts),
    "distinct_1_across": distinct_n(gens, 1, tokenizer),
    "distinct_2_across": distinct_n(gens, 2, tokenizer),
    "distinct_3_across": distinct_n(gens, 3, tokenizer),
    "avg_logprob_all": float(np.nanmean([r["avg_logprob"] for r in results if not math.isnan(r["avg_logprob"])]))
  }
  print(f"Beam (B={beam_width}) summary:", summary)
  return {"config": cfg, "beam_width": beam_width, "results": results, "summary": summary}

In [None]:
import pickle

beam_out = run_beam_eval(GreedyConfig)

with open("beam_out.pkl", "wb") as f:
  pickle.dump(beam_out, f)

In [None]:
def plot_beam_score_hist(beam_results: List[Dict[str, Any]], outpath: Optional[str] = None):
  scores = [r.get("beam_score", np.nan) for r in beam_results]
  scores = np.array([s for s in scores if (s is not None) and (not np.isnan(s))], dtype=float)
  fig, ax = plt.subplots(figsize=(7,4))
  if scores.size == 0:
    ax.text(0.5,0.5,"No beam_score data", ha="center", va="center")
  else:
    ax.hist(scores, bins=30, edgecolor="black", alpha=0.75)
    ax.set_xlabel("Beam cumulative log-prob (score)")
    ax.set_ylabel("Count")
    ax.set_title("Beam: Distribution of beam scores (cumulative log-prob)")
    ax.grid(axis="y", linestyle="--", linewidth=0.4, alpha=0.7)
  plt.tight_layout()
  if outpath:
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    plt.savefig(outpath, dpi=200)
  plt.show()
  return fig

def plot_beam_from_eval_reuse(beam_out: Dict[str, Any], outdir: Optional[str] = None) -> Dict[str, Optional[str]]:
  results = beam_out["results"]
  summary = beam_out.get("summary", {})

  saved = {}
  p1 = os.path.join(outdir, "beam_distinct_across_within.png") if outdir else None
  plot_distinct_across_within({
    "distinct_1_across": summary.get("distinct_1_across", 0.0),
    "distinct_2_across": summary.get("distinct_2_across", 0.0),
    "distinct_3_across": summary.get("distinct_3_across", 0.0),
    "distinct_1_within": summary.get("distinct_1_within", 0.0),
    "distinct_2_within": summary.get("distinct_2_within", 0.0),
    "distinct_3_within": summary.get("distinct_3_within", 0.0),
  }, outpath=p1, method_name = "Beam")
  saved["distinct"] = p1

  p2 = os.path.join(outdir, "beam_avg_logprob_box.png") if outdir else None
  plot_avg_logprob_box(results, outpath=p2)
  saved["avg_logprob_box"] = p2

  p3 = os.path.join(outdir, "beam_length_hist.png") if outdir else None
  plot_length_histogram(results, outpath=p3)
  saved["length_hist"] = p3

  p4 = os.path.join(outdir, "beam_score_hist.png") if outdir else None
  plot_beam_score_hist(results, outpath=p4)
  saved["beam_score_hist"] = p4

  return saved

pkl_path = "/content/beam_out.pkl"
with open(pkl_path, "rb") as f:
beam_res = pickle.load(f)

beam_saved = plot_beam_from_eval_reuse(beam_res, outdir="figs/beam")
print("Beam plots saved:", beam_res)

**Top-K Sampling**

In [None]:

@torch.no_grad()
def top_k_sample_single(model, tokenizer, prompt: str, max_new_tokens: int = 128, k: int = 50, temperature: float = 1.0, eos_token_id: Optional[int] = None, device: Optional[torch.device] = None,
                        seed: Optional[int] = None, return_full_token_ids: bool = False, repeat_ngram_block: int = 3, repeat_threshold: int = 4, recent_decode_window: int = 64):
    if device is None:
      device = next(model.parameters()).device
    if eos_token_id is None:
      eos_token_id = tokenizer.eos_token_id
    if temperature <= 0:
      raise ValueError("temperature must be > 0")

    gen = None
    if seed is not None:
      gen = torch.Generator(device=device)
      gen.manual_seed(int(seed))

    prompt_for_model = sanitize_prompt_for_generation(prompt)

    enc = tokenizer(prompt_for_model, return_tensors="pt", add_special_tokens=True).to(device)
    input_ids = enc["input_ids"]

    outputs = model(input_ids=input_ids, use_cache=True, return_dict=True)
    past_key_values = outputs.past_key_values
    next_logits = outputs.logits[:, -1, :]

    generated_token_ids: List[int] = []
    token_logprobs: List[float] = []
    token_buffer: List[int] = []
    ngram_counts = Counter()

    vocab_size = next_logits.size(-1)
    k = max(1, min(int(k), vocab_size))

    for step in range(max_new_tokens):
      scaled_logits = next_logits / float(temperature)
      probs = torch.softmax(scaled_logits, dim=-1)

      topk_probs, topk_indices = torch.topk(probs, k=k, dim=-1)
      topk_probs = topk_probs / (topk_probs.sum(dim=-1, keepdim=True) + 1e-20)

      if gen is not None:
        sample_idx = torch.multinomial(topk_probs.squeeze(0), num_samples=1, generator=gen)
      else:
        sample_idx = torch.multinomial(topk_probs.squeeze(0), num_samples=1)
      chosen_idx_in_topk = int(sample_idx[0].item())
      chosen_token_id = int(topk_indices[0, chosen_idx_in_topk].item())
      prob_chosen = float(topk_probs[0, chosen_idx_in_topk].item())
      logprob_chosen = math.log(max(prob_chosen, 1e-50))

      generated_token_ids.append(chosen_token_id)
      token_logprobs.append(logprob_chosen)
      token_buffer.append(chosen_token_id)

      if len(token_buffer) >= repeat_ngram_block:
        ng = tuple(token_buffer[-repeat_ngram_block:])
        ngram_counts[ng] += 1
        if ngram_counts[ng] >= repeat_threshold:
          break

      next_input = torch.tensor([[chosen_token_id]], dtype=torch.long, device=device)
      out = model(input_ids=next_input, past_key_values=past_key_values, use_cache=True, return_dict=True)
      past_key_values = out.past_key_values
      next_logits = out.logits[:, -1, :]

      try:
        recent_ids = generated_token_ids[-recent_decode_window:]
        decoded_recent = tokenizer.decode(recent_ids, skip_special_tokens=False)
      except Exception:
        decoded_recent = ""
      if "### Instruction" in decoded_recent or "### Response" in decoded_recent:
        full_decoded = tokenizer.decode(generated_token_ids, skip_special_tokens=False)
        first_marker_idx = min([i for i in [full_decoded.find("### Instruction"), full_decoded.find("### Response")] if i != -1] + [len(full_decoded)])
        cleaned = full_decoded[:first_marker_idx].strip()
        cleaned_ids = tokenizer(cleaned, return_tensors="pt", add_special_tokens=False).input_ids.squeeze(0).tolist() if cleaned else []
        generated_token_ids = cleaned_ids
        break

      if eos_token_id is not None and chosen_token_id == eos_token_id:
        break

    generated_text_raw = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
    generated_text = postprocess_generated_text(generated_text_raw)

    avg_logprob = float(np.mean(token_logprobs)) if token_logprobs else float("nan")
    num_generated_tokens = len(generated_text.split()) if generated_text.strip() else 0

    result = {
      "generated_text": generated_text,
      "generated_token_ids": generated_token_ids,
      "token_logprobs": token_logprobs,
      "avg_logprob": avg_logprob,
      "num_generated_tokens": num_generated_tokens,
    }

    if return_full_token_ids:
      full_ids = input_ids.squeeze(0).tolist() + generated_token_ids
      result["full_token_ids"] = full_ids

    return result

In [None]:
def run_topk_eval(cfg, k: int = 50, temperatures: Optional[List[float]] = None, max_new_tokens: Optional[int] = None, n_within_repeat: int = 10, prompt_for_within: int = 0,
                  deterministic_seed_per_draw: Optional[int] = None, repeat_ngram_block: int = 3, repeat_threshold: int = 4):

  if temperatures is None:
    temperatures = [0.2, 0.5, 0.8, 1.0, 1.2]
  if max_new_tokens is None:
    max_new_tokens = cfg.max_new_tokens

  set_seed(cfg.seed)
  model, tokenizer, device = prepare_model_and_tokenizer(cfg.model_name, cfg.tokenizer_name or cfg.model_name, cfg.device)
  prompts = prepare_eval_prompts(cfg.dataset_name, cfg.prompt_template, seed=cfg.seed, test_size=cfg.test_size)

  results_by_temp = {}
  summary_by_temp = {}

  base_rng = random.Random(deterministic_seed_per_draw if deterministic_seed_per_draw is not None else cfg.seed)

  for T in temperatures:
    per_prompt_results = []
    avg_logprobs = []
    gen_texts = []

    for i, p in enumerate(prompts):
      call_seed = None
      if deterministic_seed_per_draw is not None:
          call_seed = base_rng.randint(0, 2**31 - 1)

      out = top_k_sample_single(model=model, tokenizer=tokenizer, prompt=p["prompt"], max_new_tokens=max_new_tokens, k=k, temperature=T,
      eos_token_id=tokenizer.eos_token_id, device=device, seed=call_seed, return_full_token_ids=False, repeat_ngram_block=repeat_ngram_block, repeat_threshold=repeat_threshold)

      cleaned_text = out["generated_text"]
      out_num_tokens = len(cleaned_text.split()) if cleaned_text.strip() else 0

      per_prompt_results.append({
        "id": p["id"],
        "prompt": p["prompt"],
        "generated_text": cleaned_text,
        "avg_logprob": out["avg_logprob"],
        "num_generated_tokens": out_num_tokens,
      })

      if not math.isnan(out["avg_logprob"]):
        avg_logprobs.append(out["avg_logprob"])
      gen_texts.append(cleaned_text)

      if (i + 1) % cfg.print_every == 0:
        print(temperatures)
        print(f"[Top-K T={T}] Decoded {i+1}/{len(prompts)} prompts...")

    distinct1 = distinct_n(gen_texts, 1, tokenizer)
    distinct2 = distinct_n(gen_texts, 2, tokenizer)
    distinct3 = distinct_n(gen_texts, 3, tokenizer)
    avg_logprob_all = float(np.nanmean(avg_logprobs)) if avg_logprobs else float("nan")

    within_prompt_gen = []
    if len(prompts) > 0 and n_within_repeat > 0:
      within_idx = max(0, min(prompt_for_within, len(prompts) - 1))
      single_prompt = prompts[within_idx]["prompt"]
      for rep in range(n_within_repeat):
        call_seed = None
        if deterministic_seed_per_draw is not None:
          call_seed = base_rng.randint(0, 2**31 - 1)
        o = top_k_sample_single(model=model, tokenizer=tokenizer, prompt=single_prompt, max_new_tokens=max_new_tokens, k=k, temperature=T, eos_token_id=tokenizer.eos_token_id,
                                device=device, seed=call_seed, return_full_token_ids=False, repeat_ngram_block=repeat_ngram_block, repeat_threshold=repeat_threshold)
        within_prompt_gen.append(o["generated_text"])
    within_dist1 = distinct_n(within_prompt_gen, 1, tokenizer)
    within_dist2 = distinct_n(within_prompt_gen, 2, tokenizer)
    within_dist3 = distinct_n(within_prompt_gen, 3, tokenizer)

    results_by_temp[T] = per_prompt_results
    summary_by_temp[T] = {
      "num_prompts": len(prompts),
      "k": k,
      "temperature": T,
      "avg_logprob_all": avg_logprob_all,
      "distinct_1_across": distinct1,
      "distinct_2_across": distinct2,
      "distinct_3_across": distinct3,
      "distinct_1_within": within_dist1,
      "distinct_2_within": within_dist2,
      "distinct_3_within": within_dist3,
    }

    n_total = len(per_prompt_results)
    n_empty = sum(1 for r in per_prompt_results if not r["generated_text"].strip())
    empty_rate = n_empty / n_total if n_total>0 else 0.0
    print(f"[Top-K T={T}] empty_rate after postprocessing: {empty_rate:.2%}")
    print(f"[Top-K summary] T={T:.2f}, k={k} -> avg_logprob={avg_logprob_all:.4f}, distinct1_across={distinct1:.4f}, within1={within_dist1:.4f}")

  return {
    "config": cfg,
    "k": k,
    "temperatures": temperatures,
    "results_by_temp": results_by_temp,
    "summary_by_temp": summary_by_temp,
  }

In [None]:
import pickle

topk_out = run_topk_eval(GreedyConfig)

with open("topk_out.pkl", "wb") as f:
  pickle.dump(topk_out, f)

In [None]:

def plot_topk_temperature_curves(topk_summary_by_temp: Dict[float, Dict[str, Any]], outdir: Optional[str] = None):
  if not topk_summary_by_temp:
    raise ValueError("empty topk_summary_by_temp")

  temps = sorted(topk_summary_by_temp.keys())
  d1 = [topk_summary_by_temp[T].get("distinct_1_across", 0.0) for T in temps]
  d2 = [topk_summary_by_temp[T].get("distinct_2_across", 0.0) for T in temps]
  d3 = [topk_summary_by_temp[T].get("distinct_3_across", 0.0) for T in temps]
  avg_lp = [topk_summary_by_temp[T].get("avg_logprob_all", np.nan) for T in temps]

  fig, ax = plt.subplots(figsize=(7,4))
  ax.plot(temps, d1, marker='o', label='Distinct-1')
  ax.plot(temps, d2, marker='o', label='Distinct-2')
  ax.plot(temps, d3, marker='o', label='Distinct-3')
  ax.set_xlabel("Temperature")
  ax.set_ylabel("Distinct-N")
  ax.set_title("Top-K: Distinct-N vs Temperature")
  ax.legend()
  ax.grid(linestyle='--', linewidth=0.4)
  if outdir:
    p = os.path.join(outdir, "topk_distinct_vs_temp.png")
    os.makedirs(os.path.dirname(p), exist_ok=True)
    plt.savefig(p, dpi=200)
  plt.show()

  fig, ax = plt.subplots(figsize=(7,4))
  ax.plot(temps, avg_lp, marker='o')
  ax.set_xlabel("Temperature")
  ax.set_ylabel("Avg token log-prob")
  ax.set_title("Top-K: Avg token log-prob vs Temperature")
  ax.grid(linestyle='--', linewidth=0.4)
  if outdir:
    p = os.path.join(outdir, "topk_avglogprob_vs_temp.png")
    os.makedirs(os.path.dirname(p), exist_ok=True)
    plt.savefig(p, dpi=200)
  plt.show()

  return {"distinct_vs_temp": None, "avglogprob_vs_temp": None}


def plot_topk_within_across_bar(topk_summary_by_temp: Dict[float, Dict[str, Any]], temperature: float, outpath: Optional[str] = None):
  if temperature not in topk_summary_by_temp:
      temps = sorted(topk_summary_by_temp.keys())
      temperature = min(temps, key=lambda x: abs(x - temperature))
  s = topk_summary_by_temp[temperature]
    summary_like = {
      "distinct_1_across": s.get("distinct_1_across", 0.0),
      "distinct_2_across": s.get("distinct_2_across", 0.0),
      "distinct_3_across": s.get("distinct_3_across", 0.0),
      "distinct_1_within": s.get("distinct_1_within", 0.0),
      "distinct_2_within": s.get("distinct_2_within", 0.0),
      "distinct_3_within": s.get("distinct_3_within", 0.0),
    }
    plot_distinct_across_within(summary_like, outpath=outpath, method_name = "Top-K")
    return None


def plot_topk_scatter_tradeoff(topk_summary_by_temp: Dict[float, Dict[str, Any]], outpath: Optional[str] = None):
    temps = sorted(topk_summary_by_temp.keys())
    xs = [topk_summary_by_temp[T].get("distinct_1_across", 0.0) for T in temps]
    ys = [topk_summary_by_temp[T].get("avg_logprob_all", np.nan) for T in temps]
    fig, ax = plt.subplots(figsize=(6,5))
    ax.scatter(xs, ys)
    for i, T in enumerate(temps):
      ax.annotate(f"T={T}", (xs[i], ys[i]))
    ax.set_xlabel("Distinct-1 (Across)")
    ax.set_ylabel("Avg token log-prob")
    ax.set_title("Top-K: Diversity (Distinct-1) vs Quality (avg logprob) across Temperatures")
    ax.grid(linestyle='--', linewidth=0.4)
    if outpath:
      os.makedirs(os.path.dirname(outpath), exist_ok=True)
      plt.savefig(outpath, dpi=200)
    plt.show()
    return fig

plot_topk_temperature_curves(topk_out["summary_by_temp"], outdir="figs/topk")
plot_topk_within_across_bar(topk_out["summary_by_temp"], temperature=0.8, outpath="figs/topk/topk_within_across_T0.8.png")
plot_topk_scatter_tradeoff(topk_out["summary_by_temp"], outpath="figs/topk/topk_tradeoff.png")

**Top-P Sampling**

In [None]:
@torch.no_grad()
def top_p_sample_single(model, tokenizer, prompt: str, max_new_tokens: int = 128, p: float = 0.9, temperature: float = 1.0, eos_token_id: Optional[int] = None,
    device: Optional[torch.device] = None, seed: Optional[int] = None, return_full_token_ids: bool = False, repeat_ngram_block: int = 3, repeat_threshold: int = 4, recent_decode_window: int = 64):

  if device is None:
    device = next(model.parameters()).device
  if eos_token_id is None:
    eos_token_id = tokenizer.eos_token_id
  if not (0.0 < p <= 1.0):
    raise ValueError("p must be in (0, 1].")
  if temperature <= 0.0:
    raise ValueError("temperature must be > 0")

  if seed is not None:
    try:
      gen = torch.Generator(device=device)
      gen.manual_seed(int(seed))
    except Exception:
      gen = torch.Generator()
      gen.manual_seed(int(seed))
  else:
    gen = None

  prompt_for_model = sanitize_prompt_for_generation(prompt)

  enc = tokenizer(prompt_for_model, return_tensors="pt", add_special_tokens=True).to(device)
  input_ids = enc["input_ids"]
  attention_mask = enc.get("attention_mask", None)

  out = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True, return_dict=True)
  past_key_values = out.past_key_values
  next_logits = out.logits[:, -1, :]

  generated_token_ids: List[int] = []
  token_logprobs: List[float] = []
  token_buffer: List[int] = []
  ngram_counts = Counter()

  vocab_size = next_logits.size(-1)

  for step in range(max_new_tokens):
    scaled_logits = next_logits / float(temperature)  # (1, vocab)
    probs = torch.softmax(scaled_logits, dim=-1).squeeze(0)  # (vocab,)

    if p >= 1.0:
      nucleus_indices = torch.arange(vocab_size, device=probs.device)
      nucleus_probs = probs
    else:
      sorted_probs, sorted_indices = torch.sort(probs, descending=True)
      cumulative = torch.cumsum(sorted_probs, dim=0)
      if torch.any(cumulative > p):
        cutoff_idx = int(torch.nonzero(cumulative > p, as_tuple=False)[0].item())
      else:
        cutoff_idx = len(sorted_probs) - 1
      keep_count = cutoff_idx + 1
      nucleus_indices = sorted_indices[:keep_count]
      nucleus_probs = sorted_probs[:keep_count]

    ssum = float(nucleus_probs.sum().item())
    if ssum <= 0:
      nucleus_probs = torch.ones_like(nucleus_probs, device=probs.device)
      ssum = float(nucleus_probs.sum().item())
    nucleus_probs = nucleus_probs / (ssum + 1e-20)

    if gen is not None:
      sampled = torch.multinomial(nucleus_probs, num_samples=1, generator=gen).item()
    else:
      sampled = torch.multinomial(nucleus_probs, num_samples=1).item()

    chosen_token_id = int(nucleus_indices[sampled].item())
    prob_chosen = float(nucleus_probs[sampled].item())
    logprob_chosen = math.log(max(prob_chosen, 1e-50))

    generated_token_ids.append(chosen_token_id)
    token_logprobs.append(logprob_chosen)
    token_buffer.append(chosen_token_id)

    if len(token_buffer) >= repeat_ngram_block:
      ng = tuple(token_buffer[-repeat_ngram_block:])
      ngram_counts[ng] += 1
      if ngram_counts[ng] >= repeat_threshold:
        break

    next_input = torch.tensor([[chosen_token_id]], dtype=torch.long, device=device)
    out2 = model(input_ids=next_input, past_key_values=past_key_values, use_cache=True, return_dict=True)
    past_key_values = out2.past_key_values
    next_logits = out2.logits[:, -1, :]

    try:
      recent_ids = generated_token_ids[-recent_decode_window:]
      decoded_recent = tokenizer.decode(recent_ids, skip_special_tokens=False)
    except Exception:
      decoded_recent = ""
    if "### Instruction" in decoded_recent or "### Response" in decoded_recent:
      full_dec = tokenizer.decode(generated_token_ids, skip_special_tokens=False)
      first_marker_idx = min([idx for idx in [full_dec.find("### Instruction"), full_dec.find("### Response")] if idx != -1] + [len(full_dec)])
      cleaned = full_dec[:first_marker_idx].strip()
      cleaned_ids = tokenizer(cleaned, return_tensors="pt", add_special_tokens=False).input_ids.squeeze(0).tolist() if cleaned else []
      generated_token_ids = cleaned_ids
      break

    if eos_token_id is not None and chosen_token_id == eos_token_id:
      break

  generated_text_raw = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
  generated_text = postprocess_generated_text(generated_text_raw)
  avg_logprob = float(np.mean(token_logprobs)) if token_logprobs else float("nan")
  num_gen = len(generated_token_ids)

  result = {
    "generated_text": generated_text,
    "generated_token_ids": generated_token_ids,
    "token_logprobs": token_logprobs,
    "avg_logprob": avg_logprob,
    "num_generated_tokens": num_gen,
  }
  if return_full_token_ids:
    result["full_token_ids"] = input_ids.squeeze(0).tolist() + generated_token_ids

  return result

In [None]:
def run_topp_eval( cfg, p: float = 0.9, temperatures: Optional[List[float]] = None, max_new_tokens: Optional[int] = None,
    n_within_repeat: int = 10, prompt_for_within: int = 0, deterministic_seed_per_draw: Optional[int] = None):

  if temperatures is None:
    temperatures = [0.2, 0.5, 0.8, 1.0, 1.2]
  if max_new_tokens is None:
    max_new_tokens = cfg.max_new_tokens

  set_seed(cfg.seed)

  model, tokenizer, device = prepare_model_and_tokenizer(cfg.model_name, cfg.tokenizer_name or cfg.model_name, cfg.device)

  prompts = prepare_eval_prompts(cfg.dataset_name, cfg.prompt_template, seed=cfg.seed, test_size=cfg.test_size)
  num_prompts = len(prompts)

  results_by_temp = {}
  summary_by_temp = {}

  base_rng = random.Random(deterministic_seed_per_draw if deterministic_seed_per_draw is not None else cfg.seed)

  for T in temperatures:
    per_prompt_results = []
    avg_logprobs = []
    gen_texts = []

    for i, p_item in enumerate(prompts):
      call_seed = None
      if deterministic_seed_per_draw is not None:
        call_seed = base_rng.randint(0, 2**31 - 1)

      out = top_p_sample_single(model=model, tokenizer=tokenizer, prompt=p_item["prompt"], max_new_tokens=max_new_tokens, p=p,
          temperature=T, eos_token_id=tokenizer.eos_token_id, device=device, seed=call_seed, return_full_token_ids=False)

      per_prompt_results.append({
        "id": p_item["id"],
        "prompt": p_item["prompt"],
        "generated_text": out["generated_text"],
        "avg_logprob": out["avg_logprob"],
        "num_generated_tokens": out["num_generated_tokens"],
      })

      if not math.isnan(out["avg_logprob"]):
        avg_logprobs.append(out["avg_logprob"])
      gen_texts.append(out["generated_text"])

      if (i + 1) % cfg.print_every == 0:
        print(f"[Top-P T={T}] Decoded {i+1}/{len(prompts)} prompts...")

    distinct1 = distinct_n(gen_texts, 1, tokenizer)
    distinct2 = distinct_n(gen_texts, 2, tokenizer)
    distinct3 = distinct_n(gen_texts, 3, tokenizer)
    avg_logprob_all = float(np.nanmean(avg_logprobs)) if avg_logprobs else float("nan")

    within_prompt_gen = []
    if num_prompts > 0:
      within_idx = max(0, min(prompt_for_within, num_prompts - 1))
      single_prompt = prompts[within_idx]["prompt"]
      for _ in range(n_within_repeat):
        call_seed = None
        if deterministic_seed_per_draw is not None:
          call_seed = base_rng.randint(0, 2**31 - 1)
        o = top_p_sample_single(model=model, tokenizer=tokenizer, prompt=single_prompt, max_new_tokens=max_new_tokens, p=p,
          temperature=T, eos_token_id=tokenizer.eos_token_id, device=device, seed=call_seed, return_full_token_ids=False)
        within_prompt_gen.append(o["generated_text"])
    within_dist1 = distinct_n(within_prompt_gen, 1, tokenizer)
    within_dist2 = distinct_n(within_prompt_gen, 2, tokenizer)
    within_dist3 = distinct_n(within_prompt_gen, 3, tokenizer)

    results_by_temp[T] = per_prompt_results
    summary_by_temp[T] = {
      "num_prompts": num_prompts,
      "p": p,
      "temperature": T,
      "avg_logprob_all": avg_logprob_all,
      "distinct_1_across": distinct1,
      "distinct_2_across": distinct2,
      "distinct_3_across": distinct3,
      "distinct_1_within": within_dist1,
      "distinct_2_within": within_dist2,
      "distinct_3_within": within_dist3,
    }

    print(f"[Top-P summary] T={T:.2f}, p={p:.3f} -> avg_logprob={avg_logprob_all:.4f}, distinct-1 across={distinct1:.4f}, within={within_dist1:.4f}")

  return {
    "config": cfg,
    "p": p,
    "temperatures": temperatures,
    "results_by_temp": results_by_temp,
    "summary_by_temp": summary_by_temp,
  }

In [None]:
p_val = 0.9
temperatures = [0.2, 0.5, 0.8, 1.0, 1.2]
topp_out = run_topp_eval(GreedyConfig, p=p_val, temperatures=temperatures, max_new_tokens=GreedyConfig.max_new_tokens,
                         n_within_repeat=10, prompt_for_within=0, deterministic_seed_per_draw=GreedyConfig.seed)
import pickle
with open("topp_out.pkl", "wb") as f:
  pickle.dump(topp_out, f)


In [None]:
def _extract_top_p_structs(topp_obj: Dict[str, Any]):

  if isinstance(topp_obj, dict) and "summary_by_temp" in topp_obj and "results_by_temp" in topp_obj:
    return topp_obj["summary_by_temp"], topp_obj["results_by_temp"], sorted(list(topp_obj["summary_by_temp"].keys()))
  if isinstance(topp_obj, dict) and "summary_by_method" in topp_obj and "results_by_method" in topp_obj:
    s_by_method = topp_obj["summary_by_method"].get("top_p", {})
    r_by_method = topp_obj["results_by_method"].get("top_p", {})
    temps = sorted(list(s_by_method.keys()))
    return s_by_method, r_by_method, temps
  if isinstance(topp_obj, dict):
    keys = list(topp_obj.keys())
    if keys and isinstance(keys[0], (int, float)):
      temps = sorted(keys)
      return topp_obj, {}, temps
  raise ValueError("Unsupported topp object shape. Provide topp_out or full sampling result.")


def plot_topp_temperature_curves(topp_obj: Dict[str, Any], outdir: Optional[str] = None):
  summary_by_temp, _, temps = _extract_top_p_structs(topp_obj)
  if not temps:
    raise ValueError("No temperatures found in Top-P object.")

  d1 = [summary_by_temp[T].get("distinct_1_across", 0.0) for T in temps]
  d2 = [summary_by_temp[T].get("distinct_2_across", 0.0) for T in temps]
  d3 = [summary_by_temp[T].get("distinct_3_across", 0.0) for T in temps]
  avg_lp = [summary_by_temp[T].get("avg_logprob_all", np.nan) for T in temps]

  fig, ax = plt.subplots(figsize=(7,4))
  ax.plot(temps, d1, marker='o', label='Distinct-1')
  ax.plot(temps, d2, marker='o', label='Distinct-2')
  ax.plot(temps, d3, marker='o', label='Distinct-3')
  ax.set_xlabel("Temperature")
  ax.set_ylabel("Distinct-N")
  ax.set_title("Top-P: Distinct-N vs Temperature")
  ax.legend()
  ax.grid(linestyle='--', linewidth=0.4)
  if outdir:
    p = os.path.join(outdir, "topp_distinct_vs_temp.png")
    os.makedirs(os.path.dirname(p), exist_ok=True)
    plt.savefig(p, dpi=200)
  plt.show()

  fig, ax = plt.subplots(figsize=(7,4))
  ax.plot(temps, avg_lp, marker='o')
  ax.set_xlabel("Temperature")
  ax.set_ylabel("Avg token log-prob")
  ax.set_title("Top-P: Avg token log-prob vs Temperature")
  ax.grid(linestyle='--', linewidth=0.4)
  if outdir:
    p = os.path.join(outdir, "topp_avglogprob_vs_temp.png")
    os.makedirs(os.path.dirname(p), exist_ok=True)
    plt.savefig(p, dpi=200)
  plt.show()

  return {"distinct_vs_temp": None, "avglogprob_vs_temp": None}


def plot_topp_within_across_bar(topp_obj: Dict[str, Any], temperature: float, outpath: Optional[str] = None):
  summary_by_temp, _, temps = _extract_top_p_structs(topp_obj)
  if not temps:
    raise ValueError("No temperatures in Top-P data")
  if temperature not in summary_by_temp:
    temperature = min(temps, key=lambda x: abs(x - temperature))
  s = summary_by_temp[temperature]
  summary_like = {
      "distinct_1_across": s.get("distinct_1_across", 0.0),
      "distinct_2_across": s.get("distinct_2_across", 0.0),
      "distinct_3_across": s.get("distinct_3_across", 0.0),
      "distinct_1_within": s.get("distinct_1_within", 0.0),
      "distinct_2_within": s.get("distinct_2_within", 0.0),
      "distinct_3_within": s.get("distinct_3_within", 0.0),
  }
  plot_distinct_across_within(summary_like, outpath=outpath, method_name = "Top-P")
  return None

def plot_topp_scatter_tradeoff(topp_obj: Dict[str, Any], outpath: Optional[str] = None):
  summary_by_temp, _, temps = _extract_top_p_structs(topp_obj)
  temps = sorted(temps)
  xs = [summary_by_temp[T].get("distinct_1_across", 0.0) for T in temps]
  ys = [summary_by_temp[T].get("avg_logprob_all", np.nan) for T in temps]
  fig, ax = plt.subplots(figsize=(6,5))
  ax.scatter(xs, ys)
  for i, T in enumerate(temps):
    ax.annotate(f"T={T}", (xs[i], ys[i]))
  ax.set_xlabel("Distinct-1 (Across)")
  ax.set_ylabel("Avg token log-prob")
  ax.set_title("Top-P: Diversity (Distinct-1) vs Quality (avg logprob) across Temperatures")
  ax.grid(linestyle='--', linewidth=0.4)
  if outpath:
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    plt.savefig(outpath, dpi=200)
  plt.show()
  return fig

def plot_topp_from_eval(topp_obj: Dict[str, Any], outdir: Optional[str] = None) -> Dict[str, Optional[str]]:
  summary_by_temp, results_by_temp, temps = _extract_top_p_structs(topp_obj)
  saved = {}
  if not temps:
    print("No temperature data to plot.")
    return saved

  plot_topp_temperature_curves(topp_obj, outdir=outdir)
  saved["curves"] = os.path.join(outdir, "topp_distinct_vs_temp.png") if outdir else None

  med_temp = sorted(temps)[len(temps)//2]
  p_within = os.path.join(outdir, f"topp_within_across_T{med_temp}.png") if outdir else None
  plot_topp_within_across_bar(topp_obj, temperature=med_temp, outpath=p_within)
  saved["within_across"] = p_within

  p_scatter = os.path.join(outdir, "topp_tradeoff.png") if outdir else None
  plot_topp_scatter_tradeoff(topp_obj, outpath=p_scatter)
  saved["tradeoff"] = p_scatter

  if results_by_temp and (med_temp in results_by_temp):
    p_len = os.path.join(outdir, f"topp_length_hist_T{med_temp}.png") if outdir else None
    try:
      plot_length_histogram(results_by_temp[med_temp], outpath=p_len)
      saved["length_hist"] = p_len
    except Exception:
      saved["length_hist"] = None

  return saved

saved = plot_topp_from_eval(topp_out, outdir="figs/top_p")
print("Saved top-p plots:", saved)

In [None]:
def batch_reward_scores(reward_model, reward_tokenizer, questions: List[str], answers: List[str],
    device: Optional[str] = None, batch_size: int = 32, truncation: bool = True, max_length: int = 1024,):

  assert len(questions) == len(answers)
  device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  reward_model.to(device)
  reward_model.eval()
  scores = []
  with torch.no_grad():
    for start in range(0, len(questions), batch_size):
      end = min(len(questions), start + batch_size)
      qs = questions[start:end]
      ans = answers[start:end]
      inputs = reward_tokenizer(qs, ans, return_tensors="pt", padding=True, truncation=truncation, max_length=max_length)
      inputs = {k: v.to(device) for k, v in inputs.items()}
      out = reward_model(**inputs)
      logits = out.logits
      if logits.dim() == 1:
        batch_vals = logits.detach().cpu().tolist()
      elif logits.dim() == 2 and logits.size(1) == 1:
        batch_vals = logits[:,0].detach().cpu().tolist()
      elif logits.dim() == 2:
        batch_vals = logits[:,0].detach().cpu().tolist()
      else:
        batch_vals = logits.view(logits.size(0), -1)[:,0].detach().cpu().tolist()
      scores.extend([float(x) for x in batch_vals])
  return scores


In [None]:
from transformers import AutoModelForSequenceClassification

DEFAULT_REWARD_MODEL = "OpenAssistant/reward-model-deberta-v3-large-v2"

def _reward_score_for_pair(reward_model, reward_tokenizer, question: str, answer: str, device: Optional[str] = None,
                           truncation: bool = True, max_length: int = 1024):

  device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  inputs = reward_tokenizer(question, answer, return_tensors="pt", truncation=truncation, max_length=max_length)
  inputs = {k: v.to(device) for k, v in inputs.items()}
  with torch.no_grad():
    out = reward_model(**inputs)
    logits = out.logits
    if logits.ndim == 1 or logits.shape[1] == 1:
      score = float(logits.view(-1)[0].cpu().item())
    else:
      score = float(logits[0, 0].cpu().item())
return score

def run_sampling_temperature_grid(cfg, k: int = 40, p_val: float = 0.9, temperatures: Optional[List[float]] = None, max_new_tokens: Optional[int] = None, test_size: Optional[int] = None,
                                  reward_model_name: str = DEFAULT_REWARD_MODEL, reward_batch_size: int = 32, n_within_repeat: int = 10, prompt_for_within: int = 0,
                                  deterministic_seed_per_draw: Optional[int] = None, trust_remote_code_for_reward: bool = False,):
  if temperatures is None:
    temperatures = [0.2, 0.5, 0.8, 1.0, 1.2]
  if max_new_tokens is None:
    max_new_tokens = cfg.max_new_tokens
  if test_size is None:
    test_size = cfg.test_size

  set_seed(cfg.seed)
  model, tokenizer, device = prepare_model_and_tokenizer(cfg.model_name, cfg.tokenizer_name or cfg.model_name, cfg.device)

  reward_device = ("cuda" if torch.cuda.is_available() else "cpu")
  reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_name, use_fast=True)
  reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_name, trust_remote_code=trust_remote_code_for_reward).to(reward_device)
  reward_model.eval()

  prompts = prepare_eval_prompts(cfg.dataset_name, cfg.prompt_template, seed=cfg.seed, test_size=test_size)
  num_prompts = len(prompts)
  base_rng = random.Random(deterministic_seed_per_draw if deterministic_seed_per_draw is not None else cfg.seed)

  results_by_method = {"top_k": {}, "top_p": {}}
  summary_by_method = {"top_k": {}, "top_p": {}}

  def _make_question_text(p_item):
    instr = p_item.get("instruction", "") or ""
    inp = p_item.get("input", "") or ""
    return (instr.strip() + ("\nInput: " + inp.strip() if inp.strip() else "")).strip()

  for T in temperatures:
    per_prompt_results_k = []
    gen_texts_k = []
    avg_logprobs_k = []
    questions_k = []

    for i, p_item in enumerate(prompts):
      call_seed = None
      if deterministic_seed_per_draw is not None:
        call_seed = base_rng.randint(0, 2**31 - 1)

      out = top_k_sample_single(model=model, tokenizer=tokenizer, prompt=p_item["prompt"], max_new_tokens=max_new_tokens, k=k,
          temperature=T, eos_token_id=tokenizer.eos_token_id, device=device, seed=call_seed, return_full_token_ids=False)

      gen_text = out["generated_text"]
      avg_lp = out["avg_logprob"]
      gen_texts_k.append(gen_text)
      questions_k.append(_make_question_text(p_item))
      per_prompt_results_k.append({
        "id": p_item["id"],
        "prompt": p_item["prompt"],
        "generated_text": gen_text,
        "avg_logprob": avg_lp,
        "reward_score": None,
        "num_generated_tokens": out["num_generated_tokens"],
        "distinct_1": distinct_n([gen_text], 1, tokenizer),
        "distinct_2": distinct_n([gen_text], 2, tokenizer),
        "distinct_3": distinct_n([gen_text], 3, tokenizer),
      })
      if not math.isnan(avg_lp):
        avg_logprobs_k.append(avg_lp)
      if (i + 1) % cfg.print_every == 0:
        print(f"[Top-K T={T}] Decoded {i+1}/{num_prompts} prompts...")

    try:
      rewards_k = batch_reward_scores(reward_model, reward_tokenizer, questions_k, gen_texts_k, device=reward_device, batch_size=reward_batch_size, max_length=1024)
    except Exception as e:
      print("Batched reward scoring failed for Top-K:", e, "FALLING BACK to per-sample scoring.")
      rewards_k = [_reward_score_for_pair(reward_model, reward_tokenizer, questions_k[i], gen_texts_k[i], device=reward_device) for i in range(len(gen_texts_k))]

    avg_rewards_k = []
    for idx, rscore in enumerate(rewards_k):
      per_prompt_results_k[idx]["reward_score"] = float(rscore)
      avg_rewards_k.append(float(rscore))

    distinct1_k = distinct_n(gen_texts_k, 1, tokenizer)
    distinct2_k = distinct_n(gen_texts_k, 2, tokenizer)
    distinct3_k = distinct_n(gen_texts_k, 3, tokenizer)

    avg_logprob_all_k = float(np.nanmean(avg_logprobs_k)) if avg_logprobs_k else float("nan")
    avg_reward_all_k = float(np.mean(avg_rewards_k)) if avg_rewards_k else float("nan")

    within_gen_k = []
    if num_prompts > 0:
      idx_within = max(0, min(prompt_for_within, num_prompts - 1))
      single_prompt = prompts[idx_within]
      for rep in range(n_within_repeat):
        call_seed = None
        if deterministic_seed_per_draw is not None:
          call_seed = base_rng.randint(0, 2**31 - 1)
        o = top_k_sample_single(
            model=model, tokenizer=tokenizer, prompt=single_prompt["prompt"],
            max_new_tokens=max_new_tokens, k=k, temperature=T,
            eos_token_id=tokenizer.eos_token_id, device=device, seed=call_seed, return_full_token_ids=False
        )
        within_gen_k.append(o["generated_text"])
    within_dist1_k = distinct_n(within_gen_k, 1, tokenizer)
    within_dist2_k = distinct_n(within_gen_k, 2, tokenizer)
    within_dist3_k = distinct_n(within_gen_k, 3, tokenizer)

    results_by_method["top_k"][T] = per_prompt_results_k
    summary_by_method["top_k"][T] = {
      "num_prompts": num_prompts, "k": k, "temperature": T,
      "avg_logprob_all": avg_logprob_all_k, "avg_reward_all": avg_reward_all_k,
      "distinct_1_across": distinct1_k, "distinct_2_across": distinct2_k, "distinct_3_across": distinct3_k,
      "distinct_1_within": within_dist1_k, "distinct_2_within": within_dist2_k, "distinct_3_within": within_dist3_k,
    }

    print(f"[Top-K summary] T={T:.2f}, k={k} -> avg_logprob={avg_logprob_all_k:.4f}, avg_reward={avg_reward_all_k:.4f}, distinct-1 across={distinct1_k:.4f}")

    per_prompt_results_p = []
    gen_texts_p = []
    avg_logprobs_p = []
    questions_p = []

    for i, p_item in enumerate(prompts):
      call_seed = None
      if deterministic_seed_per_draw is not None:
        call_seed = base_rng.randint(0, 2**31 - 1)

      out = top_p_sample_single(model=model, tokenizer=tokenizer, prompt=p_item["prompt"], max_new_tokens=max_new_tokens, p=p_val,
          temperature=T, eos_token_id=tokenizer.eos_token_id, device=device, seed=call_seed, return_full_token_ids=False)

      gen_text = out["generated_text"]
      avg_lp = out["avg_logprob"]
      gen_texts_p.append(gen_text)
      questions_p.append(_make_question_text(p_item))
      per_prompt_results_p.append({
        "id": p_item["id"],
        "prompt": p_item["prompt"],
        "generated_text": gen_text,
        "avg_logprob": avg_lp,
        "reward_score": None,
        "num_generated_tokens": out["num_generated_tokens"],
        "distinct_1": distinct_n([gen_text], 1, tokenizer),
        "distinct_2": distinct_n([gen_text], 2, tokenizer),
        "distinct_3": distinct_n([gen_text], 3, tokenizer),
      })
      if not math.isnan(avg_lp):
        avg_logprobs_p.append(avg_lp)
      if (i + 1) % cfg.print_every == 0:
        print(f"[Top-P T={T}] Decoded {i+1}/{num_prompts} prompts...")

    try:
      rewards_p = batch_reward_scores(reward_model, reward_tokenizer, questions_p, gen_texts_p, device=reward_device, batch_size=reward_batch_size, max_length=1024)
    except Exception as e:
      print("Batched reward scoring failed for Top-P:", e, "FALLING BACK.")
      rewards_p = [_reward_score_for_pair(reward_model, reward_tokenizer, questions_p[i], gen_texts_p[i], device=reward_device) for i in range(len(gen_texts_p))]

    avg_rewards_p = []
    for idx, rscore in enumerate(rewards_p):
      per_prompt_results_p[idx]["reward_score"] = float(rscore)
      avg_rewards_p.append(float(rscore))

    distinct1_p = distinct_n(gen_texts_p, 1, tokenizer)
    distinct2_p = distinct_n(gen_texts_p, 2, tokenizer)
    distinct3_p = distinct_n(gen_texts_p, 3, tokenizer)

    avg_logprob_all_p = float(np.nanmean(avg_logprobs_p)) if avg_logprobs_p else float("nan")
    avg_reward_all_p = float(np.mean(avg_rewards_p)) if avg_rewards_p else float("nan")

    within_gen_p = []
    if num_prompts > 0:
      idx_within = max(0, min(prompt_for_within, num_prompts - 1))
      single_prompt = prompts[idx_within]
      for rep in range(n_within_repeat):
        call_seed = None
        if deterministic_seed_per_draw is not None:
          call_seed = base_rng.randint(0, 2**31 - 1)
        o = top_p_sample_single(
            model=model, tokenizer=tokenizer, prompt=single_prompt["prompt"],
            max_new_tokens=max_new_tokens, p=p_val, temperature=T,
            eos_token_id=tokenizer.eos_token_id, device=device, seed=call_seed, return_full_token_ids=False
        )
        within_gen_p.append(o["generated_text"])
    within_dist1_p = distinct_n(within_gen_p, 1, tokenizer)
    within_dist2_p = distinct_n(within_gen_p, 2, tokenizer)
    within_dist3_p = distinct_n(within_gen_p, 3, tokenizer)

    results_by_method["top_p"][T] = per_prompt_results_p
    summary_by_method["top_p"][T] = {
      "num_prompts": num_prompts, "p": p_val, "temperature": T,
      "avg_logprob_all": avg_logprob_all_p, "avg_reward_all": avg_reward_all_p,
      "distinct_1_across": distinct1_p, "distinct_2_across": distinct2_p, "distinct_3_across": distinct3_p,
      "distinct_1_within": within_dist1_p, "distinct_2_within": within_dist2_p, "distinct_3_within": within_dist3_p,
    }

    print(f"[Top-P summary] T={T:.2f}, p={p_val:.3f} -> avg_logprob={avg_logprob_all_p:.4f}, avg_reward={avg_reward_all_p:.4f}, distinct-1 across={distinct1_p:.4f}")

  return {
    "config": cfg,
    "k": k,
    "p": p_val,
    "temperatures": temperatures,
    "results_by_method": results_by_method,
    "summary_by_method": summary_by_method,
  }

In [None]:
temperatures = [0.2, 0.5, 0.8, 1.0, 1.2]

sampling_out = run_sampling_temperature_grid(
    cfg=GreedyConfig,
    k=50,
    p_val=0.9,
    temperatures=temperatures,
    max_new_tokens=128,
    test_size=400,
    reward_model_name=DEFAULT_REWARD_MODEL,
    reward_batch_size=32,
    n_within_repeat=10,
    prompt_for_within=0,
    deterministic_seed_per_draw=GreedyConfig.seed,
    trust_remote_code_for_reward=False
)

In [None]:
def load_sampling_out(path: str) -> Dict[str, Any]:
  with open(path, "rb") as f:
    return pickle.load(f)

def _safe_get_summary(sampling_out: Dict[str, Any], method: str, T) -> Optional[Dict[str, Any]]:
  try:
    return sampling_out["summary_by_method"][method][T]
  except Exception:
    return None

def _safe_get_per_prompt(sampling_out: Dict[str, Any], method: str, T) -> Optional[List[Dict[str,Any]]]:
  try:
    return sampling_out["results_by_method"][method][T]
  except Exception:
    return None

def plot_diversity_quality_tradeoff(
    sampling_out: Dict[str, Any],
    methods: Iterable[str] = ("top_k", "top_p"),
    distinct_key: str = "distinct_1_across",
    reward_key: str = "avg_reward_all",
    temp_list: Optional[List[float]] = None,
    labels_map: Optional[Dict[str, str]] = None,
    figsize: tuple = (8,6),
    outpath: Optional[str] = None,
    show_errorbars: bool = True,
    annotate_offset: float = 0.005,
):
    if temp_list is None:
      temp_list = sampling_out.get("temperatures", None)
      if temp_list is None:
        raise ValueError("No temperatures in sampling_out and no temp_list provided.")

    if labels_map is None:
      labels_map = {m: m for m in methods}

    plt.figure(figsize=figsize)
    ax = plt.gca()

    colors = {
        "top_k": "tab:blue",
        "top_p": "tab:orange",
        "greedy": "tab:green",
        "beam": "tab:red"
    }
    markers = {
        "top_k": "o", "top_p": "s", "greedy": "D", "beam": "X"
    }

    any_plotted = False

    for method in methods:
      xs = []
      ys = []
      x_err = []
      y_err = []
      temps_present = []

      for T in temp_list:
        summary = _safe_get_summary(sampling_out, method, T)
        per_prompt = _safe_get_per_prompt(sampling_out, method, T)

        if summary is None:
          continue

        x_val = summary.get(distinct_key, None)
        y_val = summary.get(reward_key, None)

        if y_val is None and "avg_logprob_all" in summary:
          y_val = summary.get("avg_logprob_all")

        if show_errorbars and per_prompt is not None:
          per_dist_vals = []
          per_reward_vals = []
          for p in per_prompt:
            pd = None
            if "distinct_1" in p and distinct_key.endswith("_1_across"):
              pd = p.get("distinct_1", None)
            else:
              key_guess = distinct_key.split("_across")[0] if "_across" in distinct_key else None
              if key_guess and key_guess in p:
                pd = p.get(key_guess)
              else:
                pd = None
            per_dist_vals.append(pd if pd is not None else np.nan)
            per_reward_vals.append(p.get("reward_score", np.nan))
          try:
            x_std = float(np.nanstd(per_dist_vals)) if any([not np.isnan(v) for v in per_dist_vals]) else 0.0
          except Exception:
            x_std = 0.0
          try:
            y_std = float(np.nanstd(per_reward_vals)) if any([not np.isnan(v) for v in per_reward_vals]) else 0.0
          except Exception:
            y_std = 0.0
        else:
            x_std = 0.0
            y_std = 0.0

        if (x_val is None) or (y_val is None):
          continue

        xs.append(float(x_val))
        ys.append(float(y_val))
        x_err.append(x_std)
        y_err.append(y_std)
        temps_present.append(T)

      if not temps_present:
        print(f"[plot] No data found for method='{method}'. Skipping.")
        continue

      arr = np.array(list(zip(temps_present, xs, ys, x_err, y_err)), dtype=float)
      arr = arr[np.argsort(arr[:,0])]
      temps_sorted = arr[:,0]
      xs_sorted = arr[:,1]
      ys_sorted = arr[:,2]
      xerr_sorted = arr[:,3]
      yerr_sorted = arr[:,4]

      color = colors.get(method, None)
      marker = markers.get(method, "o")

      ax.plot(xs_sorted, ys_sorted, marker=marker, linestyle='-', label=labels_map.get(method, method), color=color)
      if show_errorbars and (np.any(xerr_sorted > 0) or np.any(yerr_sorted > 0)):
        ax.errorbar(xs_sorted, ys_sorted, xerr=xerr_sorted, yerr=yerr_sorted, fmt='none', ecolor=color, alpha=0.4, capsize=3)

      for xi, yi, Ti in zip(xs_sorted, ys_sorted, temps_sorted):
        ax.annotate(f"{Ti:.2f}", xy=(xi, yi), xytext=(6,6), textcoords="offset points", fontsize=9, color=color)

      any_plotted = True

    if not any_plotted:
      raise RuntimeError("No data plotted — check sampling_out structure and method names.")

    ax.set_xlabel("Distinct-1 (unique unigrams / total unigrams) — across prompts")
    ax.set_ylabel("Quality (reward score) — higher is better")
    ax.set_title("Diversity vs Quality trade-off (Distinct-1 vs Reward) across Temperatures")
    ax.grid(axis="both", linestyle='--', alpha=0.4)
    ax.legend()
    plt.tight_layout()

    if outpath:
      os.makedirs(os.path.dirname(outpath), exist_ok=True)
      plt.savefig(outpath, dpi=200)
      print(f"Saved figure to {outpath}")

    plt.show()
    return ax

sampling_out = load_sampling_out("/content/reward-model/sampling_out.pkl")
plot_diversity_quality_tradeoff(sampling_out, methods=("top_k","top_p"), outpath="figs/diversity_vs_quality_topk_topp.png")