In [1]:
# Cell 1: Imports & basic config

import re
import os

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from tqdm.auto import tqdm

# === CONFIG ===
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"   
OUT_DIR    = "./mistral7b_qlora_out"               
DATA_PATH  = "spark_llm_dataset_50k.jsonl"         
USE_4BIT   = True                                  
N_EVAL     = 300                                    


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Cell 2: Load base model, tokenizer from OUT_DIR, and attach LoRA

if USE_4BIT:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        device_map="auto",
        quantization_config=bnb_config,
    )
else:
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        device_map="auto",
    )

# Tokenizer saved in OUT_DIR (your setup)
tokenizer = AutoTokenizer.from_pretrained(OUT_DIR)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Attach LoRA adapter from OUT_DIR
model = PeftModel.from_pretrained(base_model, OUT_DIR)
model.eval()

device = model.device
print("Model with LoRA loaded on:", device)

# Optional: confirm it's a PEFT model
try:
    print("PEFT config keys:", list(model.peft_config.keys()))
except Exception as e:
    print("No peft_config found:", e)


Loading checkpoint shards: 100%|██████████| 3/3 [00:06<00:00,  2.08s/it]


Model with LoRA loaded on: cuda:0
PEFT config keys: ['default']


In [3]:
# Cell 3: Load a 20-sample eval subset + prompt builder + extract_fields

# Load full JSONL dataset
full_ds = load_dataset("json", data_files=DATA_PATH, split="train")
print("Total samples in JSONL:", len(full_ds))

# Take a small random subset for evaluation (20 examples)
N_EVAL_ACTUAL = min(N_EVAL, len(full_ds))
eval_ds = full_ds.shuffle(seed=42).select(range(N_EVAL_ACTUAL))
print("Eval subset size:", len(eval_ds))

# === Prompt builder (eval-time) ===
# You can tweak wording here; this is the strict format version
def build_prompt(ex):
    return (
        "### Instruction:\n"
        "You are an extraction model for Spark logs. "
        "Given the log message below, output ONLY in the following exact format:\n\n"
        "EventId: <ID>\n"
        "EventTemplate: <template>\n\n"
        "Do not include any explanations, markdown, or extra text.\n\n"
        "### Input:\n"
        f"{ex['input']}\n\n"
        "### Response:\n"
    )

# === Regexes & extractor ===
# Non-greedy template capture; search only in the '### Response:' block
event_id_re = re.compile(r"EventId:\s*([A-Za-z0-9_-]+)")
tpl_re      = re.compile(r"EventTemplate:\s*(.+?)(?=\n###|\Z)", re.DOTALL)

def extract_fields(text: str):
    """
    Extract (EventId, EventTemplate) from the *response* portion of the text.
    - If '### Response:' is present, only search in that substring.
    - Template is captured non-greedily until the next '###' header or end of text.
    """
    # Focus only on response block if present
    if "### Response:" in text:
        text = text.split("### Response:", 1)[1]

    eid_matches = event_id_re.findall(text)
    tpl_matches = tpl_re.findall(text)

    eid = eid_matches[-1].strip() if eid_matches else None
    tpl = tpl_matches[-1].strip() if tpl_matches else None

    return eid, tpl

# Quick sanity check: gold output parsing
ex0 = eval_ds[0]
print("\n--- Sanity check on GOLD output (first example) ---")
print("GOLD output raw:\n", ex0["output"])
print("Parsed (gold_eid, gold_tpl):", extract_fields(ex0["output"]))


Total samples in JSONL: 6259
Eval subset size: 300

--- Sanity check on GOLD output (first example) ---
GOLD output raw:
 EventId: E180
EventTemplate: Started <*> remote fetches in <*> ms
Parsed (gold_eid, gold_tpl): ('E180', 'Started <*> remote fetches in <*> ms')


In [4]:
# Cell 4: Evaluate on 20 examples, printing each example in detail

def normalize_template(s: str) -> str:
    # Lowercase + collapse whitespace for a softer comparison
    return " ".join(s.lower().split())

correct_eid   = 0
correct_tpl   = 0
correct_both  = 0

eid_total   = 0
tpl_total   = 0
both_total  = 0

results = []  # store everything for inspection

for idx, ex in enumerate(tqdm(eval_ds, desc="Evaluating 20 examples"), start=1):
    gold_eid, gold_tpl = extract_fields(ex["output"])

    prompt = build_prompt(ex)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=32,
            do_sample=False,
        )

    gen_text = tokenizer.decode(out[0], skip_special_tokens=True)

    pred_eid, pred_tpl = extract_fields(gen_text)

    # --- Metrics bookkeeping ---
    if gold_eid is not None:
        eid_total += 1
        if pred_eid is not None and pred_eid == gold_eid:
            correct_eid += 1

    if gold_tpl is not None and pred_tpl is not None:
        tpl_total += 1
        if normalize_template(pred_tpl) == normalize_template(gold_tpl):
            correct_tpl += 1

    if (
        gold_eid is not None and gold_tpl is not None and
        pred_eid is not None and pred_tpl is not None
    ):
        both_total += 1
        if (pred_eid == gold_eid) and (
            normalize_template(pred_tpl) == normalize_template(gold_tpl)
        ):
            correct_both += 1

    # Save full info for this example
    results.append({
        "idx": idx,
        "input": ex["input"],
        "gold_output": ex["output"],
        "gen_output": gen_text,
        "gold_eid": gold_eid,
        "gold_tpl": gold_tpl,
        "pred_eid": pred_eid,
        "pred_tpl": pred_tpl,
    })

# Verbose print of all 20 examples
print("\n============================")
print("DETAILED PER-EXAMPLE OUTPUT")
print("============================")

for r in results:
    print("\n" + "=" * 80)
    print(f"Example {r['idx']}")
    print("- Input:")
    print(r["input"])
    print("\n- GOLD output:")
    print(r["gold_output"])
    print(f"\n  Parsed GOLD -> EventId={r['gold_eid']} | Template={repr(r['gold_tpl'])}")
    print("\n- GENERATED output:")
    print(r["gen_output"])
    print(f"\n  Parsed PRED -> EventId={r['pred_eid']} | Template={repr(r['pred_tpl'])}")


Evaluating 20 examples:   0%|          | 0/300 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Evaluating 20 examples:   0%|          | 1/300 [00:01<09:47,  1.96s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Evaluating 20 examples:   1%|          | 2/300 [00:03<08:48,  1.77s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Evaluating 20 examples:   1%|          | 3/300 [00:05<08:31,  1.72s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Evaluating 20 examples:   1%|▏         | 4/300 [00:06<07:52,  1.60s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Evaluating 20 examples:   2%|▏         | 5/300 [00:08<08:05,  1.64s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Evaluating 20 examples:   2%|▏         | 6/300 [00:09<07:42,  1.57s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Evaluating 20 examples:   2%|▏     


DETAILED PER-EXAMPLE OUTPUT

Example 1
- Input:
Started 13 remote fetches in 18 ms

- GOLD output:
EventId: E180
EventTemplate: Started <*> remote fetches in <*> ms

  Parsed GOLD -> EventId=E180 | Template='Started <*> remote fetches in <*> ms'

- GENERATED output:
### Instruction:
You are an extraction model for Spark logs. Given the log message below, output ONLY in the following exact format:

EventId: <ID>
EventTemplate: <template>

Do not include any explanations, markdown, or extra text.

### Input:
Started 13 remote fetches in 18 ms

### Response:
EventId: E180
EventTemplate: Started <*> remote fetches in <*> ms

  Parsed PRED -> EventId=E180 | Template='Started <*> remote fetches in <*> ms'

Example 2
- Input:
Registered signal handlers for [TERM, HUP, INT]

- GOLD output:
EventId: E164
EventTemplate: Registered signal handlers for <*>

  Parsed GOLD -> EventId=E164 | Template='Registered signal handlers for <*>'

- GENERATED output:
### Instruction:
You are an extraction mod




In [5]:
# Cell 5: Summary metrics for EventId, Template, and Joint correctness

import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

print("\n===================")
print("SUMMARY METRICS (EventId / Template / Joint)")
print("===================")

total = len(results)

# Basic counts
correct_eid   = 0
correct_tpl   = 0
correct_both  = 0

eid_total   = 0
tpl_total   = 0
both_total  = 0

y_true_eid = []
y_pred_eid = []

for r in results:
    gold_eid = r["gold_eid"]
    gold_tpl = r["gold_tpl"]
    pred_eid = r["pred_eid"]
    pred_tpl = r["pred_tpl"]

    # EventId metrics
    if gold_eid is not None:
        eid_total += 1
        y_true_eid.append(gold_eid if pred_eid is not None else gold_eid)
        y_pred_eid.append(pred_eid if pred_eid is not None else "__NONE__")
        if pred_eid is not None and pred_eid == gold_eid:
            correct_eid += 1

    # Template metrics
    if gold_tpl is not None and pred_tpl is not None:
        tpl_total += 1
        if normalize_template(pred_tpl) == normalize_template(gold_tpl):
            correct_tpl += 1

    # Joint metrics (both id + template correct)
    if (
        gold_eid is not None and gold_tpl is not None and
        pred_eid is not None and pred_tpl is not None
    ):
        both_total += 1
        if (pred_eid == gold_eid) and (
            normalize_template(pred_tpl) == normalize_template(gold_tpl)
        ):
            correct_both += 1

# Avoid division-by-zero
eid_acc  = correct_eid  / eid_total  if eid_total  > 0 else 0.0
tpl_acc  = correct_tpl  / tpl_total  if tpl_total  > 0 else 0.0
both_acc = correct_both / both_total if both_total > 0 else 0.0

print(f"Total eval examples:        {total}")
print(f"EventId labels present:     {eid_total}")
print(f"Template labels present:    {tpl_total}")
print(f"Both labels present:        {both_total}")

print(f"\nEventId accuracy:           {eid_acc:.3f}")
print(f"EventTemplate accuracy:     {tpl_acc:.3f}")
print(f"Joint accuracy (both):      {both_acc:.3f}")

# --------- Extra: multi-class classification metrics for EventId ---------
if eid_total > 0:
    y_true_eid = np.array(y_true_eid)
    y_pred_eid = np.array(y_pred_eid)

    # Filter out any placeholder "__NONE__" for cleaner metrics if you want strict “predicted nothing” as a class
    # (you can keep it if you want to treat missing predictions as a distinct class).
    # Here we keep them so we don't drop samples; macro F1 will still reflect failures.
    prec_macro = precision_score(y_true_eid, y_pred_eid, average="macro", zero_division=0)
    rec_macro  = recall_score(y_true_eid,  y_pred_eid, average="macro", zero_division=0)
    f1_macro   = f1_score(y_true_eid,     y_pred_eid, average="macro", zero_division=0)
    acc_cls    = accuracy_score(y_true_eid, y_pred_eid)

    print("\n=== EventId Classification Metrics (macro) ===")
    print(f"Accuracy:   {acc_cls:.3f}")
    print(f"Precision:  {prec_macro:.3f}")
    print(f"Recall:     {rec_macro:.3f}")
    print(f"F1-Score:   {f1_macro:.3f}")
else:
    print("\nNo EventId labels available for classification metrics.")



SUMMARY METRICS (EventId / Template / Joint)
Total eval examples:        300
EventId labels present:     300
Template labels present:    300
Both labels present:        300

EventId accuracy:           0.973
EventTemplate accuracy:     0.923
Joint accuracy (both):      0.903

=== EventId Classification Metrics (macro) ===
Accuracy:   0.973
Precision:  0.889
Recall:     0.901
F1-Score:   0.894


In [6]:
# Cell 6: Diagnostic-style metrics for the Spark log extraction model

import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

print("\n===================")
print("DIAGNOSTIC METRICS")
print("===================")

total = len(results)

# 1) Format / extraction diagnostics
missing_eid = 0
missing_tpl = 0
missing_any = 0

for r in results:
    if r["pred_eid"] is None:
        missing_eid += 1
    if r["pred_tpl"] is None:
        missing_tpl += 1
    if (r["pred_eid"] is None) or (r["pred_tpl"] is None):
        missing_any += 1

format_ok = total - missing_any

print(f"Total examples:                         {total}")
print(f"Format-compliant outputs (both fields present): "
      f"{format_ok}/{total} ({format_ok/total:.3f})")
print(f"Missing EventId field:                  {missing_eid}/{total} ({missing_eid/total:.3f})")
print(f"Missing EventTemplate field:            {missing_tpl}/{total} ({missing_tpl/total:.3f})")
print(f"Missing at least one field:             {missing_any}/{total} ({missing_any/total:.3f})")

# 2) Exact-match joint correctness (already in previous cell, but restated here)
joint_correct = 0
joint_defined = 0
for r in results:
    if r["gold_eid"] is None or r["gold_tpl"] is None:
        continue
    if r["pred_eid"] is None or r["pred_tpl"] is None:
        joint_defined += 1  # count as defined gold but missing pred
        continue

    joint_defined += 1
    if (r["pred_eid"] == r["gold_eid"] and
        normalize_template(r["pred_tpl"]) == normalize_template(r["gold_tpl"])):
        joint_correct += 1

joint_acc = joint_correct / joint_defined if joint_defined > 0 else 0.0
print(f"\nJoint exact-match correctness (ID + template): "
      f"{joint_correct}/{joint_defined} ({joint_acc:.3f})")

# 3) EventId per-class diagnostic report (like TruthfulQA-style "diagnostic accuracy")
#    Build y_true/y_pred again in case previous cell wasn't run.

y_true_eid_local = []
y_pred_eid_local = []

for r in results:
    if r["gold_eid"] is None:
        continue
    y_true_eid_local.append(r["gold_eid"])
    y_pred_eid_local.append(r["pred_eid"] if r["pred_eid"] is not None else "__NONE__")

if len(y_true_eid_local) == 0:
    print("\nNo EventId labels available for per-class diagnostics.")
else:
    y_true_eid_local = np.array(y_true_eid_local)
    y_pred_eid_local = np.array(y_pred_eid_local)

    print("\n=== EventId Per-class Classification Report ===")
    print(classification_report(
        y_true_eid_local,
        y_pred_eid_local,
        zero_division=0,
        digits=3,
    ))

    # Optional: show confusion matrix for the most frequent EventIds
    unique, counts = np.unique(y_true_eid_local, return_counts=True)
    sorted_idx = np.argsort(counts)[::-1]
    top_k = 10  # change if you want more/less
    top_labels = unique[sorted_idx][:top_k]

    mask = np.isin(y_true_eid_local, top_labels)
    cm = confusion_matrix(
        y_true_eid_local[mask],
        y_pred_eid_local[mask],
        labels=list(top_labels) + ["__NONE__"],
    )

    print(f"\nTop-{top_k} EventIds by frequency (plus '__NONE__' for missing preds):")
    print("Labels (columns/rows):", list(top_labels) + ["__NONE__"])
    print("Confusion matrix:\n", cm)



DIAGNOSTIC METRICS
Total examples:                         300
Format-compliant outputs (both fields present): 300/300 (1.000)
Missing EventId field:                  0/300 (0.000)
Missing EventTemplate field:            0/300 (0.000)
Missing at least one field:             0/300 (0.000)

Joint exact-match correctness (ID + template): 271/300 (0.903)

=== EventId Per-class Classification Report ===
              precision    recall  f1-score   support

          E1      0.000     0.000     0.000         2
         E10      1.000     1.000     1.000         6
        E105      1.000     1.000     1.000         2
        E106      1.000     1.000     1.000         1
        E107      0.000     0.000     0.000         1
        E114      1.000     1.000     1.000         5
        E119      1.000     1.000     1.000         4
        E125      1.000     1.000     1.000         2
        E128      1.000     1.000     1.000         3
        E129      0.667     1.000     0.800         4
  

In [7]:
# Cell X: TEMPLATE-ONLY METRICS (ignore EventId)

import numpy as np
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    classification_report,
    confusion_matrix,
)

print("\n===================")
print("TEMPLATE-ONLY METRICS")
print("===================")

total = len(results)

# 1) Basic counts: when does the model emit a template and when is it correct?
tpl_total       = 0   # how many examples have a gold template
tpl_pred_any    = 0   # how many predictions had a non-None template
tpl_exact_match = 0   # how many predictions exactly match the gold template (normalized)

gold_tpl_norm = []
pred_tpl_norm = []

for r in results:
    gold_tpl = r["gold_tpl"]
    pred_tpl = r["pred_tpl"]

    if gold_tpl is None:
        continue  # nothing to evaluate on this example

    tpl_total += 1
    gold_norm = normalize_template(gold_tpl)
    gold_tpl_norm.append(gold_norm)

    if pred_tpl is None:
        # model failed to produce a template
        pred_tpl_norm.append("__NONE__")
        continue

    tpl_pred_any += 1
    pred_norm = normalize_template(pred_tpl)
    pred_tpl_norm.append(pred_norm)

    if pred_norm == gold_norm:
        tpl_exact_match += 1

tpl_exact_acc      = tpl_exact_match / tpl_total if tpl_total > 0 else 0.0
tpl_pred_coverage  = tpl_pred_any / tpl_total    if tpl_total > 0 else 0.0

print(f"Total eval examples:                  {total}")
print(f"Examples with gold template:          {tpl_total}")
print(f"Predicted *any* template:             {tpl_pred_any}/{tpl_total} ({tpl_pred_coverage:.3f})")
print(f"Exact template matches (normalized):  {tpl_exact_match}/{tpl_total} ({tpl_exact_acc:.3f})")

# 2) Template-as-label classification metrics (macro/micro F1, etc.)
if tpl_total == 0:
    print("\nNo gold templates available; cannot compute template metrics.")
else:
    y_true_tpl = np.array(gold_tpl_norm)
    y_pred_tpl = np.array(pred_tpl_norm)

    # Treat each distinct normalized template as a class label
    acc_tpl     = accuracy_score(y_true_tpl, y_pred_tpl)
    prec_macro  = precision_score(y_true_tpl, y_pred_tpl, average="macro", zero_division=0)
    rec_macro   = recall_score(y_true_tpl,   y_pred_tpl, average="macro", zero_division=0)
    f1_macro    = f1_score(y_true_tpl,      y_pred_tpl, average="macro", zero_division=0)

    prec_micro  = precision_score(y_true_tpl, y_pred_tpl, average="micro", zero_division=0)
    rec_micro   = recall_score(y_true_tpl,   y_pred_tpl, average="micro", zero_division=0)
    f1_micro    = f1_score(y_true_tpl,      y_pred_tpl, average="micro", zero_division=0)

    print("\n=== Template Classification Metrics ===")
    print(f"Accuracy:          {acc_tpl:.3f}")
    print(f"Precision (macro): {prec_macro:.3f}")
    print(f"Recall (macro):    {rec_macro:.3f}")
    print(f"F1-Score (macro):  {f1_macro:.3f}")
    print(f"Precision (micro): {prec_micro:.3f}")
    print(f"Recall (micro):    {rec_micro:.3f}")
    print(f"F1-Score (micro):  {f1_micro:.3f}")

    # 3) Optional: per-template report + confusion matrix for top-k templates
    print("\n=== Per-template Classification Report ===")
    print(classification_report(
        y_true_tpl,
        y_pred_tpl,
        zero_division=0,
        digits=3,
    ))

    # Confusion matrix for the most frequent templates
    unique, counts = np.unique(y_true_tpl, return_counts=True)
    sorted_idx = np.argsort(counts)[::-1]
    top_k = 10  # change if you want a different number
    top_labels = unique[sorted_idx][:top_k]

    mask = np.isin(y_true_tpl, top_labels)
    cm = confusion_matrix(
        y_true_tpl[mask],
        y_pred_tpl[mask],
        labels=list(top_labels) + ["__NONE__"],
    )

    print(f"\nTop-{top_k} templates by frequency (plus '__NONE__' for missing preds):")
    print("Labels (columns/rows):", list(top_labels) + ["__NONE__"])
    print("Confusion matrix:\n", cm)



TEMPLATE-ONLY METRICS
Total eval examples:                  300
Examples with gold template:          300
Predicted *any* template:             300/300 (1.000)
Exact template matches (normalized):  277/300 (0.923)

=== Template Classification Metrics ===
Accuracy:          0.923
Precision (macro): 0.851
Recall (macro):    0.847
F1-Score (macro):  0.848
Precision (micro): 0.923
Recall (micro):    0.923
F1-Score (micro):  0.923

=== Per-template Classification Report ===
                                                                                                                                                                                                                                                                                                                                                                                                                   precision    recall  f1-score   support

                                                                                 