## Build Training Corpora

Finally, we build two parallel training corpora:

1. **Baseline (Track A)** – pairs of `(question → answer)` for training a basic model.
2. **CoT (Track B)** – pairs of `(question + teacher chain-of-thought → answer)` for CoT distillation.

These files will be used in later steps for model fine‑tuning.


In [5]:
import os
from dotenv import load_dotenv

# Load environment variables from .env file if it exists
load_dotenv()

# Dataset parameters
DATASET_NAME = os.getenv('DATASET_NAME', 'voidful/StrategyQA')
TRAIN_SAMPLES = int(os.getenv('TRAIN_SAMPLES', '100'))
RANDOM_SEED = int(os.getenv('RANDOM_SEED', '42'))
USE_FULL_DATASET = os.getenv('USE_FULL_DATASET', 'False').lower() in ('true', '1', 't')


# Model parameters
MODEL_NAME = os.getenv('MODEL_NAME', 'microsoft/phi-2')
MAX_NEW_TOKENS = int(os.getenv('MAX_NEW_TOKENS', '35'))
BATCH_SIZE = int(os.getenv('BATCH_SIZE', '8'))
USE_4BIT = os.getenv('USE_4BIT', 'True').lower() in ('true', '1', 't')
MAX_SEQ_LENGTH = int(os.getenv('MAX_SEQ_LENGTH', '512'))
HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN', '')

# Generation parameters
DO_SAMPLE = os.getenv('DO_SAMPLE', 'False').lower() in ('true', '1', 't')
TEMPERATURE = float(os.getenv('TEMPERATURE', '0.7'))

# GPT-4 parameters
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', '')
GPT4_MODEL = os.getenv('GPT4_MODEL', 'gpt-4')
GPT4_MAX_TOKENS = int(os.getenv('GPT4_MAX_TOKENS', '150'))
GPT4_TEMPERATURE = float(os.getenv('GPT4_TEMPERATURE', '0.3'))
DRY_RUN = os.getenv('DRY_RUN', 'True').lower() in ('true', '1', 't')

# Student Draft Generation
STUDENT_MAX_TOKENS = int(os.getenv('STUDENT_MAX_TOKENS', '200'))
STUDENT_TEMPERATURE = float(os.getenv('STUDENT_TEMPERATURE', '0.7'))
STUDENT_BATCH_SIZE = int(os.getenv('STUDENT_BATCH_SIZE', '8'))

# Enhanced Evaluation Generation
EVAL_MAX_TOKENS = int(os.getenv('EVAL_MAX_TOKENS', '256'))
EVAL_TEMPERATURE = float(os.getenv('EVAL_TEMPERATURE', '0.7'))
EVAL_BATCH_SIZE = int(os.getenv('EVAL_BATCH_SIZE', '4'))

# Quick Evaluation
QUICK_EVAL_MAX_TOKENS = int(os.getenv('QUICK_EVAL_MAX_TOKENS', '5'))

# Training Configuration
# Phase A Training
PHASE_A_EPOCHS = int(os.getenv('PHASE_A_EPOCHS', '3'))
PHASE_A_BATCH_SIZE = int(os.getenv('PHASE_A_BATCH_SIZE', '1'))
PHASE_A_LEARNING_RATE = float(os.getenv('PHASE_A_LEARNING_RATE', '1e-4'))
PHASE_A_WARMUP_RATIO = float(os.getenv('PHASE_A_WARMUP_RATIO', '0.1'))
PHASE_A_WEIGHT_DECAY = float(os.getenv('PHASE_A_WEIGHT_DECAY', '0.01'))

# Phase B Training
PHASE_B_EPOCHS = int(os.getenv('PHASE_B_EPOCHS', '3'))
PHASE_B_BATCH_SIZE = int(os.getenv('PHASE_B_BATCH_SIZE', '4'))
PHASE_B_LEARNING_RATE = float(os.getenv('PHASE_B_LEARNING_RATE', '5e-5'))
PHASE_B_WARMUP_RATIO = float(os.getenv('PHASE_B_WARMUP_RATIO', '0.1'))
PHASE_B_WEIGHT_DECAY = float(os.getenv('PHASE_B_WEIGHT_DECAY', '0.01'))

# Progressive Curriculum Training
CURRICULUM_STAGE1_EPOCHS = int(os.getenv('CURRICULUM_STAGE1_EPOCHS', '1'))
CURRICULUM_STAGE1_LEARNING_RATE = float(os.getenv('CURRICULUM_STAGE1_LEARNING_RATE', '5e-5'))
CURRICULUM_STAGE1_WARMUP_RATIO = float(os.getenv('CURRICULUM_STAGE1_WARMUP_RATIO', '0.1'))
CURRICULUM_STAGE1_WEIGHT_DECAY = float(os.getenv('CURRICULUM_STAGE1_WEIGHT_DECAY', '0.01'))

CURRICULUM_STAGE2_EPOCHS = int(os.getenv('CURRICULUM_STAGE2_EPOCHS', '2'))
CURRICULUM_STAGE2_LEARNING_RATE = float(os.getenv('CURRICULUM_STAGE2_LEARNING_RATE', '3e-5'))
CURRICULUM_STAGE2_WARMUP_RATIO = float(os.getenv('CURRICULUM_STAGE2_WARMUP_RATIO', '0.1'))
CURRICULUM_STAGE2_WEIGHT_DECAY = float(os.getenv('CURRICULUM_STAGE2_WEIGHT_DECAY', '0.01'))

# Validation Configuration
HIGH_CONFIDENCE_THRESHOLD = float(os.getenv('HIGH_CONFIDENCE_THRESHOLD', '0.8'))
MEDIUM_CONFIDENCE_THRESHOLD = float(os.getenv('MEDIUM_CONFIDENCE_THRESHOLD', '0.5'))
LOW_CONFIDENCE_THRESHOLD = float(os.getenv('LOW_CONFIDENCE_THRESHOLD', '0.3'))
VALIDATION_ACCEPTANCE_THRESHOLD = float(os.getenv('VALIDATION_ACCEPTANCE_THRESHOLD', '0.3'))

# Quality Thresholds
VALID_QUALITY_THRESHOLD = float(os.getenv('VALID_QUALITY_THRESHOLD', '0.5'))
CORRECTED_QUALITY_THRESHOLD = float(os.getenv('CORRECTED_QUALITY_THRESHOLD', '0.3'))

# Token Emphasis Configuration
EMPHASIS_MULTIPLIER = float(os.getenv('EMPHASIS_MULTIPLIER', '2.5'))
ADAPTIVE_EMPHASIS = os.getenv('ADAPTIVE_EMPHASIS', 'True').lower() in ('true', '1', 't')

# Memory and Performance
GRADIENT_CHECKPOINTING = os.getenv('GRADIENT_CHECKPOINTING', 'True').lower() in ('true', '1', 't')
FP16 = os.getenv('FP16', 'False').lower() in ('true', '1', 't')
BF16 = os.getenv('BF16', 'True').lower() in ('true', '1', 't')
DATALOADER_NUM_WORKERS = int(os.getenv('DATALOADER_NUM_WORKERS', '0'))
DATALOADER_PERSISTENT_WORKERS = os.getenv('DATALOADER_PERSISTENT_WORKERS', 'False').lower() in ('true', '1', 't')
SKIP_MEMORY_METRICS = os.getenv('SKIP_MEMORY_METRICS', 'True').lower() in ('true', '1', 't')

# Logging and Monitoring
LOGGING_STEPS = int(os.getenv('LOGGING_STEPS', '10'))
SAVE_STRATEGY = os.getenv('SAVE_STRATEGY', 'epoch')
REPORT_TO = os.getenv('REPORT_TO', 'none')
LOAD_BEST_MODEL_AT_END = os.getenv('LOAD_BEST_MODEL_AT_END', 'False').lower() in ('true', '1', 't')

# Evaluation Configuration
EVAL_SIZE = int(os.getenv('EVAL_SIZE', '100'))
ENHANCED_EVAL_BATCH_SIZE = int(os.getenv('ENHANCED_EVAL_BATCH_SIZE', '4'))
# File paths
parent_dir = os.path.dirname(os.getcwd())
DATA_DIR = os.path.join(parent_dir, os.getenv("DATA_DIR", "data"))
RAW_DIR = os.path.join(DATA_DIR, 'raw')
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
STUDENT_DIR = os.path.join(DATA_DIR, 'student')
TEACHER_DIR = os.path.join(DATA_DIR, 'teacher')
SAMPLE_TRAIN_PATH = os.path.join(TRAIN_DIR, 'sample_train.jsonl')
STUDENT_DRAFTS_PATH = os.path.join(STUDENT_DIR, 'student_drafts.jsonl')
CLEANED_STUDENT_DRAFTS_PATH = os.path.join(STUDENT_DIR, 'cleaned_student_drafts.jsonl')
TEACHER_OUTPUTS_PATH = os.path.join(TEACHER_DIR, 'teacher_outputs.jsonl')
BASELINE_PATH = os.path.join(TRAIN_DIR, 'train_baseline.jsonl')
COT_PATH = os.path.join(TRAIN_DIR, 'train_cot.jsonl')
COT_PATH_QA_COT = os.path.join(TEACHER_DIR, 'teacher_outputs_qa_cot.jsonl')
SAMPLE_TEST_PATH = SAMPLE_TRAIN_PATH  # Alias for consistency with Build Training Corpora cell
STAGE1_PATH = os.path.join(TRAIN_DIR, 'stage1_train.jsonl.jsonl')
STAGE2_PATH = os.path.join(TRAIN_DIR,'stage2_train.jsonl')


# Print configuration
print("=== Configuration ===")
print(f"Dataset: {DATASET_NAME}")
print(f"Model: {MODEL_NAME}")
print(f"Batch size: {BATCH_SIZE}")
print(f"4-bit quantization: {USE_4BIT}")
print(f"GPT-4 dry run: {DRY_RUN}")
print("=="*10)

=== Configuration ===
Dataset: voidful/StrategyQA
Model: microsoft/Phi-3.5-mini-instruct
Batch size: 8
4-bit quantization: True
GPT-4 dry run: False


In [6]:
import json
import random
from collections import Counter
from typing import Dict, List, Any, Optional, Tuple
import logging
from dataclasses import dataclass
from enum import Enum
import re

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class ValidationStatus(Enum):
    VALID = "valid"
    CORRECTED = "corrected" 
    INVALID = "invalid"

@dataclass
class ValidationResult:
    status: ValidationStatus
    original_text: str
    cleaned_text: Optional[str]
    confidence_score: float  # 0.0 - 1.0
    error_messages: List[str]
    metadata: Dict[str, Any]

    def is_valid(self) -> bool:
        return self.status in [ValidationStatus.VALID, ValidationStatus.CORRECTED]


class ResponseValidationPipeline:
    """Professional validation pipeline with multi-layered validation and confidence scoring."""
    
    def __init__(self):
        """Initialize pipeline with professional standards."""
        self.teacher_validator = TeacherResponseValidator()
        self.HIGH_CONFIDENCE_THRESHOLD = 0.8
        self.MEDIUM_CONFIDENCE_THRESHOLD = 0.5

    def process_responses_for_progressive_curriculum(self, teacher_data: List[Dict], ground_truth_map: Dict[str, bool]) -> Tuple[List[Dict], List[Dict], List[Dict], Dict[str, Any]]:
        """Process teacher responses for Progressive Curriculum Training (Stage 1 & Stage 2)."""

        # Initialize comprehensive statistics
        validation_stats = {
            'total_processed': 0,
            'ground_truth_missing': 0,
            'data_kept': 0,

            # Validation tiers
            'high_confidence': 0,    # >= 0.8
            'medium_confidence': 0,  # 0.5 - 0.8
            'low_confidence': 0,     # < 0.5

            # Answer validation
            'valid_teacher_answers': 0,
            'invalid_teacher_answers': 0,
            'teacher_agrees': 0,
            'teacher_disagrees': 0,

            # Q&A-CoT format tracking
            'qa_cot_format': 0,
            'traditional_format': 0,

            # Error tracking 
            'validation_errors': Counter(),
            'correction_attempts': 0,
            'successful_corrections': 0,
            'confidence_scores': [],
            
            # Template fallback tracking
            'template_fallbacks': 0,  # Track samples that used template fallback
        }

        baseline_records = []
        stage1_records = []  # Final reasoning only
        stage2_records = []  # Complete Q&A with final reasoning
        HIGH_CONFIDENCE_THRESHOLD = self.HIGH_CONFIDENCE_THRESHOLD

        for rec in teacher_data:
            q = rec['question']
            thought = rec['teacher_thought']
            teacher_answer = rec['teacher_answer']
            format_type = rec.get('format_type', 'unknown')

            validation_stats['total_processed'] += 1

            # Track format types
            if format_type == 'qa_interleaved':
                validation_stats['qa_cot_format'] += 1
            else:
                validation_stats['traditional_format'] += 1

            # Check ground truth availability
            if q not in ground_truth_map:
                validation_stats['ground_truth_missing'] += 1
                continue

            ground_truth_answer = self._convert_to_yes_no(ground_truth_map[q])

            # Use existing validation metadata if available
            if 'validation_metadata' in rec:
                validation_result = self._create_validation_result_from_metadata(rec['validation_metadata'], thought)
            else:
                # Apply professional validation
                validation_result = self.teacher_validator.validate(thought)

            validation_stats['confidence_scores'].append(validation_result.confidence_score)

            # Track validation status
            if validation_result.status == ValidationStatus.VALID:
                validation_stats['high_confidence'] += 1 if validation_result.confidence_score >= HIGH_CONFIDENCE_THRESHOLD else 0
                validation_stats['medium_confidence'] += 1 if 0.5 <= validation_result.confidence_score < 0.8 else 0
            elif validation_result.status == ValidationStatus.CORRECTED:
                validation_stats['successful_corrections'] += 1
                validation_stats['medium_confidence'] += 1
            else:
                validation_stats['low_confidence'] += 1

            # Track errors
            for error in validation_result.error_messages:
                validation_stats['validation_errors'][error] += 1

            # Validate teacher answer format
            if not self._is_valid_answer(teacher_answer):
                validation_stats['invalid_teacher_answers'] += 1
                continue
            else:
                validation_stats['valid_teacher_answers'] += 1

            # Track teacher-ground truth agreement
            teacher_answer_clean = teacher_answer.strip().capitalize()
            if teacher_answer_clean == ground_truth_answer:
                validation_stats['teacher_agrees'] += 1
            else:
                validation_stats['teacher_disagrees'] += 1

            # Confidence-based processing decision
            confidence_tier = self._get_confidence_tier(validation_result.confidence_score)

            if confidence_tier in ['high', 'medium']:
                # Accept high and medium confidence responses
                validation_stats['data_kept'] += 1

                # Use validated response if available
                final_thought = validation_result.cleaned_text if validation_result.cleaned_text else thought

                # Create baseline record (question → teacher_answer)
                baseline_record = {
                    'prompt': q,
                    'answer': teacher_answer_clean,
                    'validation_metadata': {
                        'confidence_score': validation_result.confidence_score,
                        'validation_status': validation_result.status.value,
                        'confidence_tier': confidence_tier,
                        'format_type': format_type
                    }
                }
                baseline_records.append(baseline_record)

                # Create Progressive Curriculum datasets - CRITICAL FIX: Pass teacher_answer_clean instead of teacher_answer
                stage1_prompt, stage1_answer = self._create_stage1_dataset_entry(q, final_thought, teacher_answer_clean)
                stage2_prompt, stage2_answer = self._create_stage2_dataset_entry(q, final_thought, teacher_answer_clean)

                # Stage 1: Final reasoning only
                stage1_record = {
                    'prompt': stage1_prompt,
                    'answer': stage1_answer,
                    'stage': 'stage_1',
                    'original_teacher_answer': teacher_answer_clean,  # Track original answer for debugging
                    'validation_metadata': {
                        'confidence_score': validation_result.confidence_score,
                        'validation_status': validation_result.status.value,
                        'confidence_tier': confidence_tier,
                        'format_type': format_type
                    }
                }
                stage1_records.append(stage1_record)

                # Stage 2: Complete Q&A with final reasoning
                stage2_record = {
                    'prompt': stage2_prompt,
                    'answer': stage2_answer,
                    'stage': 'stage_2',
                    'original_teacher_answer': teacher_answer_clean,  # Track original answer for debugging
                    'validation_metadata': {
                        'confidence_score': validation_result.confidence_score,
                        'validation_status': validation_result.status.value,
                        'confidence_tier': confidence_tier,
                        'format_type': format_type
                    }
                }
                stage2_records.append(stage2_record)

            # Progress reporting
            if validation_stats['total_processed'] % 100 == 0:
                print(f"Processed {validation_stats['total_processed']} responses...")

        # Calculate final statistics
        self._calculate_final_stats(validation_stats)

        return baseline_records, stage1_records, stage2_records, validation_stats

    def _create_stage1_dataset_entry(self, question: str, teacher_thought: str, teacher_answer: str) -> Tuple[str, str]:
        """Create Stage 1 entry: Focus on final reasoning after 'Therefore'."""
        
        # Extract final reasoning (everything after "Therefore")
        therefore_match = re.search(r'Therefore[,:]\s*(.*?)(?=The answer is|$)', teacher_thought, re.DOTALL | re.IGNORECASE)
        
        if therefore_match:
            final_reasoning = therefore_match.group(1).strip()
        else:
            # Fallback: use the last paragraph or sentence if no "Therefore" found
            sentences = teacher_thought.split('.')
            final_reasoning = sentences[-2].strip() if len(sentences) > 1 else teacher_thought.strip()
        
        # Create Stage 1 prompt (question only)
        stage1_prompt = f"Question: {question}"
        
        # Create Stage 1 answer (final reasoning + answer) - PRESERVE original teacher_answer
        stage1_answer = f"{final_reasoning}. Therefore, the answer is **{teacher_answer}**."
        
        return stage1_prompt, stage1_answer

    def _create_stage2_dataset_entry(self, question: str, teacher_thought: str, teacher_answer: str) -> Tuple[str, str]:
        """Create Stage 2 entry: Complete Q&A with final reasoning."""
        
        # Create Stage 2 prompt (question only)
        stage2_prompt = f"Question: {question}"
        
        # Create Stage 2 answer (complete teacher reasoning + final answer) - PRESERVE original teacher_answer
        # Remove any existing "The answer is" conclusion to avoid duplication
        clean_thought = re.sub(r'The answer is.*$', '', teacher_thought, flags=re.IGNORECASE).strip()
        
        stage2_answer = f"{clean_thought}\n\nTherefore, the final answer is **{teacher_answer}**."
        
        return stage2_prompt, stage2_answer

    def _create_validation_result_from_metadata(self, metadata: Dict[str, Any], original_text: str) -> 'ValidationResult':
        """Create ValidationResult from existing metadata."""
        return ValidationResult(
            status=ValidationStatus(metadata.get('status', 'valid')),
            original_text=original_text,
            cleaned_text=original_text,  # Already processed
            confidence_score=metadata.get('confidence_score', 0.5),
            error_messages=metadata.get('errors', []),
            metadata=metadata
        )

    def _create_qa_cot_training_prompt(self, question: str, teacher_thought: str) -> tuple:
        """Create Q&A-CoT training prompt using actual teacher reasoning.
        
        Returns:
            tuple: (prompt_text, used_template_flag)
        """

        # Extract the Q&A structure from teacher thought
        qa_match = re.search(r'(Question\s+\d+:.*?Answer\s+\d+:.*?)(?=Therefore|The answer is|$)', teacher_thought, re.DOTALL | re.IGNORECASE)

        if qa_match:
            qa_content = qa_match.group(1).strip()
            
            # Extract the conclusion part (Therefore...) from teacher_thought
            conclusion_match = re.search(r'(Therefore.*?)(?=The answer is|$)', teacher_thought, re.DOTALL | re.IGNORECASE)
            conclusion = conclusion_match.group(1).strip() if conclusion_match else "Therefore, based on the analysis above"
            
            # Create training prompt with actual teacher reasoning
            prompt = f"""Question: {question}

                    {qa_content}
                    {conclusion}"""
            
            return prompt, False  # False = did not use template
            
        else:
            # Fallback template if Q&A structure not found
            prompt = f"""Question: {question}

                        Question 1: [Ask a clarifying question about this topic]
                        Answer 1: [Provide factual information]
                        Therefore, [conclude based on the analysis]"""
            
            return prompt, True  # True = used template fallback

    def _get_confidence_tier(self, score: float) -> str:
        """Classify confidence score into tier."""
        if score >= self.HIGH_CONFIDENCE_THRESHOLD:
            return 'high'
        elif score >= self.MEDIUM_CONFIDENCE_THRESHOLD:
            return 'medium'
        else:
            return 'low'

    def _create_enhanced_cot_prompt(self, question: str, draft: str, teacher_thought: str) -> str:
        """Create enhanced CoT prompt for traditional format."""
        return f"""Question: {question}

                    Draft: {draft}

                    Reasoning: {teacher_thought}"""

    def _is_valid_answer(self, answer: str) -> bool:
        """Check if answer is valid yes/no format."""
        if not answer or not isinstance(answer, str):
            return False
        cleaned = answer.strip().lower()
        return cleaned in ['yes', 'no']

    def _convert_to_yes_no(self, ground_truth: bool) -> str:
        """Convert boolean ground truth to Yes/No string."""
        return "Yes" if ground_truth else "No"

    def _calculate_final_stats(self, validation_stats: Dict[str, Any]) -> None:
        """Calculate final statistics."""
        total_processed = validation_stats['total_processed']
        
        if total_processed > 0:
            # Calculate agreement rate
            total_agreements = validation_stats['teacher_agrees'] + validation_stats['teacher_disagrees']
            if total_agreements > 0:
                validation_stats['teacher_agreement_rate'] = (validation_stats['teacher_agrees'] / total_agreements) * 100

            # Calculate average confidence
            if validation_stats['confidence_scores']:
                validation_stats['avg_confidence'] = sum(validation_stats['confidence_scores']) / len(validation_stats['confidence_scores'])
            else:
                validation_stats['avg_confidence'] = 0.0


class TeacherResponseValidator:
    """Validates teacher responses for quality and consistency."""
    
    def __init__(self):
        # Quality thresholds
        self.MIN_LENGTH = 50
        self.MAX_LENGTH = 2000
        self.REQUIRED_PATTERNS = [
            r'question\s+\d+:', r'answer\s+\d+:', r'therefore'
        ]
    
    def validate(self, text: str) -> ValidationResult:
        """Comprehensive validation of teacher response."""
        errors = []
        confidence_score = 1.0
        
        # Length check
        if len(text) < self.MIN_LENGTH:
            errors.append("Response too short")
            confidence_score *= 0.7
        elif len(text) > self.MAX_LENGTH:
            errors.append("Response too long")
            confidence_score *= 0.9
            
        # Pattern checks
        patterns_found = 0
        for pattern in self.REQUIRED_PATTERNS:
            if re.search(pattern, text, re.IGNORECASE):
                patterns_found += 1
        
        pattern_ratio = patterns_found / len(self.REQUIRED_PATTERNS)
        confidence_score *= pattern_ratio
        
        if pattern_ratio < 0.5:
            errors.append("Missing required Q&A structure")
            
        # Determine status
        if confidence_score >= 0.8:
            status = ValidationStatus.VALID
        elif confidence_score >= 0.5:
            status = ValidationStatus.CORRECTED
        else:
            status = ValidationStatus.INVALID
            
        return ValidationResult(
            status=status,
            original_text=text,
            cleaned_text=text,  # Could add cleaning logic here
            confidence_score=confidence_score,
            error_messages=errors,
            metadata={'pattern_ratio': pattern_ratio}
        )


def analyze_distribution(records: List[Dict]) -> Dict[str, Any]:
    """Analyze answer distribution in dataset based on original_teacher_answer field."""
    if not records:
        return {'yes_count': 0, 'no_count': 0, 'yes_percent': 0, 'no_percent': 0}
    
    # For Stage 1 and Stage 2, look at original_teacher_answer field if available
    yes_count = 0
    for r in records:
        if 'original_teacher_answer' in r:
            answer = r['original_teacher_answer'].lower()
        else:
            answer = r['answer'].lower()
        
        if 'yes' in answer:
            yes_count += 1
    
    no_count = len(records) - yes_count
    
    return {
        'yes_count': yes_count,
        'no_count': no_count,
        'yes_percent': (yes_count / len(records)) * 100,
        'no_percent': (no_count / len(records)) * 100
    }


def balance_dataset_with_oversampling(records: List[Dict], target_ratio: float = 0.5, max_samples: Optional[int] = None, imbalance_threshold: float = 0.15) -> List[Dict]:
    """Balance dataset by oversampling the minority class only if imbalance > threshold."""
    if not records:
        return records
    
    # Separate by answer based on original_teacher_answer if available
    yes_records = []
    no_records = []
    
    for r in records:
        if 'original_teacher_answer' in r:
            answer = r['original_teacher_answer'].lower()
        else:
            answer = r['answer'].lower()
        
        if 'yes' in answer:
            yes_records.append(r)
        else:
            no_records.append(r)
    
    total_records = len(yes_records) + len(no_records)
    if total_records == 0:
        print("No valid records found for balancing")
        return records
    
    print(f"Original distribution: Yes={len(yes_records)}, No={len(no_records)}")
    
    # Calculate imbalance ratio
    yes_ratio = len(yes_records) / total_records
    no_ratio = len(no_records) / total_records
    imbalance = abs(yes_ratio - 0.5)
    
    print(f"Imbalance ratio: {imbalance:.3f} (threshold: {imbalance_threshold:.3f})")
    
    # Only apply oversampling if imbalance > threshold
    if imbalance <= imbalance_threshold:
        print("Imbalance within acceptable threshold, skipping oversampling")
        if max_samples and len(records) > max_samples:
            # Just sample down to max_samples if no balancing needed
            random.shuffle(records)
            return records[:max_samples]
        return records
    
    # Apply oversampling for significant imbalance
    print("Applying oversampling due to significant class imbalance")
    
    # Determine majority and minority classes
    if len(yes_records) > len(no_records):
        majority_records = yes_records
        minority_records = no_records
        majority_class = 'yes'
    else:
        majority_records = no_records
        minority_records = yes_records
        majority_class = 'no'
    
    # Calculate target size for balanced dataset
    if max_samples:
        # If max_samples specified, use it as the total target
        total_target = max_samples
        target_per_class = total_target // 2
        
        # Sample majority class down to target
        random.shuffle(majority_records)
        majority_sampled = majority_records[:target_per_class]
        
        # Oversample minority class to match
        minority_oversampled = []
        if len(minority_records) > 0:  # Prevent infinite loop
            cycles = target_per_class // len(minority_records)
            remainder = target_per_class % len(minority_records)
            
            # Add complete cycles
            for _ in range(cycles):
                minority_oversampled.extend(minority_records[:])
            
            # Add remainder
            if remainder > 0:
                random.shuffle(minority_records)
                minority_oversampled.extend(minority_records[:remainder])
        else:
            print("WARNING: No minority class records found, cannot oversample")
            minority_oversampled = []
    else:
        # Use majority class size as target for both classes
        target_per_class = len(majority_records)
        majority_sampled = majority_records[:]
        
        # Oversample minority class to match majority class size
        minority_oversampled = []
        if len(minority_records) > 0:  # Prevent infinite loop
            cycles = target_per_class // len(minority_records)
            remainder = target_per_class % len(minority_records)
            
            # Add complete cycles
            for _ in range(cycles):
                minority_oversampled.extend(minority_records[:])
            
            # Add remainder
            if remainder > 0:
                random.shuffle(minority_records)
                minority_oversampled.extend(minority_records[:remainder])
        else:
            print("WARNING: No minority class records found, cannot oversample")
            minority_oversampled = []
    
    # Combine and shuffle
    if majority_class == 'yes':
        balanced_records = majority_sampled + minority_oversampled
    else:
        balanced_records = minority_oversampled + majority_sampled
    
    random.shuffle(balanced_records)
    
    # Verify final distribution
    final_yes = 0
    for r in balanced_records:
        if 'original_teacher_answer' in r:
            answer = r['original_teacher_answer'].lower()
        else:
            answer = r['answer'].lower()
        
        if 'yes' in answer:
            final_yes += 1
    
    final_no = len(balanced_records) - final_yes
    print(f"Balanced distribution: Yes={final_yes}, No={final_no}")
    
    return balanced_records


# Load and process teacher data for Progressive Curriculum Training
print("🔍 Loading teacher data for Progressive Curriculum Training...")

# Load the teacher data
with open(COT_PATH_QA_COT, 'r') as f:
    teacher_data = [json.loads(line) for line in f]

print(f"📊 Loaded {len(teacher_data)} teacher responses")

# Load ground truth
ground_truth_map = {}
with open(SAMPLE_TEST_PATH, 'r') as f:
    for line in f:
        item = json.loads(line)
        ground_truth_map[item['question']] = item['answer']

print(f"🎯 Loaded {len(ground_truth_map)} ground truth answers")

# Initialize and run progressive curriculum pipeline
pipeline = ResponseValidationPipeline()
baseline_records, stage1_records, stage2_records, validation_stats = pipeline.process_responses_for_progressive_curriculum(teacher_data, ground_truth_map)

# Display comprehensive validation results
print(f"\n=== PROGRESSIVE CURRICULUM VALIDATION RESULTS ===")
print(f"📊 Total processed: {validation_stats['total_processed']}")
print(f"✅ Data kept: {validation_stats['data_kept']} ({validation_stats['data_kept']/validation_stats['total_processed']*100:.1f}%)")
print(f"🎯 Average confidence: {validation_stats['avg_confidence']:.3f}")

print(f"\n=== DATASET GENERATION RESULTS ===")
print(f"📝 Baseline dataset: {len(baseline_records)} examples")
print(f"🎯 Stage 1 dataset (Final Reasoning): {len(stage1_records)} examples")
print(f"🧠 Stage 2 dataset (Complete Q&A): {len(stage2_records)} examples")

print(f"\n=== FORMAT PROCESSING RESULTS ===")
print(f"🧠 Q&A-CoT format processed: {validation_stats['qa_cot_format']} ({validation_stats['qa_cot_format']/validation_stats['total_processed']*100:.1f}%)")
print(f"📝 Traditional format processed: {validation_stats['traditional_format']} ({validation_stats['traditional_format']/validation_stats['total_processed']*100:.1f}%)")
print(f"📋 Template fallbacks used: {validation_stats.get('template_fallbacks', 0)} ({validation_stats.get('template_fallbacks', 0)/validation_stats['total_processed']*100:.1f}%)")

print(f"\n=== CONFIDENCE TIER DISTRIBUTION ===")
total_validated = validation_stats['high_confidence'] + validation_stats['medium_confidence'] + validation_stats['low_confidence']
if total_validated > 0:
    print(f"🔥 High (≥0.8): {validation_stats['high_confidence']} ({validation_stats['high_confidence']/total_validated*100:.1f}%)")
    print(f"🟡 Medium (0.5-0.8): {validation_stats['medium_confidence']} ({validation_stats['medium_confidence']/total_validated*100:.1f}%)")
    print(f"🔴 Low (<0.5): {validation_stats['low_confidence']} ({validation_stats['low_confidence']/total_validated*100:.1f}%)")

print(f"\n=== ANSWER VALIDATION ===")
print(f"✅ Valid teacher answers: {validation_stats['valid_teacher_answers']}")
print(f"❌ Invalid teacher answers: {validation_stats['invalid_teacher_answers']}")
if 'teacher_agreement_rate' in validation_stats:
    print(f"🤝 Teacher-ground truth agreement: {validation_stats['teacher_agreement_rate']:.1f}%")

print(f"\n=== ERROR ANALYSIS ===")
if validation_stats['validation_errors']:
    print("Most common validation errors:")
    for error, count in validation_stats['validation_errors'].most_common(3):
        print(f"  - {error}: {count} occurrences")

print(f"🔧 Correction attempts: {validation_stats['correction_attempts']}")
print(f"✅ Successful corrections: {validation_stats['successful_corrections']}")

# Analyze class distributions before balancing
print(f"\n=== DISTRIBUTION ANALYSIS (PRE-BALANCING) ===")
baseline_dist = analyze_distribution(baseline_records)
stage1_dist = analyze_distribution(stage1_records)
stage2_dist = analyze_distribution(stage2_records)

print(f"Baseline: Yes={baseline_dist['yes_count']} ({baseline_dist['yes_percent']:.1f}%), No={baseline_dist['no_count']} ({baseline_dist['no_percent']:.1f}%)")
print(f"Stage 1: Yes={stage1_dist['yes_count']} ({stage1_dist['yes_percent']:.1f}%), No={stage1_dist['no_count']} ({stage1_dist['no_percent']:.1f}%)")
print(f"Stage 2: Yes={stage2_dist['yes_count']} ({stage2_dist['yes_percent']:.1f}%), No={stage2_dist['no_count']} ({stage2_dist['no_percent']:.1f}%)")

# Apply smart oversampling-based class balancing (only if imbalance > 15%)
print(f"\n=== APPLYING SMART CLASS BALANCING (15% THRESHOLD) ===")
TARGET_TRAIN_SIZE = 1500  # Conservative target for quality
IMBALANCE_THRESHOLD = 0.15  # Only apply oversampling if imbalance > 15%

balanced_baseline_records = balance_dataset_with_oversampling(baseline_records, target_ratio=0.5, max_samples=TARGET_TRAIN_SIZE, imbalance_threshold=IMBALANCE_THRESHOLD)
balanced_stage1_records = balance_dataset_with_oversampling(stage1_records, target_ratio=0.5, max_samples=TARGET_TRAIN_SIZE, imbalance_threshold=IMBALANCE_THRESHOLD)
balanced_stage2_records = balance_dataset_with_oversampling(stage2_records, target_ratio=0.5, max_samples=TARGET_TRAIN_SIZE, imbalance_threshold=IMBALANCE_THRESHOLD)

# Analyze post-balancing distributions
balanced_baseline_dist = analyze_distribution(balanced_baseline_records)
balanced_stage1_dist = analyze_distribution(balanced_stage1_records)
balanced_stage2_dist = analyze_distribution(balanced_stage2_records)

print(f"\n=== DISTRIBUTION ANALYSIS (POST-BALANCING) ===")
print(f"Baseline: Yes={balanced_baseline_dist['yes_count']} ({balanced_baseline_dist['yes_percent']:.1f}%), No={balanced_baseline_dist['no_count']} ({balanced_baseline_dist['no_percent']:.1f}%)")
print(f"Stage 1: Yes={balanced_stage1_dist['yes_count']} ({balanced_stage1_dist['yes_percent']:.1f}%), No={balanced_stage1_dist['no_count']} ({balanced_stage1_dist['no_percent']:.1f}%)")
print(f"Stage 2: Yes={balanced_stage2_dist['yes_count']} ({balanced_stage2_dist['yes_percent']:.1f}%), No={balanced_stage2_dist['no_count']} ({balanced_stage2_dist['no_percent']:.1f}%)")

# Save balanced Progressive Curriculum Training datasets  
print(f"\n🚀 Saving Progressive Curriculum Training datasets...")

# Save baseline dataset (direct Q→A)
with open(BASELINE_PATH, 'w') as f:
    for record in balanced_baseline_records:
        f.write(json.dumps(record) + '\n')

# Save Stage 1 dataset (final reasoning focus)
with open(STAGE1_PATH, 'w') as f:
    for record in balanced_stage1_records:
        f.write(json.dumps(record) + '\n')

# Save Stage 2 dataset (complete Q&A)
with open(STAGE2_PATH, 'w') as f:
    for record in balanced_stage2_records:
        f.write(json.dumps(record) + '\n')

print(f"💾 Saved {len(balanced_baseline_records)} baseline records to: {BASELINE_PATH}")
print(f"🎯 Saved {len(balanced_stage1_records)} Stage 1 records to: {STAGE1_PATH}")
print(f"🧠 Saved {len(balanced_stage2_records)} Stage 2 records to: {STAGE2_PATH}")

print(f"\n=== PROGRESSIVE CURRICULUM TRAINING DATA GENERATION SUMMARY ===")
print(f"🚀 Pipeline Status: ✅ COMPLETE")
print(f"🎯 Stage 1 Focus: Final reasoning after 'Therefore' phrase")
print(f"🧠 Stage 2 Focus: Complete Q&A with final reasoning")
print(f"📊 Data Quality: {validation_stats['data_kept']} high/medium confidence examples")
print(f"⚖️  Class Balance: Smart oversampling (15% threshold)")
print(f"🎯 Confidence Score: {validation_stats['avg_confidence']:.3f} average")
print(f"🔍 Validation Rate: {(validation_stats['data_kept']/validation_stats['total_processed']*100):.1f}% data retention")

# Calculate Q&A-CoT adoption rate
qa_cot_adoption = validation_stats['qa_cot_format'] / validation_stats['total_processed'] * 100
print(f"🧠 Q&A-CoT Adoption: {qa_cot_adoption:.1f}% of training data uses interleaved Q&A format")

if validation_stats['avg_confidence'] >= 0.7:
    print("🎯 SUCCESS: Achieved high-confidence Progressive Curriculum pipeline")
else:
    print("⚠️  Moderate confidence - consider adjusting thresholds")

if qa_cot_adoption >= 80:
    print("🎯 SUCCESS: High Q&A-CoT format adoption (80%+)")
else:
    print(f"⚠️  Q&A-CoT adoption at {qa_cot_adoption:.1f}% - consider improving prompt consistency")

print(f"\n🏁 Ready for Progressive Curriculum Training!")

# Update paths for Progressive Curriculum Training
BASELINE_PATH_ENHANCED = BASELINE_PATH
STAGE1_PATH_ENHANCED = STAGE1_PATH
STAGE2_PATH_ENHANCED = STAGE2_PATH

print(f"\n📂 Updated training paths:")
print(f"- BASELINE_PATH_ENHANCED = '{BASELINE_PATH_ENHANCED}'")
print(f"- STAGE1_PATH_ENHANCED = '{STAGE1_PATH_ENHANCED}'")
print(f"- STAGE2_PATH_ENHANCED = '{STAGE2_PATH_ENHANCED}'")

🔍 Loading teacher data for Progressive Curriculum Training...
📊 Loaded 200 teacher responses
🎯 Loaded 200 ground truth answers
Processed 100 responses...
Processed 200 responses...

=== PROGRESSIVE CURRICULUM VALIDATION RESULTS ===
📊 Total processed: 200
✅ Data kept: 200 (100.0%)
🎯 Average confidence: 1.000

=== DATASET GENERATION RESULTS ===
📝 Baseline dataset: 200 examples
🎯 Stage 1 dataset (Final Reasoning): 200 examples
🧠 Stage 2 dataset (Complete Q&A): 200 examples

=== FORMAT PROCESSING RESULTS ===
🧠 Q&A-CoT format processed: 200 (100.0%)
📝 Traditional format processed: 0 (0.0%)
📋 Template fallbacks used: 0 (0.0%)

=== CONFIDENCE TIER DISTRIBUTION ===
🔥 High (≥0.8): 200 (100.0%)
🟡 Medium (0.5-0.8): 0 (0.0%)
🔴 Low (<0.5): 0 (0.0%)

=== ANSWER VALIDATION ===
✅ Valid teacher answers: 200
❌ Invalid teacher answers: 0
🤝 Teacher-ground truth agreement: 91.5%

=== ERROR ANALYSIS ===
🔧 Correction attempts: 0
✅ Successful corrections: 0

=== DISTRIBUTION ANALYSIS (PRE-BALANCING) ===
Basel