In [1]:
import re

HEADER_RE = re.compile(
    r"^\s*\[(?P<span>\d+-\d+)\]\s*(?P<label>[^,]+),\s*File:\s*(?P<filename>[^,]+)",
    flags=re.IGNORECASE
)

def extract_ann_text(path, txt_path, encoding ="utf-8"):
    with open(path, "r", encoding=encoding) as fh:
        lines = fh.readlines()

    # Find indices of header lines and capture groups
    headers = []  # list of tuples (index, matchobj)
    for i, line in enumerate(lines):
        m = HEADER_RE.match(line)
        if m:
            headers.append((i, m))

    results = []
    if not headers:
        return results

    # For each found header, slice the following lines until next header (or EOF)
    for idx, (line_idx, match) in enumerate(headers):
        next_idx = headers[idx + 1][0] if (idx + 1) < len(headers) else len(lines)

        # Extract body: lines after the header line up to next header index
        body_lines = [ln.rstrip("\n") for ln in lines[line_idx + 1 : next_idx]]
        body = "\n".join(body_lines).strip() or None

        span = match.group("span").strip() if match.group("span") else None
        start_span = int(span.split("-")[0])
        end_span = int(span.split("-")[1])
        filename = match.group("filename").strip() if match.group("filename") else None
        label = match.group("label").strip() if match.group("label") else None

        with open(f"{txt_path}/{filename}", 'r') as f:
            full_text = f.read()

        results.append({
            "filename": filename,
            "start": start_span,
            "end": end_span,
            "label": label,
            "text": body, 
            "full_text": full_text,
        })
    
    return results

In [2]:
gold = "../annotated_data/first_ten_agreed.txt"

txt_path = "../data/to_annotate"
goldset = extract_ann_text(gold, txt_path)

In [3]:
len(goldset)

64

In [4]:
from difflib import SequenceMatcher
import unicodedata

def _normalize_with_map(s: str):
    """
    Normalize in a way that often matches annotation tools:
    - Unicode normalize (NFC)
    - Convert CRLF/CR to LF
    - Replace NBSP with space
    - Remove common zero-width chars
    Returns: (normalized_string, norm_index -> original_index map)
    """
    s = unicodedata.normalize("NFC", s)

    norm_chars = []
    norm_to_orig = []

    i = 0
    while i < len(s):
        ch = s[i]

        # normalize newlines
        if ch == "\r":
            # treat \r\n or \r as a single \n
            if i + 1 < len(s) and s[i + 1] == "\n":
                # map the normalized '\n' to the start of the pair
                norm_chars.append("\n")
                norm_to_orig.append(i)
                i += 2
                continue
            else:
                norm_chars.append("\n")
                norm_to_orig.append(i)
                i += 1
                continue

        # normalize NBSP
        if ch == "\u00A0":
            ch = " "

        # drop zero-width characters (common culprits)
        if ch in ("\u200b", "\u200c", "\u200d", "\ufeff"):
            i += 1
            continue

        norm_chars.append(ch)
        norm_to_orig.append(i)
        i += 1

    return "".join(norm_chars), norm_to_orig

def fuzzy_span_start_end_mapped(
    span_text: str,
    full_text: str,
    threshold: float = 0.80,
    window_slack: int = 20,
    output_1_based_inclusive: bool = False,
):
    if not span_text or not full_text:
        return -1, -1

    norm_full, norm_map = _normalize_with_map(full_text)
    norm_span, _ = _normalize_with_map(span_text)

    target_len = len(norm_span)
    best_score = 0.0
    best_start = -1
    best_end = -1  # exclusive in normalized coordinates

    min_len = max(1, target_len - window_slack)
    max_len = min(len(norm_full), target_len + window_slack)

    for win_len in range(min_len, max_len + 1):
        for i in range(0, len(norm_full) - win_len + 1):
            candidate = norm_full[i:i + win_len]
            score = SequenceMatcher(None, norm_span, candidate).ratio()
            if score > best_score:
                best_score = score
                best_start = i
                best_end = i + win_len

    if best_score < threshold or best_start < 0:
        return -1, -1

    # Map normalized [best_start, best_end) back to ORIGINAL full_text indices.
    # start maps directly; end maps via last char + 1 (to keep exclusive end).
    orig_start = norm_map[best_start]
    orig_end = norm_map[best_end - 1] + 1

    if output_1_based_inclusive:
        # convert [orig_start, orig_end) to 1-based inclusive
        return orig_start + 1, orig_end  # end becomes inclusive in 1-based

    return orig_start, orig_end



def exact_span_start_end_mapped(
    span_text: str,
    full_text: str,
    *,
    output_1_based_inclusive: bool = False,
):
    """
    Exact substring match in normalized space, then map back to original indices.

    Uses the same _normalize_with_map() convention as fuzzy_span_start_end_mapped,
    so indices are directly comparable.
    """
    if not span_text or not full_text:
        return -1, -1

    norm_full, norm_map = _normalize_with_map(full_text)
    norm_span, _ = _normalize_with_map(span_text)

    idx = norm_full.find(norm_span)
    if idx == -1:
        return -1, -1

    norm_start = idx
    norm_end = idx + len(norm_span)  # exclusive in normalized space

    # Map normalized [norm_start, norm_end) back to ORIGINAL full_text indices.
    orig_start = norm_map[norm_start]
    orig_end = norm_map[norm_end - 1] + 1  # exclusive end

    if output_1_based_inclusive:
        # convert [orig_start, orig_end) â†’ (1-based, inclusive)
        return orig_start + 1, orig_end

    return orig_start, orig_end

def add_mapped_indices_to_records(
    records,
    *,
    use_fuzzy_fallback: bool = True,
    fuzzy_threshold: float = 0.80,
    window_slack: int = 20,
    output_1_based_inclusive: bool = True,
):
    """
    For each record in a list of dicts with keys 'text' and 'full_text',
    compute start/end using the same normalization/mapping as
    fuzzy_span_start_end_mapped.

    Mutates records in-place and also returns them.
    """
    for r in records:
        span_text = r["text"]
        full_text = r["full_text"]

        # 1) Try exact normalized match
        s, e = exact_span_start_end_mapped(
            span_text,
            full_text,
            output_1_based_inclusive=output_1_based_inclusive,
        )

        # 2) Optionally fall back to fuzzy if exact fails
        if use_fuzzy_fallback and s == -1:
            s, e = fuzzy_span_start_end_mapped(
                span_text,
                full_text,
                threshold=fuzzy_threshold,
                window_slack=window_slack,
                output_1_based_inclusive=output_1_based_inclusive,
            )

        r["start"], r["end"] = s, e

    return records


In [5]:
goldset = add_mapped_indices_to_records(goldset)

In [6]:
# Preprocess to remove 'References' section, keep only relevant text
for text in goldset:
    keep_text = text['full_text'].split("References")[0]
    text['full_text'] = keep_text

goldset_sorted = []
for i in range(1,11):
    paper = []
    for text in goldset:
        if text['filename'] == f"paper_{i}.txt":
            paper.append(text)
    goldset_sorted.append(paper)

In [7]:
import json 

folder = "../annotated_data"
filename = "goldset_sorted.json"
with open(f"{folder}/{filename}", 'w') as f:
    json.dump(goldset_sorted, f, indent=4)

### Evaluation using token F1

In [8]:
# Run this block if youre resuming work after restarting your kernel
import json 

fewshot_path = "../experiments/final/fewshot_preds.json"
zeroshot_path = "../experiments/final/zeroshot_preds.json"

with open(fewshot_path, 'r') as f:
    fewshot = json.load(f)

with open(zeroshot_path, 'r') as f:
    zeroshot = json.load(f)


In [9]:
from collections import Counter
from typing import List, Tuple, Dict, Any

Span = Tuple[int, int]  # (start_char, end_char_inclusive)

def span_f1_exact(gold_spans, pred_spans, dedupe=True):

    def valid(span: Span):
        s, e = span
        return isinstance(s, int) and isinstance(e, int) and e >= s

    gold = [sp for sp in gold_spans if valid(sp)]
    pred = [sp for sp in pred_spans if valid(sp)]

    if dedupe:
        gold_set = set(gold)
        pred_set = set(pred)

        matched = sorted(gold_set & pred_set)
        tp = len(matched)
        fp = len(pred_set - gold_set)
        fn = len(gold_set - pred_set)

        unmatched_pred = sorted(pred_set - gold_set)
        unmatched_gold = sorted(gold_set - pred_set)
    else:
        gold_c = Counter(gold)
        pred_c = Counter(pred)

        matched = []
        tp = 0
        for sp in (gold_c.keys() & pred_c.keys()):
            k = min(gold_c[sp], pred_c[sp])
            tp += k
            matched.extend([sp] * k)

        fp = sum((pred_c - gold_c).values())
        fn = sum((gold_c - pred_c).values())

        unmatched_pred = []
        for sp, k in (pred_c - gold_c).items():
            unmatched_pred.extend([sp] * k)

        unmatched_gold = []
        for sp, k in (gold_c - pred_c).items():
            unmatched_gold.extend([sp] * k)

    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "tp": tp,
        "fp": fp,
        "fn": fn,
        "matched_spans": matched,
        "unmatched_pred_spans": unmatched_pred,
        "unmatched_gold_spans": unmatched_gold,
    }

In [10]:
def flatten_predictions_by_paper(predictions, num_papers=10):
    """
    predictions: dict keyed by filename (new format)
    Returns: list of length num_papers, each a list of span dicts
    """
    papers = [[] for _ in range(num_papers)]

    for i in range(num_papers):
        fn = f"paper_{i+1}.txt"
        if fn not in predictions:
            continue

        paper_obj = predictions[fn]
        for span in paper_obj.get("all_spans", []):
            if span.get("start") == -1:
                continue

            papers[i].append({
                "filename": fn,
                "span": span.get("span_text"),
                "start": span.get("start"),
                "end": span.get("end"),
                "label": span.get("gpt_label"),
            })

    return papers

fewshot = flatten_predictions_by_paper(fewshot, num_papers=10)
zeroshot = flatten_predictions_by_paper(zeroshot, num_papers=10)


In [11]:
goldset_spans = []
predset_fewshot_spans = []
predset_zeroshot_spans = []
categories = ['Unsupported claim', 'Format', 'Coherence', 'Lacks synthesis']

for i in range(0,10):
    papers = {"Coherence": [], 
                "Format": [], 
                "Unsupported claim": [], 
                "Lacks synthesis": []}
    for text in goldset_sorted[i]:
        for category in categories:
            if text['label'].lower() == category.lower():
                papers[category].append((text['start'], text['end']))
    goldset_spans.append(papers)

for i in range(0,10):
    papers = {"Coherence": [], 
                "Format": [], 
                "Unsupported claim": [], 
                "Lacks synthesis": []}
    for text in fewshot[i]:
        for category in categories:
            if text['label'].lower() == category.lower():
                papers[category].append((text['start'], text['end']))
    predset_fewshot_spans.append(papers)

for i in range(0,10):
    papers = {"Coherence": [], 
                "Format": [], 
                "Unsupported claim": [], 
                "Lacks synthesis": []}
    for text in zeroshot[i]:
        for category in categories:
            if text['label'].lower() == category.lower():
                papers[category].append((text['start'], text['end']))
    predset_zeroshot_spans.append(papers)

print("Gold set: ", goldset_spans)
print("Few shot set: ", predset_fewshot_spans)
print("Zero shot set: ", predset_zeroshot_spans)

Gold set:  [{'Coherence': [(742, 1101), (1508, 1713)], 'Format': [(1705, 1713)], 'Unsupported claim': [(468, 557), (993, 997), (1091, 1100), (1373, 1506)], 'Lacks synthesis': []}, {'Coherence': [(2037, 2506), (3462, 4014)], 'Format': [(1755, 1777), (2576, 2592), (4882, 4896)], 'Unsupported claim': [], 'Lacks synthesis': [(868, 1363), (1366, 2035)]}, {'Coherence': [], 'Format': [], 'Unsupported claim': [(716, 1194), (3148, 3155), (3191, 3232), (3242, 3248), (3451, 3458)], 'Lacks synthesis': []}, {'Coherence': [], 'Format': [], 'Unsupported claim': [(31, 180), (182, 386), (596, 726), (729, 1010), (1203, 1396), (1398, 1540), (1542, 1589), (4489, 4651)], 'Lacks synthesis': []}, {'Coherence': [], 'Format': [], 'Unsupported claim': [(760, 797), (803, 833), (2362, 2376)], 'Lacks synthesis': []}, {'Coherence': [(2028, 2462)], 'Format': [(369, 388), (2051, 2071)], 'Unsupported claim': [(62, 224), (1643, 1769)], 'Lacks synthesis': [(1791, 2026)]}, {'Coherence': [(394, 735), (1225, 1857)], 'Forma

In [12]:
print(goldset_spans[0])
print(predset_fewshot_spans[0])

{'Coherence': [(742, 1101), (1508, 1713)], 'Format': [(1705, 1713)], 'Unsupported claim': [(468, 557), (993, 997), (1091, 1100), (1373, 1506)], 'Lacks synthesis': []}
{'Coherence': [(731, 1083)], 'Format': [(1680, 1687)], 'Unsupported claim': [(462, 551), (977, 982), (1073, 1083), (1486, 1599), (1680, 1687), (1776, 1858)], 'Lacks synthesis': [(553, 1228)]}


### Micro averaged scores per category:
- For each category, gather all spans across papers and compute one single F1/precision/recall using total TP/FP/FN

In [40]:
categories = ['Unsupported claim', 'Format', 'Coherence', 'Lacks synthesis']

fewshot_results_per_category = {}

for category in categories:
    all_gold = []
    all_pred = []

    for i in range(0, 10):
        all_gold.extend(goldset_spans[i].get(category, []))
        all_pred.extend(predset_fewshot_spans[i].get(category, []))

    res = span_f1_exact(all_gold, all_pred, dedupe=False)

    fewshot_results_per_category[category] = res

In [41]:
# save results to a json file 
import json

folder = "../experiments"
filename = "fewshot_micro_avg_first_ten.json"

with open(f"{folder}/{filename}", 'w') as f:
    json.dump(fewshot_results_per_category, f, indent=4)

for category in fewshot_results_per_category:
    print(f"Scores for {category}: ")
    print(f"Precision: {fewshot_results_per_category[category]['precision']}")
    print(f"Recall: {fewshot_results_per_category[category]['recall']}")
    print(f"Micro F1: {fewshot_results_per_category[category]['f1']}")
    print(f"True positives: {fewshot_results_per_category[category]['tp']}")
    print(f"False positives: {fewshot_results_per_category[category]['fp']}")
    print(f"False negatives: {fewshot_results_per_category[category]['fn']} \n")

Scores for Unsupported claim: 
Precision: 0.0
Recall: 0.0
Micro F1: 0.0
True positives: 0
False positives: 51
False negatives: 41 

Scores for Format: 
Precision: 0.0
Recall: 0.0
Micro F1: 0.0
True positives: 0
False positives: 9
False negatives: 10 

Scores for Coherence: 
Precision: 0.0
Recall: 0.0
Micro F1: 0.0
True positives: 0
False positives: 10
False negatives: 7 

Scores for Lacks synthesis: 
Precision: 0.0
Recall: 0.0
Micro F1: 0.0
True positives: 0
False positives: 12
False negatives: 6 



In [13]:
categories = ['Unsupported claim', 'Format', 'Coherence', 'Lacks synthesis']

zeroshot_results_per_category = {}

for category in categories:
    all_gold = []
    all_pred = []

    for i in range(0, 10):
        all_gold.extend(goldset_spans[i].get(category, []))
        all_pred.extend(predset_zeroshot_spans[i].get(category, []))

    res = span_f1_exact(all_gold, all_pred, dedupe=False)

    zeroshot_results_per_category[category] = res

In [14]:
# save results to a json file 
import json 

folder = "../experiments"
filename = "zeroshot_micro_avg_first_ten.json"

with open(f"{folder}/{filename}", 'w') as f:
    json.dump(zeroshot_results_per_category, f, indent=4)

for category in zeroshot_results_per_category:
    print(f"Scores for {category}: ")
    print(f"Precision: {zeroshot_results_per_category[category]['precision']}")
    print(f"Recall: {zeroshot_results_per_category[category]['recall']}")
    print(f"Micro F1: {zeroshot_results_per_category[category]['f1']}")
    print(f"True positives: {zeroshot_results_per_category[category]['tp']}")
    print(f"False positives: {zeroshot_results_per_category[category]['fp']}")
    print(f"False negatives: {zeroshot_results_per_category[category]['fn']} \n")

Scores for Unsupported claim: 
Precision: 0.0
Recall: 0.0
Micro F1: 0.0
True positives: 0
False positives: 52
False negatives: 41 

Scores for Format: 
Precision: 0.0
Recall: 0.0
Micro F1: 0.0
True positives: 0
False positives: 7
False negatives: 10 

Scores for Coherence: 
Precision: 0.0
Recall: 0.0
Micro F1: 0.0
True positives: 0
False positives: 10
False negatives: 7 

Scores for Lacks synthesis: 
Precision: 0.0
Recall: 0.0
Micro F1: 0.0
True positives: 0
False positives: 9
False negatives: 6 



### Macro average scores per category:

In [15]:
from typing import Dict, List, Tuple, Any

Span = Tuple[int, int]

def macro_average_span_f1(
    goldset_spans, predset_spans, categories, dedupe = True, unit = "category",  # "paper_category" | "paper" | "category"
    skip_empty_gold = True):
    f1s = []
    details = []

    n = min(len(goldset_spans), len(predset_spans))

    if unit == "paper_category":
        for i in range(n):
            for cat in categories:
                gold = goldset_spans[i].get(cat, [])
                pred = predset_spans[i].get(cat, [])
                res = span_f1_exact(gold, pred, dedupe=dedupe)

                if skip_empty_gold and res["tp"] == 0 and res["fn"] == 0:
                    continue

                f1s.append(res["f1"])
                details.append({"paper": i + 1, "category": cat, **res})

    elif unit == "paper":
        for i in range(n):
            gold_all, pred_all = [], []
            for cat in categories:
                gold_all.extend(goldset_spans[i].get(cat, []))
                pred_all.extend(predset_spans[i].get(cat, []))

            res = span_f1_exact(gold_all, pred_all, dedupe=dedupe)

            if skip_empty_gold and res["tp"] == 0 and res["fn"] == 0:
                continue

            f1s.append(res["f1"])
            details.append({"paper": i + 1, **res})

    elif unit == "category":
        for cat in categories:
            gold_all, pred_all = [], []
            for i in range(n):
                gold_all.extend(goldset_spans[i].get(cat, []))
                pred_all.extend(predset_spans[i].get(cat, []))

            res = span_f1_exact(gold_all, pred_all, dedupe=dedupe)

            if skip_empty_gold and res["tp"] == 0 and res["fn"] == 0:
                continue

            f1s.append(res["f1"])
            details.append({"category": cat, **res})

    else:
        raise ValueError("unit must be one of: 'paper_category', 'paper', 'category'")

    macro_f1 = sum(f1s) / len(f1s) if f1s else 0.0

    return {
        "macro_f1": macro_f1,
        "num_units_scored": len(f1s),
        "unit": unit,
        "skip_empty_gold": skip_empty_gold,
        "details": details,  # remove if too large
    }


In [16]:
categories = ['Unsupported claim', 'Format', 'Coherence', 'Lacks synthesis']

out = macro_average_span_f1(
    goldset_spans,
    predset_fewshot_spans,
    categories,
    dedupe=True,
    unit="category",
    skip_empty_gold=True,
)

print(out["macro_f1"], out["num_units_scored"])


0.0 4


### Token overlap 

In [17]:
import re
import json
from typing import List, Tuple, Dict, Any

Span = Tuple[int, int]  # (start_char, end_char) end-exclusive by default
_WORD_RE = re.compile(r"\S+")

def whitespace_tokenize_with_offsets(text: str):
    return [(m.group(0), m.start(), m.end()) for m in _WORD_RE.finditer(text)]

def spans_to_token_indices(
    text: str,
    spans: List[Span],
    token_offsets=None,
    inclusive_end: bool = False,
):
    if token_offsets is None:
        token_offsets = whitespace_tokenize_with_offsets(text)

    covered = set()
    for i, (_tok, tstart, tend) in enumerate(token_offsets):
        for s, e in spans:
            span_start = s
            span_end = e + 1 if inclusive_end else e
            if not (tend <= span_start or tstart >= span_end):
                covered.add(i)
                break
    return covered

def token_overlap_metrics_aggregate(
    texts: List[str],
    gold_spans_list: List[List[Span]],
    pred_spans_list: List[List[Span]],
    *,
    method,
    filenames: List[str],                 # <-- add this
    tokenize_with_offsets=None,
    inclusive_end: bool = False,
    out_prefix: str = "token_overlap",
) -> Dict[str, Any]:
    """
    Micro-averaged token overlap metrics across all texts + dumps TP/FP/FN token instances.

    Produces:
      {out_prefix}_tp.json
      {out_prefix}_fp.json
      {out_prefix}_fn.json
    """
    if len(filenames) != len(texts):
        raise ValueError(f"filenames must be same length as texts: {len(filenames)} vs {len(texts)}")

    TP: List[Dict[str, Any]] = []
    FP: List[Dict[str, Any]] = []
    FN: List[Dict[str, Any]] = []

    total_tp = total_fp = total_fn = 0

    for doc_id, (fn, text, gold_spans, pred_spans) in enumerate(
        zip(filenames, texts, gold_spans_list, pred_spans_list)
    ):
        token_offsets = (
            whitespace_tokenize_with_offsets(text)
            if tokenize_with_offsets is None
            else tokenize_with_offsets(text)
        )

        gold_idxs = spans_to_token_indices(text, gold_spans, token_offsets, inclusive_end)
        pred_idxs = spans_to_token_indices(text, pred_spans, token_offsets, inclusive_end)

        tp_idxs = gold_idxs & pred_idxs
        fp_idxs = pred_idxs - gold_idxs
        fn_idxs = gold_idxs - pred_idxs

        total_tp += len(tp_idxs)
        total_fp += len(fp_idxs)
        total_fn += len(fn_idxs)

        # Store token records (now includes filename)
        for i in sorted(tp_idxs):
            tok, s, e = token_offsets[i]
            TP.append({
                "doc_id": doc_id,
                "filename": fn,
                "token_index": i,
                "token": tok,
                "char_start": s,
                "char_end": e,
            })

        for i in sorted(fp_idxs):
            tok, s, e = token_offsets[i]
            FP.append({
                "doc_id": doc_id,
                "filename": fn,
                "token_index": i,
                "token": tok,
                "char_start": s,
                "char_end": e,
            })

        for i in sorted(fn_idxs):
            tok, s, e = token_offsets[i]
            FN.append({
                "doc_id": doc_id,
                "filename": fn,
                "token_index": i,
                "token": tok,
                "char_start": s,
                "char_end": e,
            })

    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) else 0.0
    recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) else 0.0
    f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0

    # Dump JSON files
    with open(f"../experiments/to_examine/{method}/{out_prefix}_tp.json", "w", encoding="utf-8") as f:
        json.dump(TP, f, indent=2, ensure_ascii=False)

    with open(f"../experiments/to_examine/{method}/{out_prefix}_fp.json", "w", encoding="utf-8") as f:
        json.dump(FP, f, indent=2, ensure_ascii=False)

    with open(f"../experiments/to_examine/{method}/{out_prefix}_fn.json", "w", encoding="utf-8") as f:
        json.dump(FN, f, indent=2, ensure_ascii=False)

    return {
        "tp": total_tp,
        "fp": total_fp,
        "fn": total_fn,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "tp_file": f"{out_prefix}_tp.json",
        "fp_file": f"{out_prefix}_fp.json",
        "fn_file": f"{out_prefix}_fn.json",
    }

In [18]:
categories = ['Unsupported claim', 'Format', 'Coherence', 'Lacks synthesis']

for category in categories:
    all_texts = []
    all_gold_spans = []
    all_pred_spans = []

    for i in range(0, 10):
        text_content = goldset_sorted[i][0]['full_text']  
        all_texts.append(text_content)
        gold_spans = goldset_spans[i].get(category, [])
        pred_spans = predset_fewshot_spans[i].get(category, [])
        all_gold_spans.append(gold_spans)
        all_pred_spans.append(pred_spans)

    token_metrics = token_overlap_metrics_aggregate(
        all_texts,
        all_gold_spans,
        all_pred_spans,
        tokenize_with_offsets=None,
        inclusive_end=True, 
        method='fewshot',  
        filenames=[f"paper_{i+1}.txt" for i in range(10)],  # <-- pass filenames
        out_prefix=f"fewshot_{category.replace(' ', '_').lower()}"
    )

    print(f"Token-level scores for {category}: ")
    print(f"Precision: {round(token_metrics['precision'], 4)}")
    print(f"Recall: {round(token_metrics['recall'], 4)}")
    print(f"Micro F1: {round(token_metrics['f1'], 4)}")
    print(token_metrics["tp_file"], token_metrics["fp_file"], token_metrics["fn_file"])
    

Token-level scores for Unsupported claim: 
Precision: 0.4681
Recall: 0.6661
Micro F1: 0.5498
fewshot_unsupported_claim_tp.json fewshot_unsupported_claim_fp.json fewshot_unsupported_claim_fn.json
Token-level scores for Format: 
Precision: 0.2667
Recall: 0.1951
Micro F1: 0.2254
fewshot_format_tp.json fewshot_format_fp.json fewshot_format_fn.json
Token-level scores for Coherence: 
Precision: 0.2745
Recall: 0.4989
Micro F1: 0.3542
fewshot_coherence_tp.json fewshot_coherence_fp.json fewshot_coherence_fn.json
Token-level scores for Lacks synthesis: 
Precision: 0.557
Recall: 0.7924
Micro F1: 0.6542
fewshot_lacks_synthesis_tp.json fewshot_lacks_synthesis_fp.json fewshot_lacks_synthesis_fn.json


In [50]:
categories = ['Unsupported claim', 'Format', 'Coherence', 'Lacks synthesis']

for category in categories:
    all_texts = []
    all_gold_spans = []
    all_pred_spans = []

    for i in range(0, 10):
        text_content = goldset_sorted[i][0]['full_text']  # assuming full_text is available in goldset_sorted
        all_texts.append(text_content)

        gold_spans = goldset_spans[i].get(category, [])
        pred_spans = predset_zeroshot_spans[i].get(category, [])

        all_gold_spans.append(gold_spans)
        all_pred_spans.append(pred_spans)

    token_metrics = token_overlap_metrics_aggregate(
        all_texts,
        all_gold_spans,
        all_pred_spans,
        tokenize_with_offsets=None,
        inclusive_end=True, 
        method='zeroshot',  
        filenames=[f"paper_{i+1}.txt" for i in range(10)],  # <-- pass filenames
        out_prefix=f"zeroshot_{category.replace(' ', '_').lower()}"
    )

    print(f"Token-level scores for {category}: ")
    print(f"Precision: {round(token_metrics['precision'], 4)}")
    print(f"Recall: {round(token_metrics['recall'], 4)}")
    print(f"Micro F1: {round(token_metrics['f1'], 4)}")
    print(token_metrics["tp_file"], token_metrics["fp_file"], token_metrics["fn_file"])

Token-level scores for Unsupported claim: 
Precision: 0.4451
Recall: 0.624
Micro F1: 0.5196
zeroshot_unsupported_claim_tp.json zeroshot_unsupported_claim_fp.json zeroshot_unsupported_claim_fn.json
Token-level scores for Format: 
Precision: 0.3
Recall: 0.2195
Micro F1: 0.2535
zeroshot_format_tp.json zeroshot_format_fp.json zeroshot_format_fn.json
Token-level scores for Coherence: 
Precision: 0.1505
Recall: 0.3045
Micro F1: 0.2014
zeroshot_coherence_tp.json zeroshot_coherence_fp.json zeroshot_coherence_fn.json
Token-level scores for Lacks synthesis: 
Precision: 0.5027
Recall: 0.5727
Micro F1: 0.5354
zeroshot_lacks_synthesis_tp.json zeroshot_lacks_synthesis_fp.json zeroshot_lacks_synthesis_fn.json
