In [None]:
import re
import random
import time
from collections import defaultdict
import datetime
import os

# --- 依赖检查 ---
try:
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch
except ImportError:
    print("错误：未找到 'transformers' 或 'torch' 库。请先安装： pip install torch transformers accelerate")
    exit()

# --- 配置 ---
# 要测试的模型ID列表
# 例如: ["/data/global/model/llama3_instruct/", "/path/to/another_model"]
model_ids_to_test = [
"/root/autodl-tmp/AceMath",
"/root/autodl-tmp/llama",
"/root/autodl-tmp/Mistral",
"/root/autodl-tmp/Qwen2.5-Math-7B",
]
# 每个随机位数类别测试的数量
num_random_tests = 1000 # 减少数量以便快速测试，您可以改回 500 或其他值

# 推理时使用的批量大小 (Batch Size for inference)
BATCH_SIZE = 1000 # 根据您的 GPU 显存调整此值。如果遇到 OOM，请减小此值。
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 日志文件配置
log_directory = "test1_logs"
os.makedirs(log_directory, exist_ok=True)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_file_name = os.path.join(log_directory, f"llm_addition_test_log_{timestamp}.txt")

# --- 全局变量 ---
current_log_file = None

# --- 日志函数 ---
def log_message(message, print_to_console=True):
    """将消息写入日志文件并可选择打印到控制台"""
    global current_log_file
    if current_log_file:
        with open(current_log_file, 'a', encoding='utf-8') as f:
            f.write(message + "\n")
    if print_to_console:
        print(message)

# --- 模型调用函数 (批量处理) ---
def get_model_responses_batch(prompts: list, model, tokenizer, device):
    if model is None or tokenizer is None:
        return ["[模型未加载]" for _ in prompts]

    try:
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, return_attention_mask=True)
        input_ids = inputs.input_ids.to(model.device if hasattr(model, 'device') else device)
        attention_mask = inputs.attention_mask.to(model.device if hasattr(model, 'device') else device)
        
        num_input_tokens = input_ids.shape[1]

        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=5, 
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                do_sample=False
            )
        
        responses_text = []
        for i in range(outputs.shape[0]):
            response_ids = outputs[i][num_input_tokens:]
            response_text = tokenizer.decode(response_ids, skip_special_tokens=True).strip()
            responses_text.append(response_text)
        return responses_text
        
    except Exception as e:
        log_message(f"模型批量生成过程中出错: {e} (处理中批量大小: {len(prompts)})", print_to_console=True)
        return ["[生成错误]" for _ in prompts]

# --- 响应解析函数 ---
def parse_response(response_text):
    match = re.search(r'(-?\d+)', response_text)
    if match:
        try:
            return int(match.group(1))
        except ValueError:
            log_message(f" [警告: 解析的数字过大或无效 '{match.group(1)}']", print_to_console=False)
            return None
    return None

# --- 辅助函数：获取数字位数类别 (用于基本测试和进位测试) ---
def get_digit_category(n1, n2):
    len1 = len(str(abs(n1)))
    len2 = len(str(abs(n2)))
    max_digits = max(len1, len2)
    return f"{max_digits}-digit"

# --- 批量测试运行函数 ---
def run_batch_test(test_cases_xy: list, category_name: str, stats_dict: defaultdict, batch_size: int, model, tokenizer, device):
    if not test_cases_xy:
        return

    stats_dict[category_name]['total'] += len(test_cases_xy)

    num_processed_in_category = 0
    for i in range(0, len(test_cases_xy), batch_size):
        current_batch_xy = test_cases_xy[i : i + batch_size]
        
        if not current_batch_xy:
            continue

        prompts_chunk = []
        correct_answers_chunk = []
        
        for x, y in current_batch_xy:
            prompts_chunk.append(f"Calculate: {x}+{y} = ")
            correct_answers_chunk.append(x + y)
        
        raw_responses_chunk = get_model_responses_batch(prompts_chunk, model, tokenizer, device)

        for j in range(len(raw_responses_chunk)):
            prompt_text = prompts_chunk[j]
            correct_answer = correct_answers_chunk[j]
            raw_response = raw_responses_chunk[j]
            num_processed_in_category +=1

            log_line_prefix = f"测试 ({category_name} {num_processed_in_category}/{len(test_cases_xy)}): {prompt_text}"
            model_answer = parse_response(raw_response)
            
            result_log = f"模型原始输出: '{raw_response}' -> "

            if model_answer is not None:
                if model_answer == correct_answer:
                    result_log += f"结果: {model_answer} (正确)"
                    stats_dict[category_name]['correct'] += 1
                else:
                    result_log += f"结果: {model_answer} (错误! 正确答案是 {correct_answer})"
            else:
                if raw_response == "[模型未加载]":
                    result_log += "结果: 模型未加载"
                elif raw_response == "[生成错误]":
                    result_log += "结果: 模型生成错误"
                elif not raw_response:
                    result_log += "结果: 模型无输出"
                else:
                    result_log += f"结果: 无法解析模型输出 ('{raw_response}')"
            log_message(f"{log_line_prefix}{result_log}", print_to_console=True)


# --- 函数：为 N-digit + N-digit 类别生成并运行测试 ---
def generate_and_run_specific_digit_sum_tests(num_digits_operands, num_tests, stats, batch_size_param, model, tokenizer, device):
    category_name = f"{num_digits_operands}-digit+{num_digits_operands}-digit" # 例如 "2-digit+2-digit"
    log_message(f"\n--- {category_name} 测试 ({num_tests}次) ---", print_to_console=True)
    test_cases = []
    
    if num_digits_operands == 1:
        lower_bound = 0
        upper_bound = 9
    else:
        lower_bound = 10**(num_digits_operands - 1)
        upper_bound = (10**num_digits_operands) - 1
    
    for _ in range(num_tests):
        x = random.randint(lower_bound, upper_bound)
        y = random.randint(lower_bound, upper_bound)
        test_cases.append((x,y))
    run_batch_test(test_cases, category_name, stats, batch_size_param, model, tokenizer, device)

# --- 辅助函数：用于排序类别名称 ---
def sort_key_func(k):
    # 匹配 "N-digit+M-digit"
    match_n_plus_m = re.search(r'(\d+)-digit\+(\d+)-digit', k)
    if match_n_plus_m:
        # 按第一个数字 (N) 排序, 然后按第二个数字 (M) (此处 M=N)
        return (0, int(match_n_plus_m.group(1)), int(match_n_plus_m.group(2)))

    # 匹配 "N-digit" (来自基本/进位测试)
    match_n_digit = re.search(r'(\d+)-digit', k)
    if match_n_digit:
        return (1, int(match_n_digit.group(1)))

    # 其他类别名称的回退排序
    return (2, k)

# --- 主测试循环 ---
def main():
    global current_log_file
    current_log_file = log_file_name
    log_message(f"测试日志将保存在: {log_file_name}", print_to_console=True)
    log_message(f"测试开始时间: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", print_to_console=True)

    all_models_summary_stats = [] 

    for model_idx, current_model_id in enumerate(model_ids_to_test):
        log_message("\n" + "="*50, print_to_console=True)
        log_message(f"开始测试模型 {model_idx+1}/{len(model_ids_to_test)}: {current_model_id}", print_to_console=True)
        log_message("="*50, print_to_console=True)

        model = None
        tokenizer = None
        log_message(f"尝试使用设备: {device} (device_map='{device}' 将主导模型放置)", print_to_console=True)

        model_load_start_time = time.time()
        try:
            tokenizer = AutoTokenizer.from_pretrained(current_model_id)
            model = AutoModelForCausalLM.from_pretrained(
                current_model_id,
                torch_dtype=torch.bfloat16,
                device_map=device  
            )
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            model_load_time = time.time() - model_load_start_time
            log_message(f"成功加载模型和分词器: {current_model_id} (耗时: {model_load_time:.2f} 秒)", print_to_console=True)
            
            current_model_summary = {
                'model_id': current_model_id,
                'status': 'Success',
                'load_time_seconds': f"{model_load_time:.2f}",
                'overall_accuracy': 'N/A',
                'category_accuracies': {}
            }

        except Exception as e:
            model_load_time = time.time() - model_load_start_time
            log_message(f"错误：无法加载模型 '{current_model_id}'. 错误: {e} (尝试加载耗时: {model_load_time:.2f} 秒)", print_to_console=True)
            log_message("跳过此模型的测试。", print_to_console=True)
            all_models_summary_stats.append({
                'model_id': current_model_id,
                'status': 'Load Failed',
                'load_time_seconds': f"{model_load_time:.2f}",
                'overall_accuracy': 'N/A',
                'category_accuracies': {}
            })
            if model is not None: del model
            if tokenizer is not None: del tokenizer
            if torch.cuda.is_available(): torch.cuda.empty_cache()
            continue

        stats_by_digits = defaultdict(lambda: {'total': 0, 'correct': 0})
        log_message(f"\n开始大语言模型 ({current_model_id}) 加法能力测试 (每类随机测试 {num_random_tests} 次, 批量大小: {BATCH_SIZE})...", print_to_console=True)
        
        # --- 基本测试 ---
        log_message("\n--- 基本测试 ---", print_to_console=True)
        tests_basic_tuples = [(0, 0), (0, 1), (5, 8), (10, 25), (123, 0)]
        categorized_basic_tests = defaultdict(list)
        for x, y in tests_basic_tuples:
            category = get_digit_category(x, y) # 使用旧的分类方式
            categorized_basic_tests[category].append((x,y))

        for category, cases in categorized_basic_tests.items():
            log_message(f"处理基本测试类别: {category} ({len(cases)} 个案例)", print_to_console=True)
            run_batch_test(cases, category, stats_by_digits, BATCH_SIZE, model, tokenizer, device)

        # --- N-digit + N-digit 测试 (1+1 到 4+4位) ---
        log_message("\n--- N-digit + N-digit 加法专项测试 ---", print_to_console=True)
        for num_digits_val in range(1, 5): # 从1位数+1位数 到 4位数+4位数
            generate_and_run_specific_digit_sum_tests(num_digits_val, num_random_tests, stats_by_digits, BATCH_SIZE, model, tokenizer, device)

        # --- 进位复杂情况测试 ---
        log_message("\n--- 进位复杂情况测试 ---", print_to_console=True)
        tests_carry_tuples = [
            (9, 9), (8, 7), (99, 99), (88, 77), (999, 999), (1, 999),
            (9999, 9999), (1234, 8765), (99999, 99999), (1, 99999),
            (999999, 999999), (123456, 876543)
        ]
        categorized_carry_tests = defaultdict(list)
        for x, y in tests_carry_tuples:
            category = get_digit_category(x, y) # 使用旧的分类方式
            categorized_carry_tests[category].append((x,y))
        
        # 按位数类别排序进位测试以保持日志一致性
        # 使用新的 sort_key_func 来确保与 N-digit+N-digit 类别的排序兼容
        sorted_carry_test_items = sorted(categorized_carry_tests.items(), key=lambda item: sort_key_func(item[0]))
        for category, cases in sorted_carry_test_items:
            log_message(f"处理进位测试类别: {category} ({len(cases)} 个案例)", print_to_console=True)
            run_batch_test(cases, category, stats_by_digits, BATCH_SIZE, model, tokenizer, device)

        # --- 当前模型的最终统计 ---
        log_message("\n" + "="*40, print_to_console=True)
        log_message(f" 测试统计结果: {current_model_id}", print_to_console=True)
        log_message(" (按操作数位数或类型分类)", print_to_console=True) # 更新描述
        log_message("="*40, print_to_console=True)

        total_overall = 0
        correct_overall = 0
        
        category_accuracies_for_current_model = {}
        # 使用新的 sort_key_func 对类别进行数字和类型排序以供显示和存储
        sorted_category_keys_for_model = sorted(stats_by_digits.keys(), key=sort_key_func)

        for category in sorted_category_keys_for_model:
            stats = stats_by_digits[category]
            total = stats['total']
            correct = stats['correct']
            total_overall += total
            correct_overall += correct
            accuracy_str = "N/A"
            if total > 0:
                accuracy = (correct / total) * 100
                accuracy_str = f"{accuracy:.2f}%"
                log_message(f"类别: {category.rjust(18)} | 总数: {str(total).rjust(4)} | 正确: {str(correct).rjust(4)} | 准确率: {accuracy_str.rjust(7)}", print_to_console=True)
            else:
                # 仅当它是预期的类别时才记录，以避免混乱
                is_expected_category = category in categorized_basic_tests or \
                                     category in categorized_carry_tests or \
                                     any(f"{d}-digit+{d}-digit" == category for d in range(1,5)) or \
                                     any(f"{d}-digit" == category for d in range(1,7)) # 覆盖基本/进位可能生成的类别
                if is_expected_category:
                     log_message(f"类别: {category.rjust(18)} | 总数: {str(total).rjust(4)} | 正确: {str(correct).rjust(4)} | 准确率: {accuracy_str.rjust(7)}", print_to_console=True)
            category_accuracies_for_current_model[category] = accuracy_str
        
        current_model_summary['category_accuracies'] = category_accuracies_for_current_model

        log_message("-"*40, print_to_console=True)
        if total_overall > 0:
            overall_accuracy_val = (correct_overall / total_overall) * 100
            overall_accuracy_str = f"{overall_accuracy_val:.2f}%"
            log_message(f"总体 {' '.rjust(18)} | 总数: {str(total_overall).rjust(4)} | 正确: {str(correct_overall).rjust(4)} | 准确率: {overall_accuracy_str.rjust(7)}", print_to_console=True)
            current_model_summary['overall_accuracy'] = overall_accuracy_str
        else:
            log_message(f"总体 {' '.rjust(18)} | 总数: {str(0).rjust(4)} | 正确: {str(0).rjust(4)} | 准确率: {'N/A'.rjust(7)}", print_to_console=True)
            current_model_summary['overall_accuracy'] = 'N/A'
        
        all_models_summary_stats.append(current_model_summary)
        log_message("="*40, print_to_console=True)

        log_message(f"完成模型 {current_model_id} 的测试。卸载模型...", print_to_console=True)
        del model
        del tokenizer
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        log_message(f"模型 {current_model_id} 已卸载。", print_to_console=True)

    # --- 所有模型测试完成后的最终摘要表 ---
    log_message("\n\n" + "="*120, print_to_console=True)
    log_message(" " * 45 + "总体模型性能摘要" + " " * 45, print_to_console=True)
    log_message("="*120, print_to_console=True)
    
    max_model_id_len = max(len(s['model_id']) for s in all_models_summary_stats) if all_models_summary_stats else 20
    max_model_id_len = max(max_model_id_len, len("模型 ID")) 

    all_display_category_keys = set()
    for summary in all_models_summary_stats:
        if summary['status'] == 'Success':
            all_display_category_keys.update(summary['category_accuracies'].keys())
    
    # 使用新的 sort_key_func 对显示类别键进行排序
    sorted_display_category_keys = sorted(list(all_display_category_keys), key=sort_key_func)

    header_parts = [
        f"{'模型 ID'.ljust(max_model_id_len)}",
        f"{'状态'.ljust(12)}",
        f"{'加载时间(s)'.rjust(10)}",
        f"{'总体准确率'.rjust(12)}"
    ]
    # 调整类别列宽以适应 "N-digit+N-digit" 或 "N位+N位" 格式
    cat_col_width = 12 # 例如 "100.00%" 或 "N/A", "X位+X位"
    for cat_key in sorted_display_category_keys:
        # 为表头缩短类别键
        display_cat_key = cat_key.replace("-digit+-digit", "位+位").replace("-digit", "位")
        header_parts.append(f"{display_cat_key.rjust(cat_col_width)}")
    
    header = " | ".join(header_parts)
    log_message(header, print_to_console=True)
    log_message("-" * len(header), print_to_console=True)

    if not all_models_summary_stats:
        log_message("没有模型被测试或所有模型加载失败。", print_to_console=True)
    else:
        for summary in all_models_summary_stats:
            row_parts = [
                summary['model_id'].ljust(max_model_id_len),
                summary['status'].ljust(12),
                summary.get('load_time_seconds', 'N/A').rjust(10),
                summary.get('overall_accuracy', 'N/A').rjust(12)
            ]
            if summary['status'] != 'Load Failed':
                for cat_key in sorted_display_category_keys:
                    acc = summary.get('category_accuracies', {}).get(cat_key, 'N/A')
                    row_parts.append(acc.rjust(cat_col_width))
            else: 
                for _ in sorted_display_category_keys:
                    row_parts.append('N/A'.rjust(cat_col_width))
            
            log_message(" | ".join(row_parts), print_to_console=True)
    
    log_message("="*len(header), print_to_console=True)

    log_message("\n所有模型测试完成。", print_to_console=True)
    log_message(f"测试结束时间: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", print_to_console=True)
    log_message(f"完整日志保存在: {os.path.abspath(log_file_name)}", print_to_console=True)

if __name__ == "__main__":
    main()

: 