In [33]:
from collections import defaultdict
import json

# Lista dos arquivos .jsonl (com barra invertida corrigida para barra normal se estiver em ambiente Unix/Linux)
jsonl_files = [
    "results\samples_gsm8k_cot_62777.jsonl",
    "results\samples_gsm8k_cot_17456.jsonl",
    "results\samples_gsm8k_cot_46379.jsonl",
    "results\samples_gsm8k_cot_15136.jsonl",
]

# Agrupar tentativas por doc_id
samples_by_doc = defaultdict(list)

  "results\samples_gsm8k_cot_62777.jsonl",
  "results\samples_gsm8k_cot_17456.jsonl",
  "results\samples_gsm8k_cot_46379.jsonl",
  "results\samples_gsm8k_cot_15136.jsonl",


In [34]:
samples_by_doc

defaultdict(list, {})

In [35]:
for path in jsonl_files:
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            sample = json.loads(line)
            doc_id = sample["doc_id"]
            filter_type = sample.get("filter", "unknown")  # pegar o filtro
            is_correct = sample.get("exact_match", 0) == 1.0
            samples_by_doc[(doc_id, filter_type)].append(is_correct)

In [36]:
samples_by_doc

defaultdict(list,
            {(0, 'strict-match'): [True, True, True, True],
             (1, 'strict-match'): [True, True, True, True],
             (2, 'strict-match'): [False, True, True, False],
             (3, 'strict-match'): [True, True, True, True],
             (4, 'strict-match'): [True, True, True, True],
             (5, 'strict-match'): [True, True, True, True],
             (6, 'strict-match'): [True, True, True, True],
             (7, 'strict-match'): [False, False, False, False],
             (8, 'strict-match'): [False, False, False, False],
             (9, 'strict-match'): [True, True, True, True],
             (10, 'strict-match'): [True, True, True, True],
             (11, 'strict-match'): [True, True, True, True],
             (12, 'strict-match'): [False, False, False, False],
             (13, 'strict-match'): [False, True, True, True],
             (14, 'strict-match'): [True, True, True, True],
             (15, 'strict-match'): [True, True, True, True],
 

In [37]:
# Função para calcular pass@N
def pass_at_n(samples_by_doc, N):
    total = len(samples_by_doc)
    passed = sum(any(resps[:N]) for resps in samples_by_doc.values())
    print(passed)
    return passed / total if total > 0 else 0.0

In [38]:
# Agora calculamos pass@N para cada filtro separadamente:
results_by_filter = {}

# Primeiro identificar filtros únicos
filtros = set(f for _, f in samples_by_doc.keys())

for f in filtros:
    filtered_samples = {k: v for k, v in samples_by_doc.items() if k[1] == f}
    max_n = max(len(v) for v in filtered_samples.values())
    results = {f"pass@{n}": pass_at_n(filtered_samples, n) for n in [1, 2, 3, 4] if n <= max_n}
    results["total_problems"] = len(filtered_samples)
    results["attempts_per_problem"] = max_n
    results_by_filter[f] = results

231
243
249
252
218
232
238
242


In [39]:
results_by_filter

{'strict-match': {'pass@1': 0.77,
  'pass@2': 0.81,
  'pass@3': 0.83,
  'pass@4': 0.84,
  'total_problems': 300,
  'attempts_per_problem': 4},
 'flexible-extract': {'pass@1': 0.7266666666666667,
  'pass@2': 0.7733333333333333,
  'pass@3': 0.7933333333333333,
  'pass@4': 0.8066666666666666,
  'total_problems': 300,
  'attempts_per_problem': 4}}

In [1]:
import os
import json
import re
from collections import defaultdict

# Seus arquivos
jsonl_files = [
    "results/samples_gsm8k_cot_62777.jsonl",
    "results/samples_gsm8k_cot_17456.jsonl",
    "results/samples_gsm8k_cot_46379.jsonl",
    "results/samples_gsm8k_cot_15136.jsonl",
]

# Regex para extrair o index
def extract_index(path):
    match = re.search(r"cot_(\d+)", path)
    return match.group(1) if match else "unknown"

# Coletar samples agrupando por (index, filter, doc_id)
samples = defaultdict(lambda: defaultdict(list))  # samples[index][filter] = list of listas

for path in jsonl_files:
    index = extract_index(path)
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            sample = json.loads(line)
            doc_id = sample["doc_id"]
            filter_type = sample.get("filter", "unknown")
            is_correct = sample.get("exact_match", 0) == 1.0
            key = (index, filter_type, doc_id)
            samples[(index, filter_type)][doc_id].append(is_correct)

# Função pass@N
def pass_at_n(samples_by_doc, N):
    total = len(samples_by_doc)
    passed = sum(any(resps[:N]) for resps in samples_by_doc.values())
    return passed / total if total > 0 else 0.0

# Resultados por index e filtro
results_all = {}

for (index, filter_type), doc_dict in samples.items():
    max_n = max(len(v) for v in doc_dict.values())
    result = {
        f"pass@{n}": pass_at_n(doc_dict, n)
        for n in [1, 2, 3, 5] if n <= max_n
    }
    result["total_problems"] = len(doc_dict)
    result["attempts_per_problem"] = max_n
    results_all[(index, filter_type)] = result

# Mostrar resultados organizados
from pprint import pprint
pprint(results_all)

# Agora identificar o melhor index por filtro com base em pass@1
best_by_filter = {}

for filter_type in set(f for _, f in results_all):
    filtered = {
        index: metrics["pass@1"]
        for (index, f), metrics in results_all.items()
        if f == filter_type and "pass@1" in metrics
    }
    best = max(filtered.items(), key=lambda x: x[1], default=("none", 0.0))
    best_by_filter[filter_type] = {"best_index": best[0], "pass@1": best[1]}

print("\n🏆 Melhores índices por filtro (baseado em pass@1):")
pprint(best_by_filter)


{('15136', 'flexible-extract'): {'attempts_per_problem': 1,
                                 'pass@1': 0.7133333333333334,
                                 'total_problems': 300},
 ('15136', 'strict-match'): {'attempts_per_problem': 1,
                             'pass@1': 0.7733333333333333,
                             'total_problems': 300},
 ('17456', 'flexible-extract'): {'attempts_per_problem': 1,
                                 'pass@1': 0.6966666666666667,
                                 'total_problems': 300},
 ('17456', 'strict-match'): {'attempts_per_problem': 1,
                             'pass@1': 0.7766666666666666,
                             'total_problems': 300},
 ('46379', 'flexible-extract'): {'attempts_per_problem': 1,
                                 'pass@1': 0.7,
                                 'total_problems': 300},
 ('46379', 'strict-match'): {'attempts_per_problem': 1,
                             'pass@1': 0.7866666666666666,
                        