In [1]:
import pandas as pd
import subprocess, re, time
from pathlib import Path
from sklearn.metrics import classification_report, accuracy_score, f1_score

# ----------------------------
# Config
# ----------------------------
MODEL = "llama3.1:8b"
TRAIN_CSV = "incidents_train.csv"
VALID_CSV = "incidents_valid.csv"

OUT_DIR = Path("results")
OUT_DIR.mkdir(exist_ok=True)

# new output (targeted few-shot 4 examples)
OUT_CSV = OUT_DIR / "few_shot4_valid_strict_stdin_canon.csv"

SAVE_EVERY = 20
SEED = 42

HAZ_LIST = [
    "allergens", "biological", "chemical", "food additives and flavourings",
    "foreign bodies", "fraud", "migration", "organoleptic aspects",
    "other hazard", "packaging defect"
]

PROD_LIST = [
    "alcoholic beverages", "cereals and bakery products",
    "cocoa and cocoa preparations, coffee and tea", "confectionery",
    "dietetic foods, food supplements, fortified foods", "fats and oils",
    "feed materials", "food additives and flavourings", "food contact materials",
    "fruits and vegetables", "herbs and spices", "honey and royal jelly",
    "ices and desserts", "meat, egg and dairy products", "non-alcoholic beverages",
    "nuts, nut products and seeds", "other food product / mixed", "pet feed",
    "prepared dishes and snacks", "seafood",
    "soups, broths, sauces and condiments", "sugars and syrups"
]

# ----------------------------
# Ollama runner (stdin + timeout)  âœ… Windows-safe for long prompts
# ----------------------------
def run_ollama(prompt: str, timeout_sec: int = 180) -> str:
    try:
        res = subprocess.run(
            ["ollama", "run", MODEL],
            input=prompt,                 # send prompt via stdin
            capture_output=True,
            text=True,
            encoding="utf-8",
            errors="replace",
            timeout=timeout_sec
        )
    except subprocess.TimeoutExpired:
        raise RuntimeError(f"ollama timeout after {timeout_sec}s")

    if res.returncode != 0:
        raise RuntimeError(f"ollama error: {res.stderr[:800]}")

    out = (res.stdout or "").strip()
    if not out:
        raise RuntimeError(f"empty output (stderr={res.stderr[:300]})")
    return out

# ----------------------------
# Parser (strict + canonicalization)
# ----------------------------
def parse_two_lines(out: str):
    haz = None
    prod = None

    m1 = re.search(r"hazard-category\s*:\s*(.+)", out, flags=re.IGNORECASE)
    m2 = re.search(r"product-category\s*:\s*(.+)", out, flags=re.IGNORECASE)

    if m1:
        haz = m1.group(1).strip()
    if m2:
        prod = m2.group(1).strip()

    HAZ_CANON = {
        "food additives and flavorings": "food additives and flavourings",
        "food additives & flavourings": "food additives and flavourings",
        "food additives & flavorings": "food additives and flavourings",
        "foreign body": "foreign bodies",
        "packaging defects": "packaging defect",
        "organoleptic": "organoleptic aspects",
        "other hazards": "other hazard",
    }

    PROD_CANON = {
        "non alcoholic beverages": "non-alcoholic beverages",
        "nonalcoholic beverages": "non-alcoholic beverages",
        "cocoa, coffee and tea": "cocoa and cocoa preparations, coffee and tea",
        "cocoa and coffee and tea": "cocoa and cocoa preparations, coffee and tea",
        "other food product/mixed": "other food product / mixed",
        "other food product /mixed": "other food product / mixed",
        "other food product mixed": "other food product / mixed",
        "meat egg and dairy products": "meat, egg and dairy products",
        "soups broths sauces and condiments": "soups, broths, sauces and condiments",
        "food contact material": "food contact materials",
    }

    def norm(x):
        if x is None:
            return None
        x = x.strip().strip('"').strip("'").strip()
        x = re.sub(r"\s+", " ", x)
        x = x.rstrip(" .,:;")
        return x

    def canonize(x, canon_map):
        if x is None:
            return None
        key = x.lower().strip()
        return canon_map.get(key, x)

    haz = canonize(norm(haz), HAZ_CANON)
    prod = canonize(norm(prod), PROD_CANON)

    haz_in_list = (haz in HAZ_LIST) if haz else False
    prod_in_list = (prod in PROD_LIST) if prod else False

    if not haz_in_list:
        haz = None
    if not prod_in_list:
        prod = None

    parse_ok = (haz is not None) and (prod is not None)
    return haz, prod, parse_ok, haz_in_list, prod_in_list

# ----------------------------
# Metrics
# ----------------------------
def compute_metrics(df_res: pd.DataFrame, true_col: str, pred_col: str, title: str):
    df_ok = df_res[df_res["parse_ok"] == True].copy()
    if len(df_ok) == 0:
        print(f"\n=== {title} Metrics ===\nNo parse_ok samples.")
        return

    y_true = df_ok[true_col].astype(str)
    y_pred = df_ok[pred_col].astype(str)

    acc = accuracy_score(y_true, y_pred)
    macro = f1_score(y_true, y_pred, average="macro", zero_division=0)
    micro = f1_score(y_true, y_pred, average="micro", zero_division=0)
    weighted = f1_score(y_true, y_pred, average="weighted", zero_division=0)

    print(f"\n=== {title} Metrics (on parse_ok samples: {len(df_ok)}/{len(df_res)}) ===")
    print("Accuracy:", acc)
    print("Macro-F1:", macro)
    print("Micro-F1:", micro)
    print("Weighted-F1:", weighted)
    print("\nClassification Report:\n")
    print(classification_report(y_true, y_pred, zero_division=0))

# ----------------------------
# Targeted 4-shot builder
# ----------------------------
def _fmt_example(ex) -> str:
    return (
        "Example:\n"
        f"Report:\nTitle: {str(ex['title'])}\nText: {str(ex['text'])}\n"
        f"Output:\nhazard-category: {ex['hazard-category']}\nproduct-category: {ex['product-category']}\n"
    )

def pick_one(train: pd.DataFrame, hazard_cat: str = None, prod_cat: str = None, seed: int = 0):
    sub = train
    if hazard_cat is not None:
        sub = sub[sub["hazard-category"] == hazard_cat]
    if prod_cat is not None:
        sub = sub[sub["product-category"] == prod_cat]
    if len(sub) == 0:
        return None
    return sub.sample(1, random_state=seed).iloc[0]

def build_few_shot4_block(train: pd.DataFrame) -> str:
    """
    4 targeted examples to address known weak spots:
    - fraud (hazard)
    - organoleptic aspects (hazard)
    - soups/broths/sauces (product)
    - cocoa/coffee/tea (product)
    """
    ex = []

    # 1) fraud (any product)
    r1 = pick_one(train, hazard_cat="fraud", seed=SEED)
    if r1 is not None: ex.append(r1)

    # 2) organoleptic aspects (any product)
    r2 = pick_one(train, hazard_cat="organoleptic aspects", seed=SEED + 1)
    if r2 is not None: ex.append(r2)

    # 3) soups/broths/sauces and condiments (any hazard)
    r3 = pick_one(train, prod_cat="soups, broths, sauces and condiments", seed=SEED + 2)
    if r3 is not None: ex.append(r3)

    # 4) cocoa and cocoa preparations, coffee and tea (any hazard)
    r4 = pick_one(train, prod_cat="cocoa and cocoa preparations, coffee and tea", seed=SEED + 3)
    if r4 is not None: ex.append(r4)

    # deduplicate
    seen = set()
    blocks = []
    for row in ex:
        key = (str(row["title"]), str(row["text"]))
        if key in seen:
            continue
        seen.add(key)
        blocks.append(_fmt_example(row))

    return "\n".join(blocks)

# ----------------------------
# Main
# ----------------------------
def main():
    train = pd.read_csv(TRAIN_CSV)
    valid = pd.read_csv(VALID_CSV)
    n = len(valid)
    print(f"Loaded train={len(train)} valid={n}")

    few_shot_examples = build_few_shot4_block(train)
    if not few_shot_examples.strip():
        print("ERROR: Could not build few-shot examples (missing target classes in train?).")
        return

    haz_opts = "; ".join(HAZ_LIST)
    prod_opts = "; ".join(PROD_LIST)

    PROMPT_TEMPLATE = f"""You are a food safety incident classifier.

Choose labels ONLY from the allowed lists.
Do NOT invent new labels.

Output ONLY two lines, no extra text.
Line1 must be: hazard-category: <one of: {haz_opts}>
Line2 must be: product-category: <one of: {prod_opts}>

Here are labeled examples:
{few_shot_examples}

Now classify this new report:
Report:
Title: {{title}}
Text: {{text}}
""".strip()

    # Sanity check
    print("\nSanity check (single call)...")
    sanity_prompt = (
        "You MUST output ONLY two lines, no extra text.\n"
        "Line1 must be: hazard-category: allergens\n"
        "Line2 must be: product-category: confectionery\n"
    )
    try:
        sanity_out = run_ollama(sanity_prompt, timeout_sec=120)
        print("Sanity output:\n", sanity_out[:400], "\n")
    except Exception as e:
        print("Sanity check failed:", str(e))
        print("Fix Ollama/model first (e.g., `ollama pull llama3.1:8b`, then `ollama run llama3.1:8b`).")
        return

    # resume support
    done = {}
    if OUT_CSV.exists():
        prev = pd.read_csv(OUT_CSV)
        for _, r in prev.iterrows():
            done[int(r["idx"])] = r.to_dict()
        print(f"Resuming: found {len(done)} already processed rows in {OUT_CSV}")

    rows = []
    start = time.time()

    for i, r in valid.iterrows():
        if int(i) in done:
            rows.append(done[int(i)])
            continue

        prompt = PROMPT_TEMPLATE.format(title=str(r["title"]), text=str(r["text"]))
        t0 = time.time()
        try:
            out = run_ollama(prompt, timeout_sec=180)
            haz_pred, prod_pred, parse_ok, haz_ok, prod_ok = parse_two_lines(out)
            err = ""
        except Exception as e:
            out = ""
            haz_pred, prod_pred, parse_ok, haz_ok, prod_ok = None, None, False, False, False
            err = str(e)

        dt = time.time() - t0

        rows.append({
            "idx": int(i),
            "parse_ok": bool(parse_ok),
            "haz_pred": haz_pred,
            "prod_pred": prod_pred,
            "haz_in_list": bool(haz_ok),
            "prod_in_list": bool(prod_ok),
            "haz_true": r.get("hazard-category"),
            "prod_true": r.get("product-category"),
            "latency_sec": round(dt, 3),
            "error": err,
            "raw_output": out[:1500]
        })

        if (i + 1) % 10 == 0:
            elapsed = time.time() - start
            print(f"[{i+1}/{n}] parse_ok={parse_ok} haz={haz_pred} prod={prod_pred} (elapsed {elapsed/60:.1f} min)")

        if (i + 1) % SAVE_EVERY == 0:
            pd.DataFrame(rows).to_csv(OUT_CSV, index=False)

    out_df = pd.DataFrame(rows)
    out_df.to_csv(OUT_CSV, index=False)
    print(f"\nSaved final: {OUT_CSV}")

    print("\nParse OK rate:", out_df["parse_ok"].mean())
    print("Hazard in-list rate:", out_df["haz_in_list"].mean())
    print("Product in-list rate:", out_df["prod_in_list"].mean())
    print("Avg latency (sec):", out_df["latency_sec"].mean())

    compute_metrics(out_df, "haz_true", "haz_pred", "Hazard-category")
    compute_metrics(out_df, "prod_true", "prod_pred", "Product-category")

if __name__ == "__main__":
    main()

Loaded train=5082 valid=565

Sanity check (single call)...
Sanity output:
 hazard-category: allergens
product-category: confectionery 

[10/565] parse_ok=True haz=biological prod=meat, egg and dairy products (elapsed 0.1 min)
[20/565] parse_ok=True haz=biological prod=fruits and vegetables (elapsed 0.2 min)
[30/565] parse_ok=True haz=allergens prod=cereals and bakery products (elapsed 0.3 min)
[40/565] parse_ok=True haz=allergens prod=meat, egg and dairy products (elapsed 0.4 min)
[50/565] parse_ok=True haz=foreign bodies prod=fruits and vegetables (elapsed 0.6 min)
[60/565] parse_ok=True haz=foreign bodies prod=prepared dishes and snacks (elapsed 0.6 min)
[70/565] parse_ok=True haz=foreign bodies prod=soups, broths, sauces and condiments (elapsed 0.7 min)
[80/565] parse_ok=True haz=foreign bodies prod=confectionery (elapsed 0.9 min)
[90/565] parse_ok=True haz=biological prod=prepared dishes and snacks (elapsed 1.0 min)
[100/565] parse_ok=True haz=allergens prod=meat, egg and dairy pro