In [1]:
!pip install -q datasets transformers openai bitsandbytes accelerate python-dotenv huggingface_hub huggingface_hub[hf_xet]


[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


## Setup ENV var config

In [6]:
import os
from dotenv import load_dotenv

# Load environment variables from .env file if it exists
load_dotenv()

# Dataset parameters
DATASET_NAME = os.getenv('DATASET_NAME', 'voidful/StrategyQA')
TRAIN_SAMPLES = int(os.getenv('TRAIN_SAMPLES', '100'))
RANDOM_SEED = int(os.getenv('RANDOM_SEED', '42'))
USE_FULL_DATASET = os.getenv('USE_FULL_DATASET', 'False').lower() in ('true', '1', 't')


# Model parameters
MODEL_NAME = os.getenv('MODEL_NAME', 'microsoft/phi-2')
MAX_NEW_TOKENS = int(os.getenv('MAX_NEW_TOKENS', '35'))
BATCH_SIZE = int(os.getenv('BATCH_SIZE', '8'))
USE_4BIT = os.getenv('USE_4BIT', 'True').lower() in ('true', '1', 't')
MAX_SEQ_LENGTH = int(os.getenv('MAX_SEQ_LENGTH', '512'))
HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN', '')

# Generation parameters
DO_SAMPLE = os.getenv('DO_SAMPLE', 'False').lower() in ('true', '1', 't')
TEMPERATURE = float(os.getenv('TEMPERATURE', '0.7'))

# GPT-4 parameters
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', '')
GPT4_MODEL = os.getenv('GPT4_MODEL', 'gpt-4')
GPT4_MAX_TOKENS = int(os.getenv('GPT4_MAX_TOKENS', '150'))
GPT4_TEMPERATURE = float(os.getenv('GPT4_TEMPERATURE', '0.3'))
DRY_RUN = os.getenv('DRY_RUN', 'True').lower() in ('true', '1', 't')

# Student Draft Generation
STUDENT_MAX_TOKENS = int(os.getenv('STUDENT_MAX_TOKENS', '200'))
STUDENT_TEMPERATURE = float(os.getenv('STUDENT_TEMPERATURE', '0.7'))
STUDENT_BATCH_SIZE = int(os.getenv('STUDENT_BATCH_SIZE', '8'))

# Enhanced Evaluation Generation
EVAL_MAX_TOKENS = int(os.getenv('EVAL_MAX_TOKENS', '256'))
EVAL_TEMPERATURE = float(os.getenv('EVAL_TEMPERATURE', '0.7'))
EVAL_BATCH_SIZE = int(os.getenv('EVAL_BATCH_SIZE', '4'))

# Quick Evaluation
QUICK_EVAL_MAX_TOKENS = int(os.getenv('QUICK_EVAL_MAX_TOKENS', '5'))

# Training Configuration
# Phase A Training
PHASE_A_EPOCHS = int(os.getenv('PHASE_A_EPOCHS', '3'))
PHASE_A_BATCH_SIZE = int(os.getenv('PHASE_A_BATCH_SIZE', '1'))
PHASE_A_LEARNING_RATE = float(os.getenv('PHASE_A_LEARNING_RATE', '1e-4'))
PHASE_A_WARMUP_RATIO = float(os.getenv('PHASE_A_WARMUP_RATIO', '0.1'))
PHASE_A_WEIGHT_DECAY = float(os.getenv('PHASE_A_WEIGHT_DECAY', '0.01'))

# Phase B Training
PHASE_B_EPOCHS = int(os.getenv('PHASE_B_EPOCHS', '3'))
PHASE_B_BATCH_SIZE = int(os.getenv('PHASE_B_BATCH_SIZE', '4'))
PHASE_B_LEARNING_RATE = float(os.getenv('PHASE_B_LEARNING_RATE', '5e-5'))
PHASE_B_WARMUP_RATIO = float(os.getenv('PHASE_B_WARMUP_RATIO', '0.1'))
PHASE_B_WEIGHT_DECAY = float(os.getenv('PHASE_B_WEIGHT_DECAY', '0.01'))

# Progressive Curriculum Training
CURRICULUM_STAGE1_EPOCHS = int(os.getenv('CURRICULUM_STAGE1_EPOCHS', '1'))
CURRICULUM_STAGE1_LEARNING_RATE = float(os.getenv('CURRICULUM_STAGE1_LEARNING_RATE', '5e-5'))
CURRICULUM_STAGE1_WARMUP_RATIO = float(os.getenv('CURRICULUM_STAGE1_WARMUP_RATIO', '0.1'))
CURRICULUM_STAGE1_WEIGHT_DECAY = float(os.getenv('CURRICULUM_STAGE1_WEIGHT_DECAY', '0.01'))

CURRICULUM_STAGE2_EPOCHS = int(os.getenv('CURRICULUM_STAGE2_EPOCHS', '2'))
CURRICULUM_STAGE2_LEARNING_RATE = float(os.getenv('CURRICULUM_STAGE2_LEARNING_RATE', '3e-5'))
CURRICULUM_STAGE2_WARMUP_RATIO = float(os.getenv('CURRICULUM_STAGE2_WARMUP_RATIO', '0.1'))
CURRICULUM_STAGE2_WEIGHT_DECAY = float(os.getenv('CURRICULUM_STAGE2_WEIGHT_DECAY', '0.01'))

# Validation Configuration
HIGH_CONFIDENCE_THRESHOLD = float(os.getenv('HIGH_CONFIDENCE_THRESHOLD', '0.8'))
MEDIUM_CONFIDENCE_THRESHOLD = float(os.getenv('MEDIUM_CONFIDENCE_THRESHOLD', '0.5'))
LOW_CONFIDENCE_THRESHOLD = float(os.getenv('LOW_CONFIDENCE_THRESHOLD', '0.3'))
VALIDATION_ACCEPTANCE_THRESHOLD = float(os.getenv('VALIDATION_ACCEPTANCE_THRESHOLD', '0.3'))

# Quality Thresholds
VALID_QUALITY_THRESHOLD = float(os.getenv('VALID_QUALITY_THRESHOLD', '0.5'))
CORRECTED_QUALITY_THRESHOLD = float(os.getenv('CORRECTED_QUALITY_THRESHOLD', '0.3'))

# Token Emphasis Configuration
EMPHASIS_MULTIPLIER = float(os.getenv('EMPHASIS_MULTIPLIER', '2.5'))
ADAPTIVE_EMPHASIS = os.getenv('ADAPTIVE_EMPHASIS', 'True').lower() in ('true', '1', 't')

# Memory and Performance
GRADIENT_CHECKPOINTING = os.getenv('GRADIENT_CHECKPOINTING', 'True').lower() in ('true', '1', 't')
FP16 = os.getenv('FP16', 'False').lower() in ('true', '1', 't')
BF16 = os.getenv('BF16', 'True').lower() in ('true', '1', 't')
DATALOADER_NUM_WORKERS = int(os.getenv('DATALOADER_NUM_WORKERS', '0'))
DATALOADER_PERSISTENT_WORKERS = os.getenv('DATALOADER_PERSISTENT_WORKERS', 'False').lower() in ('true', '1', 't')
SKIP_MEMORY_METRICS = os.getenv('SKIP_MEMORY_METRICS', 'True').lower() in ('true', '1', 't')

# Logging and Monitoring
LOGGING_STEPS = int(os.getenv('LOGGING_STEPS', '10'))
SAVE_STRATEGY = os.getenv('SAVE_STRATEGY', 'epoch')
REPORT_TO = os.getenv('REPORT_TO', 'none')
LOAD_BEST_MODEL_AT_END = os.getenv('LOAD_BEST_MODEL_AT_END', 'False').lower() in ('true', '1', 't')

# Evaluation Configuration
EVAL_SIZE = int(os.getenv('EVAL_SIZE', '100'))
ENHANCED_EVAL_BATCH_SIZE = int(os.getenv('ENHANCED_EVAL_BATCH_SIZE', '4'))
# File paths
parent_dir = os.path.dirname(os.getcwd())
DATA_DIR = os.path.join(parent_dir, os.getenv("DATA_DIR", "data"))
RAW_DIR = os.path.join(DATA_DIR, 'raw')
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
STUDENT_DIR = os.path.join(DATA_DIR, 'student')
TEACHER_DIR = os.path.join(DATA_DIR, 'teacher')
SAMPLE_TRAIN_PATH = os.path.join(TRAIN_DIR, 'sample_train.jsonl')
STUDENT_DRAFTS_PATH = os.path.join(STUDENT_DIR, 'student_drafts.jsonl')
CLEANED_STUDENT_DRAFTS_PATH = os.path.join(STUDENT_DIR, 'cleaned_student_drafts.jsonl')
TEACHER_OUTPUTS_PATH = os.path.join(TEACHER_DIR, 'teacher_outputs.jsonl')
BASELINE_PATH = os.path.join(TRAIN_DIR, 'train_baseline.jsonl')
COT_PATH = os.path.join(TRAIN_DIR, 'train_cot.jsonl')
COT_PATH_QA_COT = os.path.join(TEACHER_DIR, 'teacher_outputs_qa_cot.jsonl')
SAMPLE_TEST_PATH = SAMPLE_TRAIN_PATH  # Alias for consistency with Build Training Corpora cell

# Print configuration
print("=== Configuration ===")
print(f"Dataset: {DATASET_NAME}")
print(f"Model: {MODEL_NAME}")
print(f"Batch size: {BATCH_SIZE}")
print(f"4-bit quantization: {USE_4BIT}")
print(f"GPT-4 dry run: {DRY_RUN}")
print("=="*10)

=== Configuration ===
Dataset: voidful/StrategyQA
Model: microsoft/Phi-3.5-mini-instruct
Batch size: 8
4-bit quantization: True
GPT-4 dry run: False


## Generate Teacher Responses

We now call GPT‑4 to obtain chain‑of‑thought (CoT) reasoning and final yes/no answers for each question/draft pair.  The prompt format follows the plan:

```
Q: <original yes/no question>
Student draft: <answer + clarifying questions>
Teacher: Please think step-by-step and provide your thought process and final Yes/No answer.
```

To run the actual API calls, you must provide a valid OpenAI API key.  If you set `dry_run=True`, dummy responses will be generated for testing purposes.


In [7]:
import openai
from openai import OpenAI
import json
import re
from typing import Dict, List, Any, Optional, Tuple
import logging
from enum import Enum
from abc import ABC, abstractmethod
from dataclasses import dataclass

class ValidationStatus(Enum):
    VALID = "valid"
    INVALID = "invalid"
    CORRECTED = "corrected"
    FAILED = "failed"

@dataclass
class ValidationResult:
    status: ValidationStatus
    original_text: str
    cleaned_text: Optional[str]
    confidence_score: float  # 0.0 - 1.0
    error_messages: List[str]
    metadata: Dict[str, Any]

    def is_valid(self) -> bool:
        return self.status in [ValidationStatus.VALID, ValidationStatus.CORRECTED]


class BaseResponseValidator(ABC):
    def __init__(self):
        self.logger = logging.getLogger(self.__class__.__name__)

    @abstractmethod
    def validate(self, response_text: str) -> ValidationResult:
        pass

    def _calculate_base_confidence(self, parsing_success: bool, content_quality: float) -> float:
        """Calculate base confidence score from parsing and content quality."""
        parsing_score = 1.0 if parsing_success else 0.3
        return (parsing_score * 0.6) + (content_quality * 0.4)


@dataclass
class TeacherResponse:
    teaching_analysis: str
    step_by_step_reasoning: str
    final_assessment: str
    extracted_answer: str  # YES/NO
    confidence_score: float
    quality_metrics: Dict[str, float]

class TeacherResponseValidator(BaseResponseValidator):
    def __init__(self):
        super().__init__()
        # Support both old three-section and new Q&A formats
        self.section_headers = [
            r'## Teaching Analysis',
            r'## Step-by-Step Reasoning',
            r'## Final Assessment'
        ]
        # Enhanced patterns for Q&A-CoT format
        self.final_answer_patterns = [
            r'The answer is \*\*(Yes|No)\*\*',  # Primary Q&A-CoT pattern
            r'## Final Assessment.*?\*\*(YES|NO)\*\*',  # Legacy pattern
            r'Based on.*?the answer is.*?\*\*(YES|NO)\*\*',
            r'\*\*(YES|NO)\*\*',
            r'\b(YES|NO)\b(?=\s*[.!]?\s*$)'
        ]
        # Q&A format patterns
        self.qa_patterns = [
            r'Question\s+(\d+)\s*:?\s*(.+?)(?=Answer\s+\1|$)',
            r'Answer\s+(\d+)\s*:?\s*(.+?)(?=Question\s+\d+|Therefore|The answer is|$)'
        ]

    def validate(self, response_text: str) -> ValidationResult:
        """Validate teacher response supporting both Q&A-CoT and legacy formats."""
        errors = []
        metadata = {}

        # Detect format type
        has_qa_format = self._detect_qa_format(response_text)
        has_three_section = self._detect_three_section_format(response_text)

        metadata['format_detected'] = 'qa_cot' if has_qa_format else ('three_section' if has_three_section else 'unknown')

        # Layer 1: Structural Parsing based on format
        if has_qa_format:
            parsed_data = self._parse_qa_cot_structure(response_text)
            parsing_success = parsed_data is not None
        elif has_three_section:
            parsed_data = self._parse_three_section_structure(response_text)
            parsing_success = parsed_data is not None
        else:
            parsed_data = None
            parsing_success = False
            errors.append("Unknown format: neither Q&A-CoT nor three-section structure detected")

        if not parsing_success and not errors:
            errors.append("Failed to parse response structure")
            # Attempt correction
            corrected = self._attempt_structure_correction(response_text)
            if corrected:
                if has_qa_format:
                    parsed_data = self._parse_qa_cot_structure(corrected)
                else:
                    parsed_data = self._parse_three_section_structure(corrected)
                if parsed_data:
                    parsing_success = True
                    errors.append("Structure corrected automatically")

        # Layer 2: Content Validation
        if parsed_data:
            if has_qa_format:
                quality_metrics = self._assess_qa_cot_quality(parsed_data, response_text)
            else:
                quality_metrics = self._assess_teacher_quality(parsed_data)
            metadata['quality_metrics'] = quality_metrics

            # Extract final answer
            final_answer = self._extract_final_answer(response_text, parsed_data)
            metadata['extracted_answer'] = final_answer

            if not final_answer or final_answer not in ['YES', 'NO']:
                errors.append("Failed to extract valid final answer")
                quality_metrics['answer_extraction'] = 0.0
            else:
                quality_metrics['answer_extraction'] = 1.0
        else:
            quality_metrics = {
                'overall': 0.0,
                'structural_quality': 0.0,
                'content_quality': 0.0,
                'answer_extraction': 0.0
            }
            final_answer = None

        # Layer 4: Confidence Scoring
        confidence_score = self._calculate_teacher_confidence(
            parsing_success, quality_metrics, len(errors), has_qa_format
        )

        # Determine final status
        if parsing_success and quality_metrics['overall'] >= 0.7 and final_answer:
            status = ValidationStatus.VALID
            cleaned_text = self._format_response_output(parsed_data, final_answer, has_qa_format)
        elif parsing_success and quality_metrics['overall'] >= 0.5 and final_answer:
            status = ValidationStatus.CORRECTED
            cleaned_text = self._format_response_output(parsed_data, final_answer, has_qa_format)
        else:
            status = ValidationStatus.INVALID
            cleaned_text = None

        return ValidationResult(
            status=status,
            original_text=response_text,
            cleaned_text=cleaned_text,
            confidence_score=confidence_score,
            error_messages=errors,
            metadata=metadata
        )

    def _detect_qa_format(self, text: str) -> bool:
        """Detect if text uses Q&A interleaved format."""
        return bool(re.search(r'(?:Question\s+\d+\s*:.*?Answer\s+\d+\s*:|Answer\s*:.*?Questions?\s*:)', text, re.DOTALL | re.IGNORECASE))

    def _detect_three_section_format(self, text: str) -> bool:
        """Detect if text uses three-section format."""
        sections_found = sum(1 for header in self.section_headers
                           if re.search(header, text, re.IGNORECASE))
        return sections_found >= 2

    def _parse_qa_cot_structure(self, text: str) -> Optional[Dict[str, Any]]:
        """Parse Q&A-CoT format: Question 1: ... Answer 1: ... Therefore ..."""
        # Extract Q&A pairs
        qa_pairs = []
        question_matches = re.findall(r'Question\s+(\d+)\s*:?\s*(.+?)(?=Answer\s+\1|$)', text, re.DOTALL | re.IGNORECASE)
        answer_matches = re.findall(r'Answer\s+(\d+)\s*:?\s*(.+?)(?=Question\s+\d+|Therefore|The answer is|$)', text, re.DOTALL | re.IGNORECASE)

        # Pair up questions and answers
        for i, (q_num, question) in enumerate(question_matches):
            answer = ""
            for a_num, ans in answer_matches:
                if a_num == q_num:
                    answer = ans.strip()
                    break
            qa_pairs.append({
                'number': q_num,
                'question': question.strip(),
                'answer': answer
            })

        # Extract therefore/conclusion
        therefore_match = re.search(r'Therefore,?\s*(.+?)(?=The answer is|$)', text, re.DOTALL | re.IGNORECASE)
        therefore_text = therefore_match.group(1).strip() if therefore_match else ""

        if qa_pairs:  # At least one Q&A pair found
            return {
                'qa_pairs': qa_pairs,
                'therefore_reasoning': therefore_text,
                'format_type': 'qa_cot',
                'pair_count': len(qa_pairs)
            }

        return None

    def _parse_three_section_structure(self, text: str) -> Optional[Dict[str, str]]:
        """Parse the three-section teacher response format."""
        sections = {}

        # Extract Teaching Analysis
        teaching_match = re.search(
            r'## Teaching Analysis\s*\n(.*?)(?=\n## |$)',
            text, re.DOTALL | re.IGNORECASE
        )
        if teaching_match:
            sections['teaching_analysis'] = teaching_match.group(1).strip()

        # Extract Step-by-Step Reasoning
        reasoning_match = re.search(
            r'## Step-by-Step Reasoning\s*\n(.*?)(?=\n## |$)',
            text, re.DOTALL | re.IGNORECASE
        )
        if reasoning_match:
            sections['step_by_step_reasoning'] = reasoning_match.group(1).strip()

        # Extract Final Assessment
        assessment_match = re.search(
            r'## Final Assessment\s*\n(.*?)(?=\n## |$)',
            text, re.DOTALL | re.IGNORECASE
        )
        if assessment_match:
            sections['final_assessment'] = assessment_match.group(1).strip()

        # Require at least 2 of 3 sections
        if len(sections) >= 2:
            sections['format_type'] = 'three_section'
            return sections

        return None

    def _assess_qa_cot_quality(self, parsed_data: Dict[str, Any], full_text: str) -> Dict[str, float]:
        """Assess quality of Q&A-CoT format responses."""
        qa_pairs = parsed_data.get('qa_pairs', [])
        therefore_text = parsed_data.get('therefore_reasoning', '')

        # Structural quality - based on Q&A pairs
        if len(qa_pairs) == 0:
            structural_quality = 0.0
        elif len(qa_pairs) == 1:
            structural_quality = 0.7  # Single question is okay
        elif len(qa_pairs) == 2:
            structural_quality = 1.0  # Ideal: 2 questions
        else:
            structural_quality = 0.9  # More than 2 is still good but not ideal

        # Content quality assessment
        content_scores = []

        # Question quality
        if qa_pairs:
            question_scores = []
            for pair in qa_pairs:
                question = pair.get('question', '')
                answer = pair.get('answer', '')

                # Question should be specific and end with ?
                q_score = 0.0
                if question:
                    q_score += 0.3  # Has question
                    if '?' in question:
                        q_score += 0.3  # Proper question format
                    if len(question.split()) >= 4:
                        q_score += 0.2  # Reasonable length
                    if any(word in question.lower() for word in ['what', 'when', 'where', 'who', 'how', 'why', 'did', 'does', 'is', 'are']):
                        q_score += 0.2  # Contains question words

                # Answer should be factual and relevant
                a_score = 0.0
                if answer:
                    a_score += 0.4  # Has answer
                    if len(answer.split()) >= 3:
                        a_score += 0.3  # Substantial answer
                    if len(answer.split()) <= 50:
                        a_score += 0.3  # Not too verbose

                pair_score = (q_score + a_score) / 2
                question_scores.append(pair_score)

            content_scores.append(sum(question_scores) / len(question_scores))
        else:
            content_scores.append(0.0)

        # Therefore reasoning quality
        if therefore_text:
            therefore_score = 0.0
            therefore_score += 0.4  # Has therefore reasoning
            if len(therefore_text.split()) >= 5:
                therefore_score += 0.3  # Reasonable length
            if any(word in therefore_text.lower() for word in ['therefore', 'since', 'because', 'so', 'thus']):
                therefore_score += 0.3  # Contains logical connectors
            content_scores.append(therefore_score)
        else:
            content_scores.append(0.0)

        content_quality = sum(content_scores) / len(content_scores) if content_scores else 0.0

        # Overall quality
        overall_quality = (structural_quality * 0.4) + (content_quality * 0.6)

        return {
            'overall': overall_quality,
            'structural_quality': structural_quality,
            'content_quality': content_quality,
            'qa_pair_count': len(qa_pairs),
            'has_therefore': bool(therefore_text)
        }

    def _extract_final_answer(self, text: str, parsed_data: Dict[str, Any]) -> Optional[str]:
        """Extract final YES/NO answer with enhanced Q&A-CoT patterns."""
        # Try patterns in order of preference (Q&A-CoT first)
        for pattern in self.final_answer_patterns:
            match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
            if match:
                answer = match.group(1).upper()
                # Normalize to YES/NO
                if answer.lower() in ['yes', 'no']:
                    return answer.upper()

        # Fallback for three-section format
        if parsed_data.get('format_type') == 'three_section' and 'final_assessment' in parsed_data:
            final_section = parsed_data['final_assessment']
            yes_count = len(re.findall(r'\byes\b', final_section, re.IGNORECASE))
            no_count = len(re.findall(r'\bno\b', final_section, re.IGNORECASE))

            if yes_count > no_count:
                return 'YES'
            elif no_count > yes_count:
                return 'NO'

        return None

    def _assess_teacher_quality(self, parsed_sections: Dict[str, str]) -> Dict[str, float]:
        """Assess quality of three-section teacher response."""
        # Structural quality (completeness)
        section_completeness = len([k for k in parsed_sections.keys() if k != 'format_type']) / 3

        # Content quality assessment
        content_scores = []

        for section_name, content in parsed_sections.items():
            if section_name == 'format_type':
                continue

            if not content.strip():
                content_scores.append(0.0)
                continue

            # Basic content quality heuristics
            word_count = len(content.split())

            if section_name == 'teaching_analysis':
                ideal_length = 50
                length_score = min(word_count / ideal_length, 1.0)
            elif section_name == 'step_by_step_reasoning':
                ideal_length = 100
                length_score = min(word_count / ideal_length, 1.0)
                # Bonus for structure
                has_structure = bool(re.search(r'\d+\.|\n-|Step \d+', content))
                length_score += 0.2 if has_structure else 0.0
                length_score = min(length_score, 1.0)
            elif section_name == 'final_assessment':
                ideal_length = 30
                length_score = min(word_count / ideal_length, 1.0)
                # Bonus for conclusion language
                has_conclusion = bool(re.search(r'based on|therefore|in conclusion', content, re.IGNORECASE))
                length_score += 0.2 if has_conclusion else 0.0
                length_score = min(length_score, 1.0)
            else:
                length_score = min(word_count / 50, 1.0)

            content_scores.append(length_score)

        content_quality = sum(content_scores) / len(content_scores) if content_scores else 0.0
        overall_quality = (section_completeness * 0.4) + (content_quality * 0.6)

        return {
            'overall': overall_quality,
            'structural_quality': section_completeness,
            'content_quality': content_quality,
            'section_count': len([k for k in parsed_sections.keys() if k != 'format_type'])
        }

    def _attempt_structure_correction(self, text: str) -> Optional[str]:
        """Attempt to correct common structural issues."""
        corrected = text

        # Q&A-CoT corrections
        if 'question' in corrected.lower() and 'answer' in corrected.lower():
            # Fix missing colons in Q&A format
            corrected = re.sub(r'\b(Question\s+\d+)\b(?!\s*:)', r'\1:', corrected, flags=re.IGNORECASE)
            corrected = re.sub(r'\b(Answer\s+\d+)\b(?!\s*:)', r'\1:', corrected, flags=re.IGNORECASE)

        # Three-section corrections
        if '## Teaching Analysis' not in corrected and 'Teaching' in corrected:
            corrected = re.sub(r'^(Teaching.*?):', r'## Teaching Analysis\n\1:', corrected, flags=re.MULTILINE)

        if '## Step-by-Step' not in corrected and ('Step' in corrected or 'reasoning' in corrected.lower()):
            corrected = re.sub(r'^(.*reasoning.*?):', r'## Step-by-Step Reasoning\n\1:', corrected, flags=re.MULTILINE | re.IGNORECASE)

        if '## Final Assessment' not in corrected and ('final' in corrected.lower() or 'assessment' in corrected.lower()):
            corrected = re.sub(r'^(.*(?:final|assessment).*?):', r'## Final Assessment\n\1:', corrected, flags=re.MULTILINE | re.IGNORECASE)

        return corrected if corrected != text else None

    def _calculate_teacher_confidence(self, parsing_success: bool, quality_metrics: Dict[str, float], error_count: int, is_qa_format: bool) -> float:
        """Calculate confidence score for teacher responses."""
        base_confidence = self._calculate_base_confidence(parsing_success, quality_metrics['overall'])

        # Format-specific adjustments
        format_bonus = 0.1 if is_qa_format else 0.05  # Slight preference for Q&A format
        structure_bonus = 0.1 if quality_metrics['structural_quality'] >= 0.8 else 0.0
        answer_bonus = 0.1 if quality_metrics.get('answer_extraction', 0) >= 1.0 else 0.0
        error_penalty = min(error_count * 0.15, 0.3)

        return max(0.0, min(1.0, base_confidence + format_bonus + structure_bonus + answer_bonus - error_penalty))

    def _format_response_output(self, parsed_data: Dict[str, Any], final_answer: str, is_qa_format: bool) -> str:
        """Format response output based on detected format."""
        if is_qa_format:
            # Format Q&A-CoT output
            output_parts = []
            qa_pairs = parsed_data.get('qa_pairs', [])

            for pair in qa_pairs:
                output_parts.append(f"Question {pair['number']}: {pair['question']}")
                output_parts.append(f"Answer {pair['number']}: {pair['answer']}")

            therefore_text = parsed_data.get('therefore_reasoning', '')
            if therefore_text:
                output_parts.append(f"Therefore, {therefore_text}")

            output_parts.append(f"The answer is **{final_answer}**.")

            return '\n'.join(output_parts)

        else:
            # Format three-section output
            output_parts = []

            if 'teaching_analysis' in parsed_data:
                output_parts.append(f"## Teaching Analysis\n{parsed_data['teaching_analysis']}")

            if 'step_by_step_reasoning' in parsed_data:
                output_parts.append(f"## Step-by-Step Reasoning\n{parsed_data['step_by_step_reasoning']}")

            if 'final_assessment' in parsed_data:
                output_parts.append(f"## Final Assessment\n{parsed_data['final_assessment']}")
            else:
                output_parts.append(f"## Final Assessment\nBased on this analysis, the answer is: **{final_answer}**")

            return '\n\n'.join(output_parts)


In [9]:
# ============================================================================
# ENHANCED GPT-4 TEACHER PROMPTS FOR Q&A-COT GENERATION (NEW IMPLEMENTATION)
# ============================================================================

def create_qa_cot_teacher_prompt() -> str:
    """Create the enhanced GPT-4 teacher prompt for Q&A-CoT generation."""

    system_prompt = """You are an expert teacher helping a 7B student AI model learn to reason through complex yes/no questions using self-questioning. Your role is to demonstrate the interleaved Q&A reasoning format that the student should learn to mimic.

                        CRITICAL INSTRUCTION - INTERLEAVED Q&A FORMAT:
                        You must generate responses that teach the student to ask itself clarifying sub-questions and answer them step-by-step. Use this EXACT format:

                        Question 1: [Ask a relevant clarifying question the student should consider]
                        Answer 1: [Provide a factual, concise answer to that question]
                        Question 2: [Optional second question if needed for multi-step reasoning]
                        Answer 2: [Answer to the second question]
                        Therefore, [synthesize the answers into final reasoning]
                        The answer is **[YES/NO]**.

                        REQUIREMENTS:
                        1. **Self-Questioning Strategy**: Generate 1-2 clarifying questions that break down implicit multi-hop reasoning
                        2. **Factual Answers**: Provide correct, concise answers (1-2 sentences) that the student can learn from
                        3. **Logical Synthesis**: Use "Therefore" to connect sub-answers to the final conclusion
                        4. **Format Consistency**: Always use "Question N:", "Answer N:", "Therefore", and "The answer is **[YES/NO]**"
                        5. **Educational Value**: Make implicit knowledge explicit (e.g., "laptops didn't exist in ancient times")

                        EXAMPLES OF GOOD Q&A-COT FORMAT:

                        Example 1:
                        Question 1: Did laptops exist during Aristotle's time?
                        Answer 1: No, laptops did not exist in Aristotle's era (they were invented much later).
                        Therefore, since laptops didn't exist then, Aristotle could not have used one.
                        The answer is **No**.

                        Example 2:
                        Question 1: What type of animal is a seahorse (fish or mammal)?
                        Answer 1: A seahorse is a type of fish (not a mammal).
                        Question 2: Do fish produce milk for their young?
                        Answer 2: No, fish do not produce milk; only mammals do.
                        Therefore, seahorses lack mammalian characteristics (like nursing their young), so a seahorse is not a mammal.
                        The answer is **No**.

                        STRATEGY GUIDELINES:
                        - For single-hop questions: Use 1 question that addresses the key uncertainty
                        - For multi-hop questions: Use 2 questions that cover the reasoning chain
                        - Questions should be specific and factual (not abstract or philosophical)
                        - Answers should provide knowledge the student needs to reach the conclusion
                        - Keep total response under 150 words for 7B model capacity

                        Your goal is to create training data where the student learns to break down complex reasoning into explicit Q&A steps."""

    return system_prompt

def create_qa_cot_user_prompt(question: str, facts_list: Optional[List[str]] = None) -> str:
    """Create the user prompt for Q&A-CoT teacher generation."""

    base_prompt = f"""QUESTION: {question}

                    Generate a Q&A-based chain-of-thought that demonstrates how the student should reason through this question. Follow the interleaved Q&A format specified in your instructions."""

    if facts_list:
        facts_text = " ".join(facts_list)
        base_prompt += f"""

                            HINT - Relevant facts to guide your reasoning: {facts_text}
                            Use these facts to determine what questions to ask."""

    return base_prompt

def call_gpt4_qa_cot(question: str, facts_list: Optional[List[str]] = None) -> str:
    """Enhanced GPT-4 call for Q&A-CoT generation."""

    client = OpenAI(api_key=OPENAI_API_KEY)

    system_prompt = create_qa_cot_teacher_prompt()
    user_prompt = create_qa_cot_user_prompt(question, facts_list)

    try:
        response = client.chat.completions.create(
            model=GPT4_MODEL,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            max_tokens=300,  # Increased for Q&A format
            temperature=0.1,  # Low for consistency
            top_p=0.95,
            frequency_penalty=0.1,  # Reduce repetitive phrasing
            presence_penalty=0.1
        )

        return response.choices[0].message.content.strip()

    except Exception as e:
        print(f"Error calling GPT-4: {e}")
        # Fallback template for testing
        return f"""Question 1: What is the key factor for answering this question?
                    Answer 1: [Factual answer based on the question context]
                    Therefore, based on this reasoning, the answer follows logically.
                    The answer is **No**."""

def parse_facts_field(facts_string: str) -> List[str]:
    """Parse the facts field from training data into a list of facts."""
    if not facts_string or not isinstance(facts_string, str):
        return []
    
    # The facts field contains sentences that should be split
    # Split by periods and clean up
    sentences = [s.strip() for s in facts_string.split('.') if s.strip()]
    
    # Further split by common delimiters and clean
    facts = []
    for sentence in sentences:
        # Split by semicolons or other sentence-like breaks
        sub_facts = [s.strip() for s in sentence.replace(';', '.').split('.') if s.strip()]
        facts.extend(sub_facts)
    
    # Filter out very short fragments
    facts = [fact for fact in facts if len(fact.split()) >= 3]
    
    return facts

# Legacy wrapper for backward compatibility
def extract_yes_no(text: str) -> str:
    """Extract a yes/no answer using professional validation."""
    validator = TeacherResponseValidator()
    result = validator.validate(text)

    if result.is_valid() and result.metadata.get('extracted_answer'):
        return result.metadata['extracted_answer'].capitalize()

    # Enhanced extraction for Q&A-CoT format
    qa_answer_match = re.search(r'The answer is \*\*(Yes|No)\*\*', text, re.IGNORECASE)
    if qa_answer_match:
        return qa_answer_match.group(1).capitalize()

    # Fallback to original patterns
    final_assessment_match = re.search(r'## Final Assessment.*?\*\*(YES|NO)\*\*', text, re.IGNORECASE | re.DOTALL)
    if final_assessment_match:
        return final_assessment_match.group(1).capitalize()

    match = re.search(r"\b(yes|no)\b", text, re.IGNORECASE)
    if match:
        return match.group(1).capitalize()

    return "No"  # Default fallback

# Get API key from environment
if not OPENAI_API_KEY and not DRY_RUN:
    print("Warning: OPENAI_API_KEY not set. Set DRY_RUN=True or provide an API key.")

print("=== ENHANCED Q&A-COT TEACHER GENERATION ===")
print("Implementing interleaved Q&A format from StrategyQA implementation plan...")

# Load training data directly (no more student drafts dependency)
with open(os.path.join(parent_dir, SAMPLE_TRAIN_PATH), 'r', encoding='utf-8') as f:
    training_data = [json.loads(line) for line in f]

print(f"Loaded {len(training_data)} questions from training data")

# Process facts field instead of decomposition
facts_available_count = 0
for item in training_data:
    facts_field = item.get('facts', '')
    if facts_field:
        facts_available_count += 1

print(f"Facts available for {facts_available_count} questions")

# Enhanced teacher response generation with Q&A-CoT format
teacher_data = []
teacher_validation_stats = {
    'total': 0,
    'valid': 0,
    'corrected': 0,
    'invalid': 0,
    'high_confidence': 0,
    'medium_confidence': 0,
    'low_confidence': 0,
    'successful_extractions': 0,
    'failed_extractions': 0,
    'qa_format_detected': 0,
    'traditional_format_detected': 0,
    'quality_scores': []
}

# Initialize teacher validator
teacher_validator = TeacherResponseValidator()

print("\n=== GENERATING Q&A-COT TEACHER RESPONSES ===")
print("Using enhanced GPT-4 prompts for interleaved self-questioning format...")

for i, item in enumerate(training_data):
    question = item['question']
    answer = item['answer']
    
    # Parse facts field instead of decomposition
    facts_string = item.get('facts', '')
    facts_list = parse_facts_field(facts_string) if facts_string else None

    if DRY_RUN:
        # Use mock Q&A-CoT response for testing
        response_text = f"""Question 1: What is the key factor for determining if this is true?
                            Answer 1: [Mock factual answer based on the question context]
                            Therefore, based on this analysis, we can determine the answer logically.
                            The answer is **No**."""

    else:
        try:
            # Use enhanced Q&A-CoT generation with facts
            response_text = call_gpt4_qa_cot(question, facts_list)

            # Validate it follows Q&A format
            if not re.search(r'Question\s+\d+\s*:', response_text):
                print(f"Warning: Response for question {i+1} doesn't follow Q&A format, retrying...")
                # Retry once with more explicit instruction
                retry_prompt = f"STRICT FORMAT REQUIRED - Generate EXACTLY this format:\n\nQuestion 1: [question]\nAnswer 1: [answer]\nTherefore, [reasoning]\nThe answer is **[YES/NO]**.\n\nQUESTION: {question}"
                response_text = call_gpt4_qa_cot(question, facts_list)

        except Exception as e:
            print(f"Error generating Q&A-CoT response for question {i+1}: {e}")
            continue

    # Validate Q&A format detection
    has_qa_format = bool(re.search(r'Question\s+\d+\s*:.*?Answer\s+\d+\s*:', response_text, re.DOTALL))
    if has_qa_format:
        teacher_validation_stats['qa_format_detected'] += 1
    else:
        teacher_validation_stats['traditional_format_detected'] += 1

    # Apply professional validation
    validation_result = teacher_validator.validate(response_text)
    teacher_validation_stats['total'] += 1
    teacher_validation_stats['quality_scores'].append(validation_result.confidence_score)

    # Update validation statistics
    if validation_result.status == ValidationStatus.VALID:
        teacher_validation_stats['valid'] += 1
    elif validation_result.status == ValidationStatus.CORRECTED:
        teacher_validation_stats['corrected'] += 1
    else:
        teacher_validation_stats['invalid'] += 1

    # Confidence tiers
    if validation_result.confidence_score >= HIGH_CONFIDENCE_THRESHOLD:
        teacher_validation_stats['high_confidence'] += 1
    elif validation_result.confidence_score >= MEDIUM_CONFIDENCE_THRESHOLD:
        teacher_validation_stats['medium_confidence'] += 1
    else:
        teacher_validation_stats['low_confidence'] += 1

    # Extract answer with enhanced Q&A patterns
    teacher_answer = extract_yes_no(response_text)
    if teacher_answer and teacher_answer.lower() in ['yes', 'no']:
        teacher_validation_stats['successful_extractions'] += 1
    else:
        teacher_validation_stats['failed_extractions'] += 1

    # Only keep medium and high confidence responses
    if validation_result.confidence_score >= MEDIUM_CONFIDENCE_THRESHOLD:
        final_response_text = validation_result.cleaned_text if validation_result.cleaned_text else response_text

        out_record = {
            'question': question,
            'teacher_thought': final_response_text,
            'teacher_answer': teacher_answer,
            'format_type': 'qa_interleaved' if has_qa_format else 'traditional',
            'facts': facts_list if facts_list else [],
            'ground_truth_answer': answer,
            'validation_metadata': {
                'confidence_score': validation_result.confidence_score,
                'status': validation_result.status.value,
                'quality_metrics': validation_result.metadata.get('quality_metrics', {}),
                'errors': validation_result.error_messages,
                'qa_format_detected': has_qa_format
            }
        }
        teacher_data.append(out_record)

    if (i + 1) % 50 == 0:
        print(f"Processed {i + 1}/{len(training_data)} teacher responses...")

# Calculate comprehensive statistics
total_teacher_processed = teacher_validation_stats['total']
if total_teacher_processed > 0:
    teacher_success_rate = (teacher_validation_stats['valid'] + teacher_validation_stats['corrected']) / total_teacher_processed * 100
    avg_teacher_confidence = sum(teacher_validation_stats['quality_scores']) / len(teacher_validation_stats['quality_scores'])
    extraction_success_rate = teacher_validation_stats['successful_extractions'] / total_teacher_processed * 100
    qa_format_rate = teacher_validation_stats['qa_format_detected'] / total_teacher_processed * 100

    print(f"\n=== Q&A-COT TEACHER GENERATION RESULTS ===")
    print(f"📊 Total processed: {total_teacher_processed}")
    print(f"✅ Valid: {teacher_validation_stats['valid']} ({teacher_validation_stats['valid']/total_teacher_processed*100:.1f}%)")
    print(f"🔧 Corrected: {teacher_validation_stats['corrected']} ({teacher_validation_stats['corrected']/total_teacher_processed*100:.1f}%)")
    print(f"❌ Invalid: {teacher_validation_stats['invalid']} ({teacher_validation_stats['invalid']/total_teacher_processed*100:.1f}%)")
    print(f"📈 Success Rate: {teacher_success_rate:.1f}% (Target: 80-85%)")
    print(f"🎯 Average Confidence: {avg_teacher_confidence:.3f}")
    print(f"🔍 Answer Extraction: {extraction_success_rate:.1f}%")

    print(f"\n=== Q&A FORMAT ADOPTION ===")
    print(f"🧠 Q&A Interleaved Format: {teacher_validation_stats['qa_format_detected']} ({qa_format_rate:.1f}%)")
    print(f"📝 Traditional Format: {teacher_validation_stats['traditional_format_detected']} ({100-qa_format_rate:.1f}%)")

    print(f"\n=== CONFIDENCE DISTRIBUTION ===")
    print(f"🔥 High (≥0.8): {teacher_validation_stats['high_confidence']} ({teacher_validation_stats['high_confidence']/total_teacher_processed*100:.1f}%)")
    print(f"🟡 Medium (0.5-0.8): {teacher_validation_stats['medium_confidence']} ({teacher_validation_stats['medium_confidence']/total_teacher_processed*100:.1f}%)")
    print(f"🔴 Low (<0.5): {teacher_validation_stats['low_confidence']} ({teacher_validation_stats['low_confidence']/total_teacher_processed*100:.1f}%)")

    print(f"\n=== DATA RETENTION ===")
    print(f"Kept after validation: {len(teacher_data)}/{total_teacher_processed} ({len(teacher_data)/total_teacher_processed*100:.1f}%)")

# Save enhanced Q&A-CoT teacher outputs
qa_cot_teacher_path = os.path.join(parent_dir, TEACHER_OUTPUTS_PATH.replace('.jsonl', '_qa_cot.jsonl'))
with open(qa_cot_teacher_path, 'w', encoding='utf-8') as f:
    for record in teacher_data:
        f.write(json.dumps(record) + '\n')

print(f"\n✅ Enhanced Q&A-CoT teacher outputs saved to {qa_cot_teacher_path}")
print(f"🚀 Q&A-CoT teacher generation complete!")

if total_teacher_processed > 0:
    if teacher_success_rate >= 80:
        print("🎯 SUCCESS: Achieved target teacher validation rate of 80%+")
    else:
        print(f"⚠️  Below target: {80 - teacher_success_rate:.1f}pp improvement needed")

    if qa_format_rate >= 90:
        print("🎯 SUCCESS: High adoption of Q&A interleaved format (90%+)")
    else:
        print(f"⚠️  Q&A format adoption at {qa_format_rate:.1f}%, may need prompt refinement")

print(f"\n=== Q&A-COT IMPLEMENTATION SUMMARY ===")
print(f"🧠 Format Innovation: Interleaved Q&A self-questioning implemented")
print(f"📚 Training Examples: {len(teacher_data)} high-quality Q&A-CoT responses")
print(f"🎯 Average Quality: {avg_teacher_confidence:.3f} confidence score")
print(f"🔍 Format Consistency: {qa_format_rate:.1f}% Q&A format adoption")
print(f"📈 Data Quality: {teacher_success_rate:.1f}% validation success rate")

print(f"\n🏁 Ready for Phase B training with Q&A-CoT supervision!")

# Update teacher outputs path for Phase B pipeline
TEACHER_OUTPUTS_PATH_QA_COT = qa_cot_teacher_path.replace(parent_dir + os.sep, '')
print(f"\n🔧 Updated path for Phase B:")
print(f"- TEACHER_OUTPUTS_PATH_QA_COT = '{TEACHER_OUTPUTS_PATH_QA_COT}'")

=== ENHANCED Q&A-COT TEACHER GENERATION ===
Implementing interleaved Q&A format from StrategyQA implementation plan...
Loaded 200 questions from training data
Facts available for 200 questions

=== GENERATING Q&A-COT TEACHER RESPONSES ===
Using enhanced GPT-4 prompts for interleaved self-questioning format...


KeyboardInterrupt: 