In [35]:
import os
import math
import re

def read_hidden_positions(file_path):
    """
    question 정답 위치 정보를 읽어오는 함수
    """
    hidden_positions = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            match = re.match(r'U(\d+)\s*:\s*\[(.*?)\]', line.strip())
            if match:
                user_id = int(match.group(1))   # 사용자 ID 추출
                positions = list(map(int, match.group(2).split(','))) # 정답 위치 리스트 추출
                hidden_positions[user_id] = positions
    return hidden_positions

def read_predicted_rankings(file_path):
    """
    예측된 순위를 읽어오는 함수
    """
    predicted_rankings = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
        users = re.findall(r'\[U(\d+)\](.*?)(?=\n\[U\d+\]|\Z)', content, re.DOTALL) # user별 예측 데이터 추출
        for user in users:
            user_id = int(user[0])  # user ID 추출
            user_content = user[1]
            questions = re.findall(r'Question\s*(\d+)\s*[:：]\s*(.*?)(?=\nQuestion\s*\d+\s*[:：]|\n\[U\d+\]|\Z)', user_content, re.DOTALL)
            rankings = {}
            for q in questions:
                q_num = int(q[0])   # 질문 번호 추출
                ranking_str = q[1].strip().replace(' ', '')         # 순위 문자열 추출 후 공백 제거
                ranking = list(map(int, ranking_str.split(',')))    # 순위를 리스트로 변환
                rankings[q_num] = ranking
            predicted_rankings[user_id] = rankings
    return predicted_rankings

def compute_ndcg_at_5(rankings, hidden_positions):
    """
    nDCG@5 값을 계산하는 함수
    """
    user_ndcg = {}
    question_ndcg = []
    for user_id, user_rankings in rankings.items():
        if user_id not in hidden_positions:   # user id가 없으면 무시
            continue
        user_positions = hidden_positions[user_id]
        ndcg_values = []
        for q_num, ranking in user_rankings.items():
            if q_num - 1 >= len(user_positions):    # 정답 리스트에 없는 질문 번호는 무시
                continue
            correct_position = user_positions[q_num - 1]    # 올바른 위치
            correct_item = correct_position  # question의 news는 1부터 5까지 번호가 매겨져 있음
            if correct_item in ranking: # 정답 뉴스가 순위 리스트에 있을 경우
                rank = ranking.index(correct_item) + 1  # 해당 뉴스의 순위 (1부터 시작)
                if rank <= 5:
                    ndcg = 1 / math.log2(rank + 1)  # nDCG 계산
                else:
                    ndcg = 0
            else:   # 정답 뉴스가 없을 경우 NDCG는 0
                ndcg = 0
            ndcg_values.append((q_num, ndcg))
            question_ndcg.append(ndcg)
        avg_ndcg = sum(ndcg for _, ndcg in ndcg_values) / len(ndcg_values) if ndcg_values else 0    # user의 평균 nDCG 계산
        user_ndcg[user_id] = (avg_ndcg, ndcg_values)
    return user_ndcg, question_ndcg

def write_results(file_path, user_ndcg, overall_user_ndcg, overall_question_ndcg):
    """
    metrics 결과 작성 함수
    """
    with open(file_path, 'w', encoding='utf-8') as f:
        for user_id in sorted(user_ndcg.keys()):
            avg_ndcg, ndcg_values = user_ndcg[user_id]
            f.write(f'[U{user_id}] : 평균 nDCG@5 : {avg_ndcg:.3f}\n')
            for q_num, ndcg in sorted(ndcg_values):
                f.write(f'Question {q_num} : nDCG@5 = {ndcg:.3f}\n')
            f.write('\n')
        f.write('-----\n')
        f.write(f'전체 USER 평균 nDCG@5 : {overall_user_ndcg:.3f}\n')
        f.write(f'전체 Question 평균 nDCG@5 : {overall_question_ndcg:.3f}\n')

def measure_metrics(result_file, purpose):
    """
    metric 측정 main 함수
    """
    hidden_file = f'user_prompts/{purpose}/metadata/hidden_positions.txt'
    output_file = f'{result_file.replace(".txt","")}_metrics.txt'
    
    predicted_rankings = read_predicted_rankings(os.path.join('result/gpt_result', result_file))    # 예측된 순위 읽기
    hidden_positions = read_hidden_positions(hidden_file)   # 숨겨진 위치 읽기
    user_ndcg, question_ndcg_values = compute_ndcg_at_5(predicted_rankings, hidden_positions)   # NDCG@5 계산
    overall_user_ndcg = sum(avg for avg, _ in user_ndcg.values()) / len(user_ndcg) if user_ndcg else 0  # 전체 사용자 평균 NDCG 계산
    overall_question_ndcg = sum(question_ndcg_values) / len(question_ndcg_values) if question_ndcg_values else 0    # 전체 질문 평균 NDCG 계산
    os.makedirs(os.path.join('result', 'experiment_result'), exist_ok=True) # 출력 디렉터리 생성
    write_results(os.path.join('result', 'experiment_result', output_file), user_ndcg, overall_user_ndcg, overall_question_ndcg)    # 결과를 파일에 쓰기
    print(f'{output_file} 생성 완료 (대상 : {result_file})')


In [36]:
measure_metrics('2_negative_finetuned.txt', 'with_negative')
measure_metrics('2_positive_finetuned.txt', 'only_positive')

2_negative_finetuned_metrics.txt 생성 완료 (대상 : 2_negative_finetuned.txt)
2_positive_finetuned_metrics.txt 생성 완료 (대상 : 2_positive_finetuned.txt)
