In [None]:
from torchmetrics.text import CharErrorRate
import json

cer_metric = CharErrorRate()

def evaluate_cer(gt_file, pred_file):
    with open(gt_file, 'r', encoding='utf-8') as f:
        gt_data = json.load(f)
    with open(pred_file, 'r', encoding='utf-8') as f:
        pred_data = [json.loads(line) for line in f]

    # 1️⃣ GT 처리
    gt_dict = {}
    for item in gt_data:
        image_name = item['image']
        gt_texts = [ann['text'] for ann in item['ground_truth']]
        gt_text = ' '.join(gt_texts)
        gt_dict[image_name] = gt_text

    # 2️⃣ Prediction 처리
    pred_dict = {}
    for item in pred_data:
        image_name = item['image']
        pred_text = item['conversations'][-1]['value'].strip()
        pred_dict[image_name] = pred_text

    # 3️⃣ CER 계산
    total_cer = 0
    count = 0
    for image in gt_dict.keys():
        if image in pred_dict:
            gt_text = gt_dict[image]
            pred_text = pred_dict[image]
            cer = cer_metric([pred_text], [gt_text]).item()
            print(f"{image}: CER={cer:.4f}")
            total_cer += cer
            count += 1
        else:
            print(f"⚠️ {image}에 대한 예측 없음")

    avg_cer = total_cer / count if count > 0 else 0
    print(f"📊 평균 CER: {avg_cer:.4f}")

# 사용 예시
# evaluate_cer("benchmark_gt.json", "predictions.jsonl")


In [2]:
pip install torch torchmetrics

Collecting torch
  Downloading torch-2.7.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting torchmetrics
  Downloading torchmetrics-1.7.2-py3-none-any.whl.metadata (21 kB)
Collecting filelock (from torch)
  Using cached filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.5-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1

In [3]:
import torch
from torchmetrics.text import CharErrorRate, WordErrorRate

# CER, WER 계산기 초기화
calc_cer = CharErrorRate()
calc_wer = WordErrorRate()

def calculate_cer_wer(gt_text, pred_text):
    """
    torchmetrics.text의 CharErrorRate와 WordErrorRate를 사용하여 CER, WER 계산
    """
    # torchmetrics는 입력을 리스트 형태로 받아야 함
    cer = calc_cer([pred_text], [gt_text])
    wer = calc_wer([pred_text], [gt_text])
    
    return cer.item(), wer.item()

# 예제 테스트
ground_truth = "안녕하세요, 오늘 날씨가 참 좋네요."
predicted_text = "안녀하세요, 온늘 날씨가 참 줗네요."

cer, wer = calculate_cer_wer(ground_truth, predicted_text)

print(f"CER: {cer:.4f}")
print(f"WER: {wer:.4f}")

CER: 0.1500
WER: 0.6000
