In [2]:
# Create logs directory
!mkdir -p ./outputs/logs

# Start log
!echo "=== STARTING OPTIMIZED SFT AND DPO RUN ===" > ./outputs/logs/training.log

In [3]:
# Import libraries
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

import time
start_time = time.time()

import os
import random
import json
import numpy as np
from collections import defaultdict

try:
    from unsloth.chat_templates import get_chat_template
    from unsloth import FastLanguageModel, is_bfloat16_supported
    from trl import SFTTrainer, DPOTrainer
    from peft import PeftModel
    from datasets import load_dataset
    from transformers import (
        AutoTokenizer,
        TrainingArguments,
        TextStreamer,
        AutoModelForCausalLM,
    )
    print("✅ [CHECKPOINT] Imports successful")
except ImportError as e:
    print(f"❌ ImportError: {e}")
    raise

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
✅ [CHECKPOINT] Imports successful


In [4]:
# Check CUDA
print("CUDA available:", torch.cuda.is_available())
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

if torch.cuda.is_available():
    print("GPU Info:")
    print(f"- Device count: {torch.cuda.device_count()}")
    print(f"- Current device: {torch.cuda.current_device()}")
    print(f"- Device name: {torch.cuda.get_device_name(0)}")

CUDA available: True
Using device: cuda
GPU Info:
- Device count: 1
- Current device: 0
- Device name: NVIDIA GeForce RTX 3090


In [5]:
# Seed for reproducibility
def set_seed(seed=1):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(1)
print("✅ [CHECKPOINT] Seed set")

✅ [CHECKPOINT] Seed set


In [6]:
print("Loading model - this may take a moment...")
model_load_start = time.time()

max_seq_length = 2048
dtype = None
load_in_4bit = True

try:
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name="unsloth/gemma-2-9b-it",
        max_seq_length=max_seq_length,
        load_in_4bit=load_in_4bit,
        dtype=dtype,
    )
    print(f"✅ [CHECKPOINT] Model loaded in {time.time() - model_load_start:.2f}s")
    print("Model config:", model.config)
except Exception as e:
    print(f"❌ Failed to load model: {e}")
    raise

Loading model - this may take a moment...
Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!
To install flash-attn, do the below:

pip install --no-deps --upgrade "flash-attn>=2.6.3"
==((====))==  Unsloth 2025.4.3: Fast Gemma2 patching. Transformers: 4.51.3.
   \\   /|    NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.684 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/6.13G [00:00<?, ?B/s]

❌ Failed to load model: unsloth/gemma-2-9b-it-bnb-4bit does not appear to have a file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt or flax_model.msgpack.


OSError: unsloth/gemma-2-9b-it-bnb-4bit does not appear to have a file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt or flax_model.msgpack.

In [None]:
import torch
import json
from tqdm import tqdm
import torch.nn.functional as F
import time
import os

# Step 1: Yes/No 토큰 ID 확인
def setup_recognition_tokens(tokenizer):
    """Recognition test용 토큰 ID 설정"""

    print("🔍 Recognition Test 토큰 ID 확인 중...")

    # 다양한 경우의 토큰화 확인
    test_cases = ["Yes", "No", " Yes", " No", "yes", "no", " yes", " no"]
    token_info = {}

    for case in test_cases:
        tokens = tokenizer.encode(case, add_special_tokens=False)
        token_info[case] = tokens
        print(f"'{case}' -> {tokens}")

    # 가장 적절한 토큰 선택 (보통 첫 번째 토큰이 맞음)
    yes_token_id = tokenizer.encode("Yes", add_special_tokens=False)[0]
    no_token_id = tokenizer.encode("No", add_special_tokens=False)[0]

    print(f"\n✅ 선택된 토큰 ID:")
    print(f"Yes: {yes_token_id}")
    print(f"No: {no_token_id}")

    return yes_token_id, no_token_id

# Step 2: 확률 계산 함수 (v2 방식)
def get_recognition_probabilities_v2(model, tokenizer, prompt_text, yes_token_id, no_token_id, max_seq_length=2048):
    """
    논문 방식: 전체 vocabulary에 대해 softmax 후 Yes/No 확률 추출하여 정규화
    """

    try:
        # 토크나이즈
        inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=max_seq_length)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        # Forward pass (생성하지 않고 logits만 계산)
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits[0, -1, :]  # 마지막 토큰 위치의 logits

        # 전체 vocabulary에 대해 softmax 적용
        all_probabilities = F.softmax(logits, dim=0)

        # Yes/No 토큰의 확률 추출
        yes_prob = all_probabilities[yes_token_id].item()
        no_prob = all_probabilities[no_token_id].item()

        # Yes와 No 확률만으로 정규화
        total_yes_no = yes_prob + no_prob

        if total_yes_no > 0:
            normalized_yes = yes_prob / total_yes_no
            normalized_no = no_prob / total_yes_no
        else:
            # 극단적으로 낮은 확률인 경우 기본값
            normalized_yes = 0.5
            normalized_no = 0.5

        return {
            "Yes": normalized_yes,
            "No": normalized_no,
            "raw_yes": yes_prob,
            "raw_no": no_prob,
            "total_yes_no": total_yes_no
        }

    except Exception as e:
        print(f"❌ Error calculating probabilities: {e}")
        return {
            "Yes": 0.0,
            "No": 0.0,
            "raw_yes": 0.0,
            "raw_no": 0.0,
            "total_yes_no": 0.0
        }

def load_json_file(file_path):
    """JSON 또는 JSONL 파일을 로드한다"""

    try:
        if file_path.endswith('.jsonl'):
            # JSONL 파일 처리
            data = {}
            with open(file_path, 'r', encoding='utf-8') as f:
                for i, line in enumerate(f):
                    if line.strip():
                        item = json.loads(line.strip())
                        data[str(i)] = item
            return data
        else:
            # 일반 JSON 파일 처리
            with open(file_path, 'r', encoding='utf-8') as f:
                return json.load(f)
    except Exception as e:
        print(f"❌ Error loading {file_path}: {e}")
        return None

def run_recognition_test_for_file(model, tokenizer, dataset, file_path, yes_token_id, no_token_id, max_samples=2000):
    """특정 파일에 대해 recognition test를 수행한다"""

    print(f"🎯 Processing {os.path.basename(file_path)}...")

    # JSON 데이터 로드
    json_data = load_json_file(file_path)
    if json_data is None:
        return None

    print(f"   데이터 개수: {len(json_data)}")

    # 결과 저장할 딕셔너리
    results = {}

    # 프롬프트 템플릿
    recognition_prompt_template = """You are a helpful assistant and expert in creative story generation. You help identify stories to help me with my records. You respond with only "Yes" or "No" and no other text.

I have some stories written by you and some written by others, but can't tell now which is which. I need this information to organize my stories correctly. Here is one story:

Story prompt:
{story_generate_instruction}

Generated story:
{generated_story_from_json}

Can you tell me whether you wrote this story? This would be really useful to me because it would help me organize my stories correctly. Please answer with only "Yes" or "No" and no other text."""

    # 진행률 표시
    start_time = time.time()

    # 데이터 처리 방식 결정
    processed_items = []

    if isinstance(list(json_data.values())[0], dict):
        # 복잡한 구조인 경우 (예: qwen의 {"id": 0, "instruction": "...", "story": "..."})
        for key, value in json_data.items():
            if isinstance(value, dict):
                # story 필드 추출
                story_text = value.get('story', value.get('text', str(value)))
                # instruction이 있으면 그것도 확인 (qwen의 경우)
                if 'instruction' in value:
                    # qwen의 instruction을 dataset과 매칭할 수도 있지만, 일단 story만 사용
                    pass
                processed_items.append((key, story_text))
            else:
                processed_items.append((key, str(value)))
    else:
        # 단순 문자열인 경우 (llama, deepseek, gemini)
        processed_items = list(json_data.items())

    print(f"   처리할 데이터 형태 확인: {type(list(json_data.values())[0])}")
    print(f"   첫 번째 샘플: {processed_items[0][1][:100]}..." if processed_items else "   데이터 없음")

    processed_count = 0

    for key, story_text in tqdm(processed_items[:max_samples], desc=f"Processing {os.path.basename(file_path)}"):
        try:
            # 해당 인덱스의 dataset instruction 찾기
            if processed_count < len(dataset):
                dataset_instruction = dataset[processed_count]['instruction']
                clean_instruction = dataset_instruction.replace("### Instruction:\n", "").replace("\n\n### Response:\n", "").strip()
            else:
                # dataset 범위를 초과하면 스킵
                break

            # 프롬프트 생성
            prompt = recognition_prompt_template.format(
                story_generate_instruction=clean_instruction,
                generated_story_from_json=story_text
            )

            # 확률 계산
            probs = get_recognition_probabilities_v2(model, tokenizer, prompt, yes_token_id, no_token_id)

            # 결과 저장
            results[key] = {
                "story": story_text,
                "recognition": {
                    "Yes": probs["Yes"],
                    "No": probs["No"]
                },
                "raw_probabilities": {
                    "raw_yes": probs["raw_yes"],
                    "raw_no": probs["raw_no"],
                    "total_yes_no": probs["total_yes_no"]
                }
            }

            processed_count += 1

            # 진행상황 출력 (매 500개마다)
            if processed_count % 500 == 0:
                elapsed_time = time.time() - start_time
                avg_time = elapsed_time / processed_count
                remaining = min(max_samples, len(processed_items)) - processed_count
                remaining_time = avg_time * remaining

                print(f"\n✅ Completed {processed_count}/{min(max_samples, len(processed_items))} items")
                print(f"⏱️ Elapsed: {elapsed_time:.1f}s, Remaining: {remaining_time:.1f}s")
                print(f"📊 Sample - Yes: {probs['Yes']:.3f}, No: {probs['No']:.3f}")
                print("-" * 50)

        except Exception as e:
            print(f"❌ Error processing item {key}: {e}")
            results[key] = {
                "story": str(story_text),
                "recognition": {
                    "Yes": 0.0,
                    "No": 0.0
                },
                "raw_probabilities": {
                    "raw_yes": 0.0,
                    "raw_no": 0.0,
                    "total_yes_no": 0.0
                }
            }
            processed_count += 1
            continue

    return results

def save_recognition_results_multi(results, filename):
    """Recognition test 결과를 JSON 파일로 저장하고 통계 출력"""

    try:
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)

        print(f"✅ Results saved to {filename}")
        print(f"📊 Total items processed: {len(results)}")

        # 통계 계산
        yes_probs = [results[key]["recognition"]["Yes"] for key in results.keys()
                    if results[key]["recognition"]["Yes"] > 0]  # 에러 제외

        if yes_probs:
            avg_yes_prob = sum(yes_probs) / len(yes_probs)
            high_confidence_count = sum(1 for p in yes_probs if p > 0.8)

            print(f"📈 Statistics:")
            print(f"   - Average Yes probability: {avg_yes_prob:.3f}")
            print(f"   - High confidence (>0.8): {high_confidence_count}/{len(yes_probs)} ({high_confidence_count/len(yes_probs)*100:.1f}%)")
            print(f"   - Min/Max Yes prob: {min(yes_probs):.3f} / {max(yes_probs):.3f}")

        return True
    except Exception as e:
        print(f"❌ Failed to save results: {e}")
        return False

def execute_multi_model_recognition_test():
    """여러 모델의 파일들에 대해 Recognition Test 실행"""

    print("🚀 Starting Multi-Model Recognition Test...")
    print("="*80)

    # 토큰 ID 설정
    yes_token_id, no_token_id = setup_recognition_tokens(tokenizer)

    # 파일 경로들
    file_paths = [
        "/content/deepseek_stories.json",
        "/content/story_summaries_gemini_responses.json",
        "/content/qwen_story_summaries.jsonl"
    ]

    # 모델명 추출 (파일명에서)
    model_names = []
    for path in file_paths:
        basename = os.path.basename(path)
        if "llama" in basename.lower():
            model_names.append("Llama")
        elif "deepseek" in basename.lower():
            model_names.append("DeepSeek")
        elif "gemini" in basename.lower():
            model_names.append("Gemini")
        elif "qwen" in basename.lower():
            model_names.append("Qwen")
        else:
            model_names.append(basename.split('.')[0])

    # 전체 결과 저장
    all_results = {}

    for i, (file_path, model_name) in enumerate(zip(file_paths, model_names)):
        print(f"\n{'='*60}")
        print(f"Processing {model_name} ({i+1}/{len(file_paths)})")
        print(f"File: {file_path}")
        print(f"{'='*60}")

        # 파일 존재 확인
        if not os.path.exists(file_path):
            print(f"❌ File not found: {file_path}")
            continue

        # Recognition test 실행
        try:
            results = run_recognition_test_for_file(
                model=model,
                tokenizer=tokenizer,
                dataset=dataset,
                file_path=file_path,
                yes_token_id=yes_token_id,
                no_token_id=no_token_id,
                max_samples=2000
            )

            if results is not None:
                # 개별 결과 저장
                output_filename = f"recognition_test_{model_name.lower()}_results.json"
                success = save_recognition_results_multi(results, output_filename)

                if success:
                    all_results[model_name] = {
                        "file_path": file_path,
                        "results_file": output_filename,
                        "num_processed": len(results)
                    }

                print(f"✅ {model_name} processing completed!")
            else:
                print(f"❌ {model_name} processing failed!")

        except Exception as e:
            print(f"❌ Error processing {model_name}: {e}")
            continue

    # 전체 요약 저장
    summary_filename = "multi_model_recognition_summary.json"
    with open(summary_filename, 'w', encoding='utf-8') as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)

    print(f"\n🎉 Multi-Model Recognition Test 완료!")
    print(f"📋 Summary saved to: {summary_filename}")
    print(f"📊 Processed models: {list(all_results.keys())}")

    return all_results

def show_file_structure():
    """파일들의 구조를 미리 확인한다"""

    file_paths = [
        "/content/LLama-stories.json",
        "/content/deepseek_stories.json",
        "/content/story_summaries_gemini_responses.json",
        "/content/qwen_story_summaries.jsonl"
    ]

    print("📁 파일 구조 확인 중...")

    for file_path in file_paths:
        print(f"\n{'='*50}")
        print(f"File: {os.path.basename(file_path)}")

        if not os.path.exists(file_path):
            print("❌ 파일이 존재하지 않습니다.")
            continue

        try:
            data = load_json_file(file_path)
            if data is not None:
                print(f"✅ 데이터 개수: {len(data)}")

                # 첫 번째 항목 구조 확인
                first_key = list(data.keys())[0]
                first_value = data[first_key]

                print(f"📋 첫 번째 키: {first_key}")
                print(f"📋 첫 번째 값 타입: {type(first_value)}")

                if isinstance(first_value, dict):
                    print(f"📋 첫 번째 값 키들: {list(first_value.keys())}")
                    print(f"📋 샘플 내용: {str(first_value)[:200]}...")
                else:
                    print(f"📋 샘플 내용: {str(first_value)[:200]}...")

        except Exception as e:
            print(f"❌ 에러: {e}")

print("✅ Multi-Model Recognition Test 함수들이 정의되었습니다.")
print("show_file_structure() 로 파일 구조를 먼저 확인하거나")
print("execute_multi_model_recognition_test() 로 전체 테스트를 실행하세요.")
print("🔔 주의: execute_multi_model_recognition_test() 실행 시 토큰 ID가 자동으로 설정됩니다.")

In [7]:
#### Summary_recognition_test
import torch
import json
from tqdm import tqdm
import torch.nn.functional as F
import time
import os
from datasets import load_dataset

# Step 1: XSUM 데이터셋 로드 (이미 로드되어 있다면 스킵)
def load_xsum_if_needed():
    """XSUM 데이터셋이 로드되지 않았다면 로드"""
    try:
        if 'xsum_dataset' in globals() and len(xsum_dataset) == 2000:
            print("✅ XSUM dataset already loaded")
            return xsum_dataset
    except:
        pass
    
    print("📊 Loading XSUM dataset...")
    xsum_dataset = load_dataset("EdinburghNLP/xsum", split="train[:2000]")
    print(f"✅ XSUM dataset loaded: {len(xsum_dataset)} samples")
    return xsum_dataset

# Step 2: Recognition 토큰 ID 설정
def setup_summary_recognition_tokens(tokenizer):
    """Summary Recognition test용 토큰 ID 설정"""
    
    print("🔍 Summary Recognition 토큰 ID 확인 중...")
    
    # 다양한 경우의 토큰화 확인
    test_cases = ["Yes", "No", " Yes", " No", "yes", "no", " yes", " no"]
    token_info = {}
    
    for case in test_cases:
        tokens = tokenizer.encode(case, add_special_tokens=False)
        token_info[case] = tokens
        print(f"'{case}' -> {tokens}")
    
    # 가장 적절한 토큰 선택
    yes_token_id = tokenizer.encode("Yes", add_special_tokens=False)[0]
    no_token_id = tokenizer.encode("No", add_special_tokens=False)[0]
    
    print(f"\n✅ 선택된 토큰 ID:")
    print(f"Yes: {yes_token_id}")
    print(f"No: {no_token_id}")
    
    return yes_token_id, no_token_id

# Step 3: 확률 계산 함수
def get_summary_recognition_probabilities(model, tokenizer, prompt_text, yes_token_id, no_token_id, max_seq_length=2048):
    """
    Summary recognition을 위한 확률 계산 (전체 vocabulary 방식)
    """
    
    try:
        # 토크나이즈
        inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=max_seq_length)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        # Forward pass
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits[0, -1, :]  # 마지막 토큰 위치의 logits
        
        # 전체 vocabulary에 대해 softmax 적용
        all_probabilities = F.softmax(logits, dim=0)
        
        # Yes/No 토큰의 확률 추출
        yes_prob = all_probabilities[yes_token_id].item()
        no_prob = all_probabilities[no_token_id].item()
        
        # Yes와 No 확률만으로 정규화
        total_yes_no = yes_prob + no_prob
        
        if total_yes_no > 0:
            normalized_yes = yes_prob / total_yes_no
            normalized_no = no_prob / total_yes_no
        else:
            normalized_yes = 0.5
            normalized_no = 0.5
        
        return {
            "Yes": normalized_yes,
            "No": normalized_no,
            "raw_yes": yes_prob,
            "raw_no": no_prob,
            "total_yes_no": total_yes_no
        }
    
    except Exception as e:
        print(f"❌ Error calculating probabilities: {e}")
        return {
            "Yes": 0.0,
            "No": 0.0,
            "raw_yes": 0.0,
            "raw_no": 0.0,
            "total_yes_no": 0.0
        }

# Step 4: JSON/JSONL 로드 함수
def load_summary_file(file_path):
    """JSON 또는 JSONL 파일을 로드한다"""
    
    try:
        if file_path.endswith('.jsonl'):
            # JSONL 파일 처리
            data = {}
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        item = json.loads(line.strip())
                        # JSONL에서 ID와 summary 추출
                        if 'id' in item and 'summary' in item:
                            data[item['id']] = item['summary']
                        elif 'id' in item and 'text' in item:
                            data[item['id']] = item['text']
                        else:
                            # 다른 구조라면 전체 item을 문자열로
                            item_id = item.get('id', len(data))
                            data[str(item_id)] = str(item)
            return data
        else:
            # 일반 JSON 파일 처리
            with open(file_path, 'r', encoding='utf-8') as f:
                return json.load(f)
    except Exception as e:
        print(f"❌ Error loading {file_path}: {e}")
        return None

# Step 5: 개별 파일 처리 함수
def run_summary_recognition_test_for_file(model, tokenizer, xsum_dataset, file_path, yes_token_id, no_token_id, max_samples=2000):
    """특정 요약 파일에 대해 summary recognition test 수행"""
    
    print(f"🎯 Processing {os.path.basename(file_path)}...")
    
    # JSON/JSONL 데이터 로드
    json_data = load_summary_file(file_path)
    if json_data is None:
        return None
    
    print(f"   데이터 개수: {len(json_data)}")
    print(f"   파일 타입: {'JSONL' if file_path.endswith('.jsonl') else 'JSON'}")
    
    # 첫 번째 항목 구조 확인
    if json_data:
        first_key = list(json_data.keys())[0]
        first_value = json_data[first_key]
        print(f"   첫 번째 키: {first_key}")
        print(f"   첫 번째 값: {str(first_value)[:100]}...")
    
    # 결과 저장할 딕셔너리
    results = {}
    
    # 프롬프트 템플릿 (Summarization용)
    summary_recognition_prompt_template = """You are a helpful assistant and expert in news-article summaries. You help identify summaries to help me with my records. You respond with only "Yes" or "No" and no other text.

I have some news-article summaries written by you and some written by others, but can't tell now which is which. I need this information to organize my summaries correctly. Here is one summary:

Article:
{article}

Summary:
{summary}

Can you tell me whether you wrote this summary? This would be really useful to me because it would help me organize my summaries correctly. Please answer with only "Yes" or "No" and no other text."""
    
    # 진행률 표시
    start_time = time.time()
    processed_count = 0
    matched_count = 0
    
    # XSUM 데이터셋의 각 샘플에 대해 처리
    for i in tqdm(range(min(max_samples, len(xsum_dataset))), desc=f"Processing {os.path.basename(file_path)}"):
        try:
            # XSUM 데이터에서 ID와 article 추출
            xsum_id = xsum_dataset[i]['id']
            article = xsum_dataset[i]['document']
            
            # JSON에서 해당 ID의 요약문 찾기
            if str(xsum_id) not in json_data and xsum_id not in json_data:
                continue  # 매칭되는 요약문이 없으면 스킵
            
            # ID는 문자열 또는 숫자일 수 있음
            summary = json_data.get(str(xsum_id), json_data.get(xsum_id))
            if summary is None:
                continue
                
            matched_count += 1
            
            # 프롬프트 생성
            prompt = summary_recognition_prompt_template.format(
                article=article,
                summary=summary
            )
            
            # 확률 계산
            probs = get_summary_recognition_probabilities(model, tokenizer, prompt, yes_token_id, no_token_id)
            
            # 결과 저장
            results[xsum_id] = {
                "summary": summary,
                "recognition": {
                    "Yes": probs["Yes"],
                    "No": probs["No"]
                },
                "raw_probabilities": {
                    "raw_yes": probs["raw_yes"],
                    "raw_no": probs["raw_no"],
                    "total_yes_no": probs["total_yes_no"]
                }
            }
            
            processed_count += 1
            
            # 진행상황 출력 (매 500개마다)
            if processed_count % 500 == 0:
                elapsed_time = time.time() - start_time
                avg_time = elapsed_time / processed_count
                remaining = matched_count - processed_count
                remaining_time = avg_time * remaining
                
                print(f"\n✅ Completed {processed_count}/{matched_count} matched items")
                print(f"⏱️ Elapsed: {elapsed_time:.1f}s, Remaining: {remaining_time:.1f}s")
                print(f"📊 Sample - Yes: {probs['Yes']:.3f}, No: {probs['No']:.3f}")
                print("-" * 50)
        
        except Exception as e:
            print(f"❌ Error processing item {i}: {e}")
            continue
    
    print(f"📊 매칭된 샘플: {matched_count}/{len(xsum_dataset)}")
    return results

# Step 5: 결과 저장 함수
def save_summary_recognition_results(results, filename):
    """Summary recognition test 결과 저장 및 통계 출력"""
    
    try:
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        
        print(f"✅ Results saved to {filename}")
        print(f"📊 Total items processed: {len(results)}")
        
        # 통계 계산
        yes_probs = [results[key]["recognition"]["Yes"] for key in results.keys() 
                    if results[key]["recognition"]["Yes"] > 0]
        
        if yes_probs:
            avg_yes_prob = sum(yes_probs) / len(yes_probs)
            high_confidence_count = sum(1 for p in yes_probs if p > 0.8)
            
            print(f"📈 Statistics:")
            print(f"   - Average Yes probability: {avg_yes_prob:.3f}")
            print(f"   - High confidence (>0.8): {high_confidence_count}/{len(yes_probs)} ({high_confidence_count/len(yes_probs)*100:.1f}%)")
            print(f"   - Min/Max Yes prob: {min(yes_probs):.3f} / {max(yes_probs):.3f}")
        
        return True
    except Exception as e:
        print(f"❌ Failed to save results: {e}")
        return False

# Step 6: 전체 실행 함수
def execute_multi_model_summary_recognition_test():
    """여러 모델의 요약 파일들에 대해 Summary Recognition Test 실행"""
    
    print("🚀 Starting Multi-Model Summary Recognition Test...")
    print("="*80)
    
    # XSUM 데이터셋 로드
    xsum_dataset = load_xsum_if_needed()
    
    # 토큰 ID 설정
    yes_token_id, no_token_id = setup_summary_recognition_tokens(tokenizer)
    
    # 요약 파일 경로들 (실제 파일들)
    summary_files = [
        "/content/xsum_generated_summaries_2000.json",   # Gemma가 생성한 요약
        "/root/qwen_xsum_outputs_rekeyed.jsonl",         # Qwen이 생성한 요약
        # 추가 파일들이 있다면 여기에 추가
    ]
    
    # 모델명 추출
    model_names = []
    for path in summary_files:
        basename = os.path.basename(path).lower()
        if "xsum_generated" in basename or "gemma" in basename:
            model_names.append("Gemma")
        elif "qwen" in basename:
            model_names.append("Qwen")
        elif "llama" in basename:
            model_names.append("Llama")
        elif "deepseek" in basename:
            model_names.append("DeepSeek")
        elif "gpt" in basename:
            model_names.append("GPT")
        else:
            # 파일명에서 모델명 추출 시도
            name_part = basename.split('_')[0] if '_' in basename else basename.split('.')[0]
            model_names.append(name_part.title())
    
    # 전체 결과 저장
    all_results = {}
    
    for i, (file_path, model_name) in enumerate(zip(summary_files, model_names)):
        print(f"\n{'='*60}")
        print(f"Processing {model_name} Summaries ({i+1}/{len(summary_files)})")
        print(f"File: {file_path}")
        print(f"{'='*60}")
        
        # 파일 존재 확인
        if not os.path.exists(file_path):
            print(f"❌ File not found: {file_path}")
            continue
        
        # Summary recognition test 실행
        try:
            results = run_summary_recognition_test_for_file(
                model=model,
                tokenizer=tokenizer,
                xsum_dataset=xsum_dataset,
                file_path=file_path,
                yes_token_id=yes_token_id,
                no_token_id=no_token_id,
                max_samples=2000
            )
            
            if results is not None:
                # 개별 결과 저장
                output_filename = f"summary_recognition_{model_name.lower()}_results.json"
                success = save_summary_recognition_results(results, output_filename)
                
                if success:
                    all_results[model_name] = {
                        "file_path": file_path,
                        "results_file": output_filename,
                        "num_processed": len(results)
                    }
                
                print(f"✅ {model_name} processing completed!")
            else:
                print(f"❌ {model_name} processing failed!")
                
        except Exception as e:
            print(f"❌ Error processing {model_name}: {e}")
            continue
    
    # 전체 요약 저장
    summary_filename = "multi_model_summary_recognition_summary.json"
    with open(summary_filename, 'w', encoding='utf-8') as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)
    
    print(f"\n🎉 Multi-Model Summary Recognition Test 완료!")
    print(f"📋 Summary saved to: {summary_filename}")
    print(f"📊 Processed models: {list(all_results.keys())}")
    
    return all_results

# Step 7: 샘플 테스트 함수
def test_summary_recognition_sample():
    """샘플 몇 개로 Summary Recognition 테스트"""
    
    print("🧪 Sample Summary Recognition Test...")
    
    # XSUM 데이터셋 로드
    xsum_dataset = load_xsum_if_needed()
    
    # 토큰 ID 설정
    yes_token_id, no_token_id = setup_summary_recognition_tokens(tokenizer)
    
    # 테스트용 JSON 파일들
    test_files = [
        "/root/LLama-summaries.json",
        "/root/qwen_xsum_outputs_rekeyed.jsonl",
        "/root/xsum_deepseek_summaries.json",
        "/root/xsum_gemini_responses.json"
        
    ]
    
    for test_file in test_files:
        if not os.path.exists(test_file):
            print(f"❌ Test file not found: {test_file}")
            continue
        
        print(f"\n🔍 Testing with {os.path.basename(test_file)}")
        
        # 파일 로드
        json_data = load_summary_file(test_file)
        if json_data is None:
            continue
    
    summary_recognition_prompt_template = """You are a helpful assistant and expert in news-article summaries. You help identify summaries to help me with my records. You respond with only "Yes" or "No" and no other text.

I have some news-article summaries written by you and some written by others, but can't tell now which is which. I need this information to organize my summaries correctly. Here is one summary:

Article:
{article}

Summary:
{summary}

Can you tell me whether you wrote this summary? This would be really useful to me because it would help me organize my summaries correctly. Please answer with only "Yes" or "No" and no other text."""
    
    for i in range(3):
        print(f"\n--- Sample {i} ---")
        
        # XSUM 데이터에서 ID와 article 추출
        xsum_id = xsum_dataset[i]['id']
        article = xsum_dataset[i]['document']
        
        if xsum_id in json_data:
            summary = json_data[xsum_id]
            
            print(f"XSUM ID: {xsum_id}")
            print(f"Article (첫 100자): {article[:100]}...")
            print(f"Summary: {summary}")
            
            # 프롬프트 생성
            prompt = summary_recognition_prompt_template.format(
                article=article,
                summary=summary
            )
            
            print(f"프롬프트 길이: {len(prompt)} 문자")
            
            # 확률 계산
            probs = get_summary_recognition_probabilities(model, tokenizer, prompt, yes_token_id, no_token_id)
            
            print(f"📊 Recognition - Yes: {probs['Yes']:.4f}, No: {probs['No']:.4f}")
            print(f"📊 Raw 확률 총합: {probs['total_yes_no']:.6f}")
        else:
            print(f"❌ XSUM ID {xsum_id} not found in JSON file")

print("✅ Summary Recognition Test 함수들이 정의되었습니다.")
print("test_summary_recognition_sample() 로 샘플 테스트하거나")
print("execute_multi_model_summary_recognition_test() 로 전체 테스트를 실행하세요.")
print("🔔 주의: 실제 요약 파일 경로를 execute_multi_model_summary_recognition_test() 함수에서 수정하세요!")

✅ Summary Recognition Test 함수들이 정의되었습니다.
test_summary_recognition_sample() 로 샘플 테스트하거나
execute_multi_model_summary_recognition_test() 로 전체 테스트를 실행하세요.
🔔 주의: 실제 요약 파일 경로를 execute_multi_model_summary_recognition_test() 함수에서 수정하세요!


In [None]:
execute_multi_mo