## measure_metrics 함수 정의

In [1]:
import os
import math
import re

def read_hidden_positions(file_path):
    """
    question 정답 위치 정보를 읽어오는 함수
    """
    hidden_positions = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            match = re.match(r'U(\d+)\s*:\s*\[(.*?)\]', line.strip())
            if match:
                user_id = int(match.group(1))   # 사용자 ID 추출
                positions = list(map(int, match.group(2).split(','))) # 정답 위치 리스트 추출
                hidden_positions[user_id] = positions
    return hidden_positions

def read_predicted_rankings(file_path):
    """
    예측된 순위를 읽어오는 함수
    - 파일 내에서 [U숫자] 단위로 split 한 다음
      각 블록 내의 매 라인에서 "Question X: ..." 패턴을 찾아 순위를 파싱
    """
    predicted_rankings = {}
    
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()

    # [U숫자] 블록들을 추출 (user_id, user_content)
    user_blocks = re.findall(r'\[U(\d+)\](.*?)(?=\n\[U\d+\]|\Z)', content, flags=re.DOTALL)

    for user_block in user_blocks:
        user_id = int(user_block[0])
        user_content = user_block[1]

        lines = user_content.strip().split('\n')
        user_questions = {}

        for line in lines:
            # "Question 1: 19, 4, ..." 형태의 패턴
            match_q = re.match(r'Question\s+(\d+)\s*:\s*(.*)', line.strip())
            if match_q:
                q_num = int(match_q.group(1))
                ranking_str = match_q.group(2).strip()
                
                # 쉼표로 split -> 정수 변환
                # 이 때, x.strip()이 빈 문자열인 경우를 필터링해서 예외 방지
                ranking = [
                    int(x.strip())
                    for x in ranking_str.split(',')
                    if x.strip()  # 빈 문자열("")은 건너뛴다
                ]
                
                user_questions[q_num] = ranking

        predicted_rankings[user_id] = user_questions

    return predicted_rankings


def ndcg_for_rank(rank, k):
    """
    랭크(rank: 1-based)가 k 이내일 때의 nDCG 기여도를 계산하는 함수.
    rank가 k보다 크면 0을 반환.
    """
    if rank <= k:
        # DCG 공식: 1 / log2(rank + 1)
        return 1 / math.log2(rank + 1)
    return 0


def compute_ndcg_at_5_and_10(rankings, hidden_positions):
    """
    nDCG@5, nDCG@10 및 정확도(Top1 Accuracy)를 계산하는 함수
    """
    # 사용자별 nDCG 결과 저장용
    user_ndcg = {}
    
    # 전체 질문 단위로 모은 nDCG (5, 10)
    question_ndcg_5 = []
    question_ndcg_10 = []

    # 정확도 계산을 위한 변수
    total_questions = 0
    total_correct_top1 = 0

    # 사용자별 Accuracy 계산을 위한 변수
    user_accuracy = {}
    user_correct_counts = {}
    user_total_questions = {}

    for user_id, user_rankings in rankings.items():
        if user_id not in hidden_positions:   # user id가 없으면 무시
            continue
        
        user_positions = hidden_positions[user_id]
        ndcg5_values = []
        ndcg10_values = []
        
        correct_top1_count = 0
        total_user_questions = 0

        for q_num, ranking in user_rankings.items():
            # q_num-1 이 hidden_positions 내에서 유효 인덱스인지 확인
            if q_num - 1 >= len(user_positions):
                continue
            
            correct_item = user_positions[q_num - 1]  # question 정답 (1~N번 뉴스 중 하나)
            total_questions += 1
            total_user_questions += 1

            # 정답 뉴스가 ranking 안에 있는지 확인
            if correct_item in ranking:
                rank = ranking.index(correct_item) + 1  # 실제 순위 (1-based)
                ndcg5 = ndcg_for_rank(rank, 5)
                ndcg10 = ndcg_for_rank(rank, 10)
            else:
                ndcg5 = 0
                ndcg10 = 0

            ndcg5_values.append((q_num, ndcg5))
            ndcg10_values.append((q_num, ndcg10))

            question_ndcg_5.append(ndcg5)
            question_ndcg_10.append(ndcg10)

            # 정확도 (Top1) 체크
            if len(ranking) > 0 and ranking[0] == correct_item:
                correct_top1_count += 1
                total_correct_top1 += 1

        # 사용자별 평균 nDCG@5, nDCG@10
        avg_ndcg5 = sum(ndcg for _, ndcg in ndcg5_values) / len(ndcg5_values) if ndcg5_values else 0
        avg_ndcg10 = sum(ndcg for _, ndcg in ndcg10_values) / len(ndcg10_values) if ndcg10_values else 0

        # 사용자별 Accuracy
        accuracy = correct_top1_count / total_user_questions if total_user_questions else 0

        user_ndcg[user_id] = {
            'ndcg5': ndcg5_values,
            'ndcg10': ndcg10_values,
            'avg_ndcg5': avg_ndcg5,
            'avg_ndcg10': avg_ndcg10
        }
        user_accuracy[user_id] = accuracy
        user_correct_counts[user_id] = correct_top1_count
        user_total_questions[user_id] = total_user_questions

    # 전체 USER 평균 nDCG@5, nDCG@10
    if user_ndcg:
        overall_user_ndcg5 = sum(v['avg_ndcg5'] for v in user_ndcg.values()) / len(user_ndcg)
        overall_user_ndcg10 = sum(v['avg_ndcg10'] for v in user_ndcg.values()) / len(user_ndcg)
    else:
        overall_user_ndcg5 = 0
        overall_user_ndcg10 = 0

    # 전체 Question nDCG@5, nDCG@10
    if question_ndcg_5:
        overall_question_ndcg5 = sum(question_ndcg_5) / len(question_ndcg_5)
    else:
        overall_question_ndcg5 = 0
    if question_ndcg_10:
        overall_question_ndcg10 = sum(question_ndcg_10) / len(question_ndcg_10)
    else:
        overall_question_ndcg10 = 0

    # 전체 USER 평균 Accuracy (사용자별 Accuracy 평균)
    overall_user_accuracy = sum(user_accuracy.values()) / len(user_accuracy) if user_accuracy else 0
    # 전체 Question Accuracy (모든 Question 중 Top1 정답 비율)
    overall_question_accuracy = total_correct_top1 / total_questions if total_questions else 0

    return (
        user_ndcg,
        overall_user_ndcg5,
        overall_user_ndcg10,
        overall_question_ndcg5,
        overall_question_ndcg10,
        user_accuracy,
        overall_user_accuracy,
        overall_question_accuracy,
        total_correct_top1,
        total_questions,
        user_correct_counts,
        user_total_questions
    )


def write_results(
    file_path,
    user_ndcg,
    overall_user_ndcg5,
    overall_user_ndcg10,
    overall_question_ndcg5,
    overall_question_ndcg10,
    user_accuracy,
    overall_user_accuracy,
    overall_question_accuracy,
    total_correct_top1,
    total_questions,
    user_correct_counts,
    user_total_questions
):
    """
    metrics 결과 작성 함수
    """
    with open(file_path, 'w', encoding='utf-8') as f:
        # 전체 요약
        f.write(f'전체 USER 평균 nDCG@5  : {overall_user_ndcg5:.3f}\n')
        f.write(f'전체 USER 평균 nDCG@10 : {overall_user_ndcg10:.3f}\n')
        f.write(f'전체 Question nDCG@5  : {overall_question_ndcg5:.3f}\n')
        f.write(f'전체 Question nDCG@10 : {overall_question_ndcg10:.3f}\n\n')

        f.write(f'전체 USER 평균 Accuracy : {overall_user_accuracy:.3f}\n')
        f.write(f'전체 Question Accuracy : {overall_question_accuracy:.3f} ({total_correct_top1} / {total_questions})\n')
        f.write('\n-----------------------------------------------------------\n\n')

        # 사용자별 상세
        for user_id in sorted(user_ndcg.keys()):
            ndcg_info = user_ndcg[user_id]
            avg_ndcg5 = ndcg_info['avg_ndcg5']
            avg_ndcg10 = ndcg_info['avg_ndcg10']
            
            accuracy = user_accuracy.get(user_id, 0)
            correct_top1_count = user_correct_counts.get(user_id, 0)
            total_user_q = user_total_questions.get(user_id, 0)

            f.write(
                f'[U{user_id}] : 평균 nDCG@5 : {avg_ndcg5:.3f}  |  '
                f'평균 nDCG@10 : {avg_ndcg10:.3f}  |  '
                f'Accuracy : {accuracy:.3f} ({correct_top1_count}/{total_user_q})\n'
            )

            # (q_num, ndcg_val) 리스트를 딕셔너리로 변환
            ndcg5_dict = dict(ndcg_info['ndcg5'])   # { q_num: nDCG@5, ... }
            ndcg10_dict = dict(ndcg_info['ndcg10']) # { q_num: nDCG@10, ... }

            # 두 딕셔너리에 있는 모든 question 번호를 합쳐서 정렬
            all_q_nums = sorted(set(ndcg5_dict.keys()) | set(ndcg10_dict.keys()))
            
            for q_num in all_q_nums:
                ndcg5_val = ndcg5_dict.get(q_num, 0.0)
                ndcg10_val = ndcg10_dict.get(q_num, 0.0)
                # 원하는 출력 형식 예: "  - Question 1 : nDCG@5 = 0.000 | nDCG@10 = 0.000"
                f.write(
                    f'  - Question {q_num} : '
                    f'nDCG@5 = {ndcg5_val:.3f} | '
                    f'nDCG@10 = {ndcg10_val:.3f}\n'
                )

            f.write('\n')


def measure_metrics(target_file, target_folder, purpose):
    """
    metric 측정 main 함수
    """
    target_folder = f'../../prompts/{target_folder}'

    # hidden_positions.txt 경로 설정
    if purpose == 'new_negative':
        hidden_file = f'{target_folder}/metadata/hidden_positions.txt'
    else:
        hidden_file = f'{target_folder}/{purpose}/metadata/hidden_positions.txt'
    
    output_file = f'{target_file.replace(".txt","")}_metrics.txt'
    
    # 1) 예측된 순위 읽기
    predicted_rankings = read_predicted_rankings(
        os.path.join('../../results/gpt_result', target_file)
    )
    # 2) 정답 위치 읽기
    hidden_positions = read_hidden_positions(hidden_file)
    # 3) nDCG@5, nDCG@10 및 Accuracy 계산
    (
        user_ndcg,
        overall_user_ndcg5,
        overall_user_ndcg10,
        overall_question_ndcg5,
        overall_question_ndcg10,
        user_accuracy,
        overall_user_accuracy,
        overall_question_accuracy,
        total_correct_top1,
        total_questions,
        user_correct_counts,
        user_total_questions
    ) = compute_ndcg_at_5_and_10(predicted_rankings, hidden_positions)

    # 4) 결과를 파일로 출력
    os.makedirs(os.path.join('../../results', 'metrics'), exist_ok=True)
    write_results(
        os.path.join('../../results', 'metrics', output_file),
        user_ndcg,
        overall_user_ndcg5,
        overall_user_ndcg10,
        overall_question_ndcg5,
        overall_question_ndcg10,
        user_accuracy,
        overall_user_accuracy,
        overall_question_accuracy,
        total_correct_top1,
        total_questions,
        user_correct_counts,
        user_total_questions
    )
    print(f'{output_file} 생성 완료 (대상 : {target_file})')


## 실행

In [5]:

measure_metrics(target_file='[250325] negative_ratio_80.txt', target_folder = "[top1] history_ratio 0.8", purpose='with_negative')
# measure_metrics(target_file='[250318] norway_positive2.txt', target_folder = "[top1] norway_ns4", purpose='only_positive')


[250325] negative_ratio_80_metrics.txt 생성 완료 (대상 : [250325] negative_ratio_80.txt)


In [15]:
measure_metrics(target_file='[250310] positive_ns4_fine(40,15)_across(negative_model).txt', target_folder = "[top1] test_ns4", purpose='only_positive')
measure_metrics(target_file='[250310] negative_ns4_fine(40,15)_across(positive_model).txt', target_folder = "[top1] test_ns4", purpose='with_negative')

[250310] positive_ns4_fine(40,15)_across(negative_model)_metrics.txt 생성 완료 (대상 : [250310] positive_ns4_fine(40,15)_across(negative_model).txt)
[250310] negative_ns4_fine(40,15)_across(positive_model)_metrics.txt 생성 완료 (대상 : [250310] negative_ns4_fine(40,15)_across(positive_model).txt)
