In [None]:
# ============================================
# Llama-3-70B triplet generation via vLLM (OpenAI-compatible)
# ============================================
import os, json, math, random, time, collections, datetime
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

import jsonlines
from tenacity import retry, stop_after_attempt, wait_exponential_jitter

try:
    from openai import OpenAI
except Exception:
    !pip -q install openai==1.51.2
    from openai import OpenAI

# -----------------------
# CONFIG
# -----------------------
# Local vLLM server (started in step 1)
BASE_URL   = "http://127.0.0.1:8000/v1"
API_KEY    = "EMPTY"  # vLLM ignores, but the client expects a string
MODEL_NAME = "<hf_model_id>"   # same as you passed to vLLM; or --served-model-name

# Target accounting (you already produced 73 from InternVL)
TARGET_TOTAL = 2000
ALREADY_HAVE = 73
REMAINING    = TARGET_TOTAL - ALREADY_HAVE  # 1927

# Papers to skip this run
SKIP_SUBSTRINGS = ["InternVL"]

# Concurrency
MAX_WORKERS = 8
RNG_SEED    = 13
random.seed(RNG_SEED)

# Triplet policy (per chunk)
N_TRIPLETS_PER_CHUNK = 1           # predictable volume & cost

# Paths (adjust if needed)
CHUNKS_DIR = Path("data/chunks")   # folder with *.chunks.jsonl created earlier
OUT_DIR    = Path("outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Output file for this run
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
triplets_path = OUT_DIR / f"triplets_run_llama70b_{ts}.jsonl"

# Prompts (same structure as your GPT-5 pipeline)
SYSTEM_PROMPT = (
    "You are a meticulous data constructor. Given a technical passage (CHUNK), "
    "produce instruction-tuning triplets that are useful for training. "
    "Prefer concrete, unambiguous, domain-grounded questions."
)
GEN_PROMPT = (
    "You will receive a technical passage (CHUNK). Produce {k} high-quality instruction-tuning triplets.\n"
    'Each triplet is a JSON object with fields: "question", "input", "response".\n'
    "Return a JSON array of triplets only."
)

# OpenAI-compatible client pointed at vLLM
client = OpenAI(base_url=BASE_URL, api_key=API_KEY)

# -----------------------
# DISCOVER ELIGIBLE PAPERS
# -----------------------
def skip_file(p: Path) -> bool:
    name = p.name.lower()
    return any(s.lower() in name for s in SKIP_SUBSTRINGS)

chunk_files = sorted(CHUNKS_DIR.glob("*.chunks.jsonl"))
eligible_files = [p for p in chunk_files if not skip_file(p)]
assert eligible_files, "No eligible papers after applying SKIP_SUBSTRINGS."

num_papers = len(eligible_files)   # should be 36 for you
base = REMAINING // num_papers
extra = REMAINING - base * num_papers  # first 'extra' papers get +1
per_paper_targets = {p: base for p in eligible_files}
for p in eligible_files[:extra]:
    per_paper_targets[p] += 1

print(f"Eligible papers: {num_papers}")
print(f"Per-paper targets: base={base}, extra={extra}  → 53/54 mix")
print(f"k (triplets per chunk) = {N_TRIPLETS_PER_CHUNK}")
print(f"Output file → {triplets_path}")

# -----------------------
# READ CHUNKS & BUILD WORK
# -----------------------
def read_chunks(file_path: Path):
    rows = []
    with jsonlines.open(file_path, "r") as r:
        for obj in r:
            rows.append(obj)
    return rows

work = []  # list[(paper_path, chunk_obj)]
for cf in eligible_files:
    need_triplets = per_paper_targets[cf]                 # 53 or 54
    need_chunks   = math.ceil(need_triplets / N_TRIPLETS_PER_CHUNK)
    rows = read_chunks(cf)
    chosen = rows if len(rows) <= need_chunks else random.sample(rows, need_chunks)
    for obj in chosen:
        work.append((cf, obj))

# cap by theoretical max chunk need
max_chunks_needed = math.ceil(REMAINING / N_TRIPLETS_PER_CHUNK)
if len(work) > max_chunks_needed:
    work = work[:max_chunks_needed]

print(f"Planned chunks to process: {len(work)}  (cap: {max_chunks_needed})")

# -----------------------
# MODEL CALL (with retry)
# -----------------------
@retry(stop=stop_after_attempt(4), wait=wait_exponential_jitter(min=1, max=8))
def call_model(chunk_text: str, k: int):
    user_prompt = f"CHUNK:\n\n{chunk_text}\n\n---\nPlease output exactly a JSON array of {k} triplets."
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user",   "content": GEN_PROMPT.format(k=k)},
        {"role": "user",   "content": user_prompt},
    ]
    # vLLM supports Chat Completions compat
    resp = client.chat.completions.create(
        model=MODEL_NAME,
        messages=messages,
        temperature=0.2,          # adjust if you need more variety
        max_tokens=800,           # guardrail for long outputs
    )
    text = resp.choices[0].message.content
    try:
        data = json.loads(text)
        return data if isinstance(data, list) else []
    except Exception:
        # Be permissive: if malformed, return empty instead of failing the batch
        return []

def make_records(meta_chunk_obj, triplets_list):
    out = []
    for t in triplets_list:
        rec = {
            "question": (t.get("question") or "").strip(),
            "input":     t.get("input", ""),
            "response": (t.get("response") or "").strip(),
            "meta": meta_chunk_obj,  # carries source, chunk_id, tokens, etc.
        }
        if rec["question"] and rec["response"]:
            out.append(rec)
    return out

# -----------------------
# PARALLEL EXECUTION (single writer)
# -----------------------
written = 0
per_source_written = collections.Counter()

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex, jsonlines.open(triplets_path, "a") as w:
    futures = []
    for (src_path, chunk_obj) in work:
        fut = ex.submit(call_model, chunk_obj["text"], N_TRIPLETS_PER_CHUNK)
        fut.meta_chunk = chunk_obj
        futures.append(fut)

    for fut in as_completed(futures):
        triplets = fut.result()
        recs = make_records(fut.meta_chunk, triplets)
        for rec in recs:
            if written >= REMAINING:
                break
            w.write(rec)  # single writer (main thread) → no file corruption
            written += 1
            per_source_written[rec["meta"]["source"]] += 1
        if written >= REMAINING:
            break

print(f"[DONE] Wrote {written} triplets → {triplets_path}")
print("Per-paper tally (top 10):")
for k, v in per_source_written.most_common(10):
    print(f"  {k}: {v}")
print("Total papers written:", len(per_source_written))
