In [None]:
import pandas as pd
import sys
from pathlib import Path
from tqdm.auto import tqdm
import re

sys.path.insert(0, '/home/y-guo/self-ensemble/self-ensemble')
sys.path.insert(0, '/home/y-guo/self-ensemble/self-ensemble/notebook')

import utils

In [None]:
def TF(generations, _gold_answers):
    """检查答案是否正确 - 与exceed.ipynb中的逻辑完全一致"""
    generations = utils.take_until_punct_or_space(generations)
    if len(generations) == 0:
        return False
    return utils.partial_match(generations, _gold_answers, True)

In [None]:
# 测试单个文件
test_file = "/home/y-guo/self-ensemble/myriadlama/llama3.2_3b/baseline_per_prompt.0shots.feather"
test_logits = "/home/y-guo/self-ensemble/myriadlama/llama3.2_3b/myriadlama.logits.avg.singleparaqapair.0fshots.0samples.5paras.feather"

result = calculate_majority_voting_accuracy(test_file, test_logits)

print(f"总问题数: {result['total_questions']}")
print(f"多数投票正确: {result['majority_correct']}")
print(f"多数投票错误: {result['majority_incorrect']}")
print(f"\n原始平票数: {result['original_tie_votes']} ({result['tie_rate']:.2%})")
print(f"用logits解决: {result['tie_resolved_by_logits']} (其中正确: {result['tie_resolved_correct']})")
print(f"剩余平票数: {result['tie_votes']}")
print(f"\n排除平票准确率: {result['voting_accuracy_exclude_ties']:.4f}")
print(f"平票算错准确率: {result['voting_accuracy_with_ties']:.4f}")
print(f"修正后准确率: {result['voting_accuracy_corrected']:.4f}")
print(f"Individual Any准确率: {result['individual_any_accuracy']:.4f}")
print(f"\n样本投票详情统计:")
print(f"  正确数分布: min={result['details']['correct_count'].min()}, "
      f"max={result['details']['correct_count'].max()}, "
      f"mean={result['details']['correct_count'].mean():.2f}")
print(f"  错误数分布: min={result['details']['incorrect_count'].min()}, "
      f"max={result['details']['incorrect_count'].max()}, "
      f"mean={result['details']['incorrect_count'].mean():.2f}")
print(f"\n前10个问题的投票结果:")
display(result['details'].head(10))

In [None]:
# 测试：打印平票的详细信息（仅用于调试）
tie_cases = result['details'][result['details']['is_tie'] == True].head(10)

if len(tie_cases) > 0:
    print(f"找到 {len(result['details'][result['details']['is_tie'] == True])} 个平票，显示前10个:\n")
    
    # 读取原始数据以获取具体的问题和答案
    test_df = pd.read_feather(test_file)
    logits_df = pd.read_feather(test_logits) if Path(test_logits).exists() else None
    
    for idx, (_, row) in enumerate(tie_cases.iterrows(), 1):
        uuid = row['uuid']
        print(f"{'='*80}")
        print(f"平票案例 #{idx} - UUID: {uuid}")
        print(f"正确: {row['correct_count']}, 错误: {row['incorrect_count']}, 空: {row['empty_count']}")
        
        # 获取该uuid的所有paraphrases
        question_group = test_df[test_df['uuid'] == uuid]
        
        # 显示正确答案
        print(f"正确答案: {question_group.iloc[0]['answer_lemmas']}")
        
        # 显示所有生成结果
        print(f"\n所有生成结果 ({len(question_group)} 个paraphrases):")
        for i, (_, gen_row) in enumerate(question_group.iterrows(), 1):
            gen = gen_row['generation_lemmas']
            is_correct = TF(gen, gen_row['answer_lemmas'])
            status = "✓ 正确" if is_correct else "✗ 错误"
            print(f"  {i}. [{status}] {gen}")
        
        # 显示logits的答案和修正结果
        if logits_df is not None:
            logits_row = logits_df[logits_df['uuid'] == uuid]
            if len(logits_row) > 0:
                logits_gen = logits_row.iloc[0]['generation_lemmas']
                answer_lemmas = question_group.iloc[0]['answer_lemmas']
                logits_correct = TF(logits_gen, answer_lemmas)
                status = "✓ 正确" if logits_correct else "✗ 错误"
                print(f"\nLogits.avg答案: [{status}] {logits_gen}")
                print(f"修正后结果: {'正确' if row['majority_is_correct'] else '错误'}")
                print(f"是否用logits解决: {'是' if row['resolved_by_logits'] else '否'}")
            else:
                print(f"\n⚠ 在logits文件中未找到UUID: {uuid}")
        else:
            print(f"\n⚠ 未提供logits文件")
        print()
else:
    print("测试文件中没有平票情况")

In [None]:
def process_all_models_voting(base_path, output_path):
    """
    处理所有模型的多数投票准确率，使用语义匹配
    平票时使用logits.avg结果解决
    """
    base_path = Path(base_path)
    output_path = Path(output_path)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # 优化：只获取有baseline文件的目录
    model_dirs = []
    for d in base_path.iterdir():
        if d.is_dir() and list(d.glob('baseline_per_prompt.*.feather')):
            model_dirs.append(d)
    
    all_results = []
    
    for model_dir in tqdm(sorted(model_dirs), desc="Processing models"):
        model_name = model_dir.name
        baseline_files = sorted(model_dir.glob('baseline_per_prompt.*.feather'))
        
        print(f"\n处理模型: {model_name}")
        
        for baseline_file in baseline_files:
            match = re.search(r'baseline_per_prompt\.(\d+)shots\.feather', baseline_file.name)
            if not match:
                continue
            
            shot_count = int(match.group(1))
            
            # 查找对应的logits.avg文件
            if shot_count == 5:
                # 5-shot没有fshots后缀
                logits_pattern = f"myriadlama.logits.avg.singleparaqapair.{shot_count}samples.*.feather"
            else:
                logits_pattern = f"myriadlama.logits.avg.singleparaqapair.{shot_count}fshots.{shot_count}samples.*.feather"
            
            logits_files = list(model_dir.glob(logits_pattern))
            logits_path = logits_files[0] if logits_files else None
            
            if logits_path:
                print(f"  找到logits文件: {logits_path.name}")
            
            try:
                result = calculate_majority_voting_accuracy(baseline_file, logits_path)
                
                all_results.append({
                    'model': model_name,
                    'shot_count': shot_count,
                    'total_questions': result['total_questions'],
                    'majority_correct': result['majority_correct'],
                    'majority_incorrect': result['majority_incorrect'],
                    'original_tie_votes': result['original_tie_votes'],
                    'tie_resolved_by_logits': result['tie_resolved_by_logits'],
                    'tie_resolved_correct': result['tie_resolved_correct'],
                    'remaining_ties': result['tie_votes'],
                    'tie_rate': result['tie_rate'],
                    'voting_accuracy_exclude_ties': result['voting_accuracy_exclude_ties'],
                    'voting_accuracy_with_ties': result['voting_accuracy_with_ties'],
                    'voting_accuracy_corrected': result['voting_accuracy_corrected'],
                    'individual_any_accuracy': result['individual_any_accuracy']
                })
                
                print(f"  {shot_count}-shot:")
                print(f"    原始平票: {result['original_tie_votes']} ({result['tie_rate']:.1%})")
                print(f"    用logits解决: {result['tie_resolved_by_logits']} (其中正确: {result['tie_resolved_correct']})")
                print(f"    排除平票准确率: {result['voting_accuracy_exclude_ties']:.4f}")
                print(f"    平票算错准确率: {result['voting_accuracy_with_ties']:.4f}")
                print(f"    修正后准确率: {result['voting_accuracy_corrected']:.4f}")
                print(f"    Individual Any: {result['individual_any_accuracy']:.4f}")
                
            except Exception as e:
                print(f"  错误处理 {baseline_file.name}: {e}")
                import traceback
                traceback.print_exc()
    
    # 保存结果
    df_results = pd.DataFrame(all_results)
    output_file = output_path / "majority_voting_accuracy.csv"
    df_results.to_csv(output_file, index=False)
    print(f"\n结果已保存到: {output_file}")
    
    return df_results

In [None]:
# 运行所有模型的投票统计
BASE_PATH = "/home/y-guo/self-ensemble/myriadlama"
OUTPUT_PATH = "/home/y-guo/self-ensemble/analyzeResults"

df_voting_results = process_all_models_voting(BASE_PATH, OUTPUT_PATH)

In [None]:
# 显示结果汇总
print("多数投票准确率汇总:")
display(df_voting_results.sort_values(['model', 'shot_count']))

print("\n按模型分组的平均准确率:")
grouped = df_voting_results.groupby('model')[['voting_accuracy_exclude_ties', 'voting_accuracy_with_ties', 
                                                'voting_accuracy_corrected', 'individual_any_accuracy']].mean()
grouped = grouped.sort_values('voting_accuracy_corrected', ascending=False)
display(grouped)

print("\n平票解决情况汇总:")
tie_summary = df_voting_results.groupby('model')[['original_tie_votes', 'tie_resolved_by_logits', 
                                                    'tie_resolved_correct', 'remaining_ties']].sum()
display(tie_summary)

In [None]:
def majority_vote(generations_list, answer_lemmas):
    """
    对多个生成结果进行多数投票，使用语义匹配而非字符串匹配
    
    Args:
        generations_list: list of generation_lemmas arrays (from baseline_per_prompt)
        answer_lemmas: gold answer lemmas
        
    Returns:
        (is_correct, correct_count, incorrect_count, empty_count)
    """
    # 统计正确、错误、空答案的数量
    correct_count = 0
    incorrect_count = 0
    empty_count = 0
    
    for gen in generations_list:
        # 使用与TF完全相同的逻辑判断每个generation是否正确
        if TF(gen, answer_lemmas):
            correct_count += 1
        else:
            # 检查是否是空答案
            gen_clean = utils.take_until_punct_or_space(gen)
            if len(gen_clean) == 0:
                empty_count += 1
            else:
                incorrect_count += 1
    
    # 多数投票：正确答案多数 vs 错误答案多数
    total = len(generations_list)
    
    # 投票结果
    if correct_count > incorrect_count + empty_count:
        # 多数正确
        is_majority_correct = True
    elif incorrect_count + empty_count > correct_count:
        # 多数错误（包括空答案）
        is_majority_correct = False
    else:
        # 平票：正确和错误数量相等
        is_majority_correct = None  # None表示平票/不确定
    
    return is_majority_correct, correct_count, incorrect_count, empty_count

In [None]:
def calculate_majority_voting_accuracy(baseline_per_prompt_path, logits_avg_path=None):
    """
    计算baseline_per_prompt文件中使用多数投票的准确率
    使用语义匹配判断意见一致性
    
    Args:
        baseline_per_prompt_path: 到baseline_per_prompt.*.feather文件的路径
        logits_avg_path: 到logits.avg文件的路径（用于解决平票）
        
    Returns:
        dict with voting statistics
    """
    df = pd.read_feather(baseline_per_prompt_path)
    
    # 如果提供了logits.avg文件，读取它
    logits_df = None
    if logits_avg_path and Path(logits_avg_path).exists():
        logits_df = pd.read_feather(logits_avg_path)
    
    # 按uuid分组，每个uuid对应一个问题的多个paraphrases
    grouped = df.groupby('uuid')
    
    total_questions = 0
    majority_correct = 0
    majority_incorrect = 0
    tie_votes = 0
    tie_resolved_by_logits = 0
    tie_resolved_correct = 0
    vote_details = []
    
    for uuid, group in tqdm(grouped, desc="Processing questions", leave=False):
        total_questions += 1
        
        # 优化：直接使用values避免额外的tolist()转换
        generations_list = group['generation_lemmas'].values
        
        # 获取正确答案（每个问题的答案都相同）
        answer_lemmas = group['answer_lemmas'].iloc[0]
        
        # 进行多数投票（使用语义匹配）
        is_correct, correct_cnt, incorrect_cnt, empty_cnt = majority_vote(generations_list, answer_lemmas)
        
        # 记录原始投票结果
        original_is_correct = is_correct
        resolved_by_logits = False
        
        # 如果是平票且提供了logits文件，尝试用logits结果解决
        if is_correct is None and logits_df is not None:
            logits_row = logits_df[logits_df['uuid'] == uuid]
            if len(logits_row) > 0:
                logits_gen = logits_row.iloc[0]['generation_lemmas']
                is_correct = TF(logits_gen, answer_lemmas)
                resolved_by_logits = True
                tie_resolved_by_logits += 1
                if is_correct:
                    tie_resolved_correct += 1
        
        # 统计结果
        if is_correct is True:
            majority_correct += 1
        elif is_correct is False:
            majority_incorrect += 1
        else:  # None - tie (仍未解决)
            tie_votes += 1
        
        # 同时计算individual accuracy (任意一个paraphrase正确就算对)
        individual_any_correct = (correct_cnt > 0)
        
        # 优化：只在需要详细信息时才构建gold_answer字符串
        vote_details.append({
            'uuid': uuid,
            'majority_is_correct': is_correct,
            'original_is_correct': original_is_correct,
            'resolved_by_logits': resolved_by_logits,
            'correct_count': correct_cnt,
            'incorrect_count': incorrect_cnt,
            'empty_count': empty_cnt,
            'total_paraphrases': len(group),
            'is_tie': (original_is_correct is None),  # 记录原始是否平票
            'individual_any_correct': individual_any_correct
        })
    
    # 计算准确率
    # 排除平票的准确率（原始投票）
    original_tie_votes = sum(1 for d in vote_details if d['original_is_correct'] is None)
    decidable_questions = total_questions - original_tie_votes
    voting_accuracy_exclude_ties = majority_correct / (total_questions - tie_votes) if (total_questions - tie_votes) > 0 else 0
    
    # 包含平票的准确率（将平票视为错误）- 原始投票
    voting_accuracy_with_ties = (majority_correct - tie_resolved_correct) / total_questions if total_questions > 0 else 0
    
    # 修正后的准确率（用logits解决平票后）
    voting_accuracy_corrected = majority_correct / total_questions if total_questions > 0 else 0
    
    # 优化：使用更快的计数方式
    individual_any_accuracy = sum(d['individual_any_correct'] for d in vote_details) / total_questions if total_questions > 0 else 0
    
    return {
        'total_questions': total_questions,
        'majority_correct': majority_correct,
        'majority_incorrect': majority_incorrect,
        'original_tie_votes': original_tie_votes,
        'tie_votes': tie_votes,  # 修正后剩余的平票
        'tie_resolved_by_logits': tie_resolved_by_logits,
        'tie_resolved_correct': tie_resolved_correct,
        'tie_rate': original_tie_votes / total_questions if total_questions > 0 else 0,
        'voting_accuracy_exclude_ties': voting_accuracy_exclude_ties,  # 排除平票的准确率（原始）
        'voting_accuracy_with_ties': voting_accuracy_with_ties,  # 平票算错的准确率（原始）
        'voting_accuracy_corrected': voting_accuracy_corrected,  # 修正后的准确率
        'individual_any_accuracy': individual_any_accuracy,
        'details': pd.DataFrame(vote_details)
    }