In [2]:
# ==========================================================
# FAST Text Variant Generator (batched LLM + semantic filtering + resume/continue)
# ==========================================================
import os, re, json, random, math, difflib, time
from typing import List, Dict
from tqdm import tqdm

# ----------------- CONFIG -----------------
REPO_ROOT = "/home/myid/bp67339/code/crop-care/backend"
DATA_DIR  = f"{REPO_ROOT}/data"

IN_TRAIN   = f"{DATA_DIR}/train_clean_step1.jsonl"          # [{"text":"...", "label":"..."}]
OUT_AUG    = f"{DATA_DIR}/train_phase3_augmented.jsonl"     # originals + LLM variants
OUT_PAIRS  = f"{DATA_DIR}/train_phase3_pairs.jsonl"         # (orig, variant) for consistency eval
OUT_REPORT = f"{DATA_DIR}/phase3_generation_report_llm.txt"
PROGRESS_FP= f"{DATA_DIR}/phase3_progress.json"             # resume checkpoint

# Small/fast instruct model (you can switch later)
HF_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# Speed knobs
BATCH_SIZE         = 16
VARIANTS_PER_SEED  = 2
MAX_NEW_TOKENS     = 80
TEMPERATURE        = 0.6
TOP_P              = 0.9
REPETITION_PENALTY = 1.05
SAVE_EVERY         = 500

# Scope
MAX_SEEDS          = None   # process all
SHUFFLE_SEEDS      = False  # deterministic continuation

# Optional heavy filters (off for speed)
USE_SBERT    = False
SBERT_MODEL  = "sentence-transformers/all-MiniLM-L6-v2"
SBERT_THRESH = 0.80
USE_NLI      = False
NLI_MODEL    = "roberta-large-mnli"

random.seed(42)
os.makedirs(DATA_DIR, exist_ok=True)

# ----------------- IO helpers -----------------
def read_jsonl(fp):
    rows=[]
    with open(fp,"r",encoding="utf-8") as f:
        for line in f:
            line=line.strip()
            if not line: continue
            rows.append(json.loads(line))
    return rows

def write_jsonl(fp, rows):
    os.makedirs(os.path.dirname(fp), exist_ok=True)
    with open(fp,"w",encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

# ----------------- Load seeds -----------------
rows = read_jsonl(IN_TRAIN)
print("Loaded:", len(rows))
if SHUFFLE_SEEDS:
    random.shuffle(rows)
if MAX_SEEDS:
    rows = rows[:MAX_SEEDS]
    print("Using subset:", len(rows))

# ----------------- Filters (lightweight) -----------------
def toks(s): return re.findall(r"[a-zA-Z']+", s.lower())
def keyword_set(s):
    stop=set("""a an the and or but if so then than with without of on in to for from by at is are was were be been being it its it's im i'm i you your u our we they them this that those these there's""".split())
    return set([w for w in toks(s) if w not in stop and len(w)>=3])

def similarity(a,b):  # fast string sim just for guardrails
    return difflib.SequenceMatcher(None, a.lower(), b.lower()).ratio()

def keep_variant(orig, var):
    var=var.strip()
    if not var or len(var.split())<5 or len(var)>400: return False
    sim=similarity(orig,var)
    if sim>=0.98 or sim<0.50: return False
    base=keyword_set(orig); ks=keyword_set(var)
    if len(base & ks) < max(1, min(3, math.ceil(0.2*len(base)))): return False
    if var.lower().startswith("disease:"): return False
    return True

# ----------------- Load LLM -----------------
import torch, importlib.util
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

load_kwargs = {"device_map": "auto"}
use_4bit = torch.cuda.is_available() and importlib.util.find_spec("bitsandbytes") is not None
if use_4bit:
    load_kwargs["load_in_4bit"] = True
    print("Using 4-bit quantization via bitsandbytes.")
else:
    print("bitsandbytes not found (or no CUDA). Using regular precision.")

tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, use_fast=True)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    HF_MODEL,
    torch_dtype=(torch.float16 if torch.cuda.is_available() else torch.float32),
    **load_kwargs
)

gen = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=MAX_NEW_TOKENS,
    temperature=TEMPERATURE,
    top_p=TOP_P,
    do_sample=True,
    repetition_penalty=REPETITION_PENALTY,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    return_full_text=False,
)
print("Accelerate device map:", getattr(model, "hf_device_map", None))

# ----------------- Prompting -----------------
SYSTEM_MSG = (
    "You rewrite plant disease descriptions into short, farmer-style sentences "
    "while strictly preserving the original meaning and symptoms. "
    "Do not add or remove symptoms. Do not name a disease. 1–2 sentences max."
)

def clamp(txt, max_chars=600):
    return txt if len(txt) <= max_chars else (txt[:max_chars] + "…")

def build_prompt(text, n):
    return (
        f"{SYSTEM_MSG}\n\n"
        f"ORIGINAL:\n{clamp(text)}\n\n"
        f"Write {n} farmer-style variants that preserve meaning. "
        f"Output EXACTLY a JSON list of strings (no extra text)."
    )

def parse_json_list(s: str):
    m = re.search(r"\[[\s\S]*\]", s)
    if not m: return []
    try:
        arr = json.loads(m.group(0))
        if isinstance(arr, list):
            return [x.strip() for x in arr if isinstance(x,str) and x.strip()]
    except Exception:
        return []
    return []

def get_generated_text(item):
    if isinstance(item, list):
        item = item[0]
    return item.get("generated_text", "")

# ----------------- Continue from existing -----------------
seen_orig, seen_dupes = set(), set()
out_rows, pairs_rows = [], []

if os.path.exists(OUT_AUG):
    out_rows = read_jsonl(OUT_AUG)
    for r in out_rows:
        txt  = (r.get("text") or "").strip().lower()
        lbl  = r.get("label")
        src  = r.get("source")
        if not txt or not lbl: 
            continue
        if src == "orig":
            seen_orig.add((lbl, txt))
        elif src == "aug":
            seen_dupes.add((lbl, txt))

if os.path.exists(OUT_PAIRS):
    pairs_rows = read_jsonl(OUT_PAIRS)

print(f"Already have {len(seen_orig)} originals and {len(seen_dupes)} aug variants from previous runs.")

# Filter to only remaining
def norm(s): return (s or "").strip().lower()
remaining_rows = []
for r in rows:
    base  = norm(r.get("text"))
    label = r.get("label")
    if base and label and (label, base) not in seen_orig:
        remaining_rows.append({"text": base, "label": label})

rows = remaining_rows
print("Remaining to process:", len(rows))

# Reset progress file
if os.path.exists(PROGRESS_FP):
    os.remove(PROGRESS_FP)

# ----------------- Batch loop -----------------
def process_batch(batch_rows):
    prompts = [build_prompt((r.get("text") or "").strip(), VARIANTS_PER_SEED) for r in batch_rows]
    return gen(prompts, batch_size=BATCH_SIZE)

print("Generating variants in batches…")
B = BATCH_SIZE
total = len(rows)
num_batches = (total + B - 1) // B
t0 = time.time()

for b, i in enumerate(tqdm(range(0, total, B), total=num_batches)):
    batch_rows = rows[i : min(i+B, total)]

    # add originals
    for r in batch_rows:
        base = (r.get("text") or "").strip()
        label= r.get("label")
        if not base or not label: 
            continue
        out_rows.append({"text": base, "label": label, "source":"orig", "tags":[]})

    # generate batch
    try:
        outs = process_batch(batch_rows)
    except Exception as e:
        print(f"\nBatch {i}-{i+len(batch_rows)} failed: {e}")
        continue

    # parse/filter
    for r, out in zip(batch_rows, outs):
        base  = (r.get("text") or "").strip()
        label = r.get("label")
        if not base or not label: 
            continue
        text_out = get_generated_text(out)
        cand = parse_json_list(text_out)
        if not cand:
            cand = [ln.strip(" -") for ln in text_out.split("\n") if len(ln.strip())>10][:VARIANTS_PER_SEED]

        kept_here = 0
        for v in cand:
            if kept_here >= VARIANTS_PER_SEED: break
            if not keep_variant(base, v): continue
            key=(label, v.lower().strip())
            if key in seen_dupes: continue
            seen_dupes.add(key)
            out_rows.append({"text": v, "label": label, "source":"aug", "tags":["llm"]})
            pairs_rows.append({"orig": base, "variant": v, "label": label, "tags":["llm"]})
            kept_here += 1

    # periodic save
    if (b + 1) % max(1, SAVE_EVERY // B) == 0:
        write_jsonl(OUT_AUG, out_rows)
        write_jsonl(OUT_PAIRS, pairs_rows)

# final save
write_jsonl(OUT_AUG, out_rows)
write_jsonl(OUT_PAIRS, pairs_rows)
with open(OUT_REPORT,"w") as f:
    f.write(f"Seeds used: {len(rows)}\n")
    f.write(f"Total (orig+aug): {len(out_rows)}\n")
    f.write(f"Pairs: {len(pairs_rows)}\n")
    f.write(f"Model: {HF_MODEL} | Batch: {BATCH_SIZE} | Variants/seed: {VARIANTS_PER_SEED} | MaxNewTokens: {MAX_NEW_TOKENS}\n")
    f.write(f"Took {(time.time()-t0)/60:.1f} min\n")

print("✅ Saved")
print("  Augmented:", OUT_AUG)
print("  Pairs:    ", OUT_PAIRS)
print("  Report:   ", OUT_REPORT)
print(f"Final counts — originals: {sum(1 for r in out_rows if r['source']=='orig')}, "
      f"variants: {sum(1 for r in out_rows if r['source']=='aug')}, pairs: {len(pairs_rows)}")

Loaded: 44494
bitsandbytes not found (or no CUDA). Using regular precision.


Device set to use cuda:0


Accelerate device map: {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 0, 'model.layers.17': 0, 'model.layers.18': 0, 'model.layers.19': 0, 'model.layers.20': 0, 'model.layers.21': 1, 'model.norm': 1, 'model.rotary_emb': 1, 'lm_head': 1}
Already have 5000 originals and 2870 aug variants from previous runs.
Remaining to process: 39493
Generating variants in batches…


  0%|          | 0/2469 [00:00<?, ?it/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  0%|          | 1/2469 [00:02<1:29:47,  2.18s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  0%|          | 2/2469 [00:04<1:30:22,  2.20s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  0%|          | 3/2469 [00:06<1:30:17,  2.20s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  0%|          | 4/2469 [00:08<1:29:43,  2.18s/it]A decoder-only architecture is being used, but right-padding was det

✅ Saved
  Augmented: /home/myid/bp67339/code/crop-care/backend/data/train_phase3_augmented.jsonl
  Pairs:     /home/myid/bp67339/code/crop-care/backend/data/train_phase3_pairs.jsonl
  Report:    /home/myid/bp67339/code/crop-care/backend/data/phase3_generation_report_llm.txt
Final counts — originals: 44493, variants: 25928, pairs: 25928
