## measure_metrics 함수 정의

In [1]:
import os
import math
import re

def read_hidden_positions(file_path):
    """
    question 정답 위치 정보를 읽어오는 함수
    """
    hidden_positions = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            match = re.match(r'U(\d+)\s*:\s*\[(.*?)\]', line.strip())
            if match:
                user_id = int(match.group(1))   # 사용자 ID 추출
                positions = list(map(int, match.group(2).split(','))) # 정답 위치 리스트 추출
                hidden_positions[user_id] = positions
    return hidden_positions

def read_predicted_rankings(file_path):
    """
    예측된 순위를 읽어오는 함수
    """
    predicted_rankings = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
        users = re.findall(r'\[U(\d+)\](.*?)(?=\n\[U\d+\]|\Z)', content, re.DOTALL) # user별 예측 데이터 추출
        for user in users:
            user_id = int(user[0])  # user ID 추출
            user_content = user[1]
            questions = re.findall(r'Question\s*(\d+)\s*[:：]\s*(.*?)(?=\nQuestion\s*\d+\s*[:：]|\n\[U\d+\]|\Z)', user_content, re.DOTALL)
            rankings = {}
            for q in questions:
                q_num = int(q[0])   # 질문 번호 추출
                ranking_str = q[1].strip().replace(' ', '')         # 순위 문자열 추출 후 공백 제거
                ranking = list(map(int, ranking_str.split(',')))    # 순위를 리스트로 변환
                rankings[q_num] = ranking
            predicted_rankings[user_id] = rankings
    return predicted_rankings

def compute_ndcg_at_5(rankings, hidden_positions):
    """
    nDCG@5 및 정확도를 계산하는 함수
    """
    user_ndcg = {}
    question_ndcg = []

    # 정확도 계산을 위한 변수 초기화
    total_questions = 0
    total_correct_top1 = 0
    user_accuracy = {}
    user_correct_counts = {}
    user_total_questions = {}

    for user_id, user_rankings in rankings.items():
        if user_id not in hidden_positions:   # user id가 없으면 무시
            continue
        user_positions = hidden_positions[user_id]
        ndcg_values = []
        correct_top1_count = 0
        total_user_questions = 0
        for q_num, ranking in user_rankings.items():
            if q_num - 1 >= len(user_positions):    # 정답 리스트에 없는 질문 번호는 무시
                continue
            correct_position = user_positions[q_num - 1]    # 올바른 위치
            correct_item = correct_position  # question의 news는 1부터 5까지 번호가 매겨져 있음
            total_questions += 1
            total_user_questions += 1

            # nDCG 계산
            if correct_item in ranking: # 정답 뉴스가 순위 리스트에 있을 경우
                rank = ranking.index(correct_item) + 1  # 해당 뉴스의 순위 (1부터 시작)
                if rank <= 5:
                    ndcg = 1 / math.log2(rank + 1)  # nDCG 계산
                else:
                    ndcg = 0
            else:   # 정답 뉴스가 없을 경우 NDCG는 0
                ndcg = 0
            ndcg_values.append((q_num, ndcg))
            question_ndcg.append(ndcg)

            # 정확도 계산
            if correct_item == ranking[0]:
                correct_top1_count += 1
                total_correct_top1 += 1

        avg_ndcg = sum(ndcg for _, ndcg in ndcg_values) / len(ndcg_values) if ndcg_values else 0    # user의 평균 nDCG 계산
        accuracy = correct_top1_count / total_user_questions if total_user_questions else 0
        user_accuracy[user_id] = accuracy
        user_correct_counts[user_id] = correct_top1_count
        user_total_questions[user_id] = total_user_questions
        user_ndcg[user_id] = (avg_ndcg, ndcg_values)

    # 전체 USER 정확도 계산 (사용자별 정확도의 평균)
    overall_user_accuracy = sum(user_accuracy.values()) / len(user_accuracy) if user_accuracy else 0
    # 전체 Question 정확도 계산
    overall_question_accuracy = total_correct_top1 / total_questions if total_questions else 0

    return user_ndcg, question_ndcg, user_accuracy, overall_user_accuracy, overall_question_accuracy, total_correct_top1, total_questions, user_correct_counts, user_total_questions

def write_results(file_path, user_ndcg, overall_user_ndcg, overall_question_ndcg, user_accuracy, overall_user_accuracy, overall_question_accuracy, total_correct_top1, total_questions, user_correct_counts, user_total_questions):
    """
    metrics 결과 작성 함수
    """
    with open(file_path, 'w', encoding='utf-8') as f:
        f.write(f'전체 USER 평균 nDCG@5 : {overall_user_ndcg:.3f}\n')
        f.write(f'전체 Question nDCG@5 : {overall_question_ndcg:.3f}\n\n')
        f.write(f'전체 USER 평균 Accuracy : {overall_user_accuracy:.3f}\n')
        f.write(f'전체 Question Accuracy : {overall_question_accuracy:.3f} ({total_correct_top1} / {total_questions})\n')
        f.write('\n-----------------------------------------------------------\n\n')
        for user_id in sorted(user_ndcg.keys()):
            avg_ndcg, ndcg_values = user_ndcg[user_id]
            accuracy = user_accuracy.get(user_id, 0)
            correct_top1_count = user_correct_counts.get(user_id, 0)
            total_user_questions = user_total_questions.get(user_id, 0)
            f.write(f'[U{user_id}] : 평균 nDCG@5 : {avg_ndcg:.3f}  |  Accuracy : {accuracy:.3f} ({correct_top1_count}/{total_user_questions})\n')
            for q_num, ndcg in sorted(ndcg_values):
                f.write(f'Question {q_num} : nDCG@5 = {ndcg:.3f}\n')
            f.write('\n')

def measure_metrics(target_file, target_folder, purpose):
    """
    metric 측정 main 함수
    """

    target_folder = f'../../prompts/{target_folder}'

    if purpose == 'new_negative':
        hidden_file = f'{target_folder}/metadata/hidden_positions.txt'
    else:
        hidden_file = f'{target_folder}/{purpose}/metadata/hidden_positions.txt'
    
    output_file = f'{target_file.replace(".txt","")}_metrics.txt'
    
    predicted_rankings = read_predicted_rankings(os.path.join('../../results/gpt_result', target_file))    # 예측된 순위 읽기
    hidden_positions = read_hidden_positions(hidden_file)   # 숨겨진 위치 읽기
    user_ndcg, question_ndcg_values, user_accuracy, overall_user_accuracy, overall_question_accuracy, total_correct_top1, total_questions, user_correct_counts, user_total_questions = compute_ndcg_at_5(predicted_rankings, hidden_positions)   # NDCG@5 및 정확도 계산
    overall_user_ndcg = sum(avg for avg, _ in user_ndcg.values()) / len(user_ndcg) if user_ndcg else 0  # 전체 사용자 평균 NDCG 계산
    overall_question_ndcg = sum(question_ndcg_values) / len(question_ndcg_values) if question_ndcg_values else 0    # 전체 질문 평균 NDCG 계산
    os.makedirs(os.path.join('../../results', 'metrics'), exist_ok=True) # 출력 디렉터리 생성
    write_results(
        os.path.join('../../results', 'metrics', output_file),
        user_ndcg,
        overall_user_ndcg,
        overall_question_ndcg,
        user_accuracy,
        overall_user_accuracy,
        overall_question_accuracy,
        total_correct_top1,
        total_questions,
        user_correct_counts,
        user_total_questions
    )    # 결과를 파일에 쓰기
    print(f'{output_file} 생성 완료 (대상 : {target_file})')


## 실행

In [2]:
# measure_metrics(target_file='[241223-3] positive.txt', target_folder = "[241223-3] 1~1000", purpose='only_positive')
measure_metrics(target_file='[241230-5] negative_3.txt', target_folder = "[241230-5] 1~1000", purpose='with_negative')

[241230-5] negative_3_metrics.txt 생성 완료 (대상 : [241230-5] negative_3.txt)
