In [1]:
!pip install -q datasets transformers openai bitsandbytes accelerate python-dotenv huggingface_hub huggingface_hub[hf_xet]


[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


## Setup ENV var config

In [None]:
import os
from dotenv import load_dotenv

# Load environment variables from .env file if it exists
load_dotenv()

# Dataset parameters
DATASET_NAME = os.getenv('DATASET_NAME', 'voidful/StrategyQA')
TRAIN_SAMPLES = int(os.getenv('TRAIN_SAMPLES', '100'))
RANDOM_SEED = int(os.getenv('RANDOM_SEED', '42'))
USE_FULL_DATASET = os.getenv('USE_FULL_DATASET', 'False').lower() in ('true', '1', 't')


# Model parameters
MODEL_NAME = os.getenv('MODEL_NAME', 'microsoft/phi-2')
MAX_NEW_TOKENS = int(os.getenv('MAX_NEW_TOKENS', '35'))
BATCH_SIZE = int(os.getenv('BATCH_SIZE', '8'))
USE_4BIT = os.getenv('USE_4BIT', 'True').lower() in ('true', '1', 't')
MAX_SEQ_LENGTH = int(os.getenv('MAX_SEQ_LENGTH', '512'))
HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN', '')

# Generation parameters
DO_SAMPLE = os.getenv('DO_SAMPLE', 'False').lower() in ('true', '1', 't')
TEMPERATURE = float(os.getenv('TEMPERATURE', '0.7'))

# GPT-4 parameters
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', '')
GPT4_MODEL = os.getenv('GPT4_MODEL', 'gpt-4')
GPT4_MAX_TOKENS = int(os.getenv('GPT4_MAX_TOKENS', '150'))
GPT4_TEMPERATURE = float(os.getenv('GPT4_TEMPERATURE', '0.3'))
DRY_RUN = os.getenv('DRY_RUN', 'True').lower() in ('true', '1', 't')

# Student Draft Generation
STUDENT_MAX_TOKENS = int(os.getenv('STUDENT_MAX_TOKENS', '200'))
STUDENT_TEMPERATURE = float(os.getenv('STUDENT_TEMPERATURE', '0.7'))
STUDENT_BATCH_SIZE = int(os.getenv('STUDENT_BATCH_SIZE', '8'))

# Enhanced Evaluation Generation
EVAL_MAX_TOKENS = int(os.getenv('EVAL_MAX_TOKENS', '256'))
EVAL_TEMPERATURE = float(os.getenv('EVAL_TEMPERATURE', '0.7'))
EVAL_BATCH_SIZE = int(os.getenv('EVAL_BATCH_SIZE', '4'))

# Quick Evaluation
QUICK_EVAL_MAX_TOKENS = int(os.getenv('QUICK_EVAL_MAX_TOKENS', '5'))

# Training Configuration
# Phase A Training
PHASE_A_EPOCHS = int(os.getenv('PHASE_A_EPOCHS', '3'))
PHASE_A_BATCH_SIZE = int(os.getenv('PHASE_A_BATCH_SIZE', '1'))
PHASE_A_LEARNING_RATE = float(os.getenv('PHASE_A_LEARNING_RATE', '1e-4'))
PHASE_A_WARMUP_RATIO = float(os.getenv('PHASE_A_WARMUP_RATIO', '0.1'))
PHASE_A_WEIGHT_DECAY = float(os.getenv('PHASE_A_WEIGHT_DECAY', '0.01'))

# Phase B Training
PHASE_B_EPOCHS = int(os.getenv('PHASE_B_EPOCHS', '3'))
PHASE_B_BATCH_SIZE = int(os.getenv('PHASE_B_BATCH_SIZE', '4'))
PHASE_B_LEARNING_RATE = float(os.getenv('PHASE_B_LEARNING_RATE', '5e-5'))
PHASE_B_WARMUP_RATIO = float(os.getenv('PHASE_B_WARMUP_RATIO', '0.1'))
PHASE_B_WEIGHT_DECAY = float(os.getenv('PHASE_B_WEIGHT_DECAY', '0.01'))

# Progressive Curriculum Training
CURRICULUM_STAGE1_EPOCHS = int(os.getenv('CURRICULUM_STAGE1_EPOCHS', '1'))
CURRICULUM_STAGE1_LEARNING_RATE = float(os.getenv('CURRICULUM_STAGE1_LEARNING_RATE', '5e-5'))
CURRICULUM_STAGE1_WARMUP_RATIO = float(os.getenv('CURRICULUM_STAGE1_WARMUP_RATIO', '0.1'))
CURRICULUM_STAGE1_WEIGHT_DECAY = float(os.getenv('CURRICULUM_STAGE1_WEIGHT_DECAY', '0.01'))

CURRICULUM_STAGE2_EPOCHS = int(os.getenv('CURRICULUM_STAGE2_EPOCHS', '2'))
CURRICULUM_STAGE2_LEARNING_RATE = float(os.getenv('CURRICULUM_STAGE2_LEARNING_RATE', '3e-5'))
CURRICULUM_STAGE2_WARMUP_RATIO = float(os.getenv('CURRICULUM_STAGE2_WARMUP_RATIO', '0.1'))
CURRICULUM_STAGE2_WEIGHT_DECAY = float(os.getenv('CURRICULUM_STAGE2_WEIGHT_DECAY', '0.01'))

# Validation Configuration
HIGH_CONFIDENCE_THRESHOLD = float(os.getenv('HIGH_CONFIDENCE_THRESHOLD', '0.8'))
MEDIUM_CONFIDENCE_THRESHOLD = float(os.getenv('MEDIUM_CONFIDENCE_THRESHOLD', '0.5'))
LOW_CONFIDENCE_THRESHOLD = float(os.getenv('LOW_CONFIDENCE_THRESHOLD', '0.3'))
VALIDATION_ACCEPTANCE_THRESHOLD = float(os.getenv('VALIDATION_ACCEPTANCE_THRESHOLD', '0.3'))

# Quality Thresholds
VALID_QUALITY_THRESHOLD = float(os.getenv('VALID_QUALITY_THRESHOLD', '0.5'))
CORRECTED_QUALITY_THRESHOLD = float(os.getenv('CORRECTED_QUALITY_THRESHOLD', '0.3'))

# Token Emphasis Configuration
EMPHASIS_MULTIPLIER = float(os.getenv('EMPHASIS_MULTIPLIER', '2.5'))
ADAPTIVE_EMPHASIS = os.getenv('ADAPTIVE_EMPHASIS', 'True').lower() in ('true', '1', 't')

# Memory and Performance
GRADIENT_CHECKPOINTING = os.getenv('GRADIENT_CHECKPOINTING', 'True').lower() in ('true', '1', 't')
FP16 = os.getenv('FP16', 'False').lower() in ('true', '1', 't')
BF16 = os.getenv('BF16', 'True').lower() in ('true', '1', 't')
DATALOADER_NUM_WORKERS = int(os.getenv('DATALOADER_NUM_WORKERS', '0'))
DATALOADER_PERSISTENT_WORKERS = os.getenv('DATALOADER_PERSISTENT_WORKERS', 'False').lower() in ('true', '1', 't')
SKIP_MEMORY_METRICS = os.getenv('SKIP_MEMORY_METRICS', 'True').lower() in ('true', '1', 't')

# Logging and Monitoring
LOGGING_STEPS = int(os.getenv('LOGGING_STEPS', '10'))
SAVE_STRATEGY = os.getenv('SAVE_STRATEGY', 'epoch')
REPORT_TO = os.getenv('REPORT_TO', 'none')
LOAD_BEST_MODEL_AT_END = os.getenv('LOAD_BEST_MODEL_AT_END', 'False').lower() in ('true', '1', 't')

# Evaluation Configuration
EVAL_SIZE = int(os.getenv('EVAL_SIZE', '100'))
ENHANCED_EVAL_BATCH_SIZE = int(os.getenv('ENHANCED_EVAL_BATCH_SIZE', '4'))
# File paths
parent_dir = os.path.dirname(os.getcwd())
DATA_DIR = os.path.join(parent_dir, os.getenv("DATA_DIR", "data"))
RAW_DIR = os.path.join(DATA_DIR, 'raw')
SAMPLE_TRAIN_PATH = os.path.join(DATA_DIR, '/train/sample_train.jsonl')
STUDENT_DRAFTS_PATH = os.path.join(DATA_DIR, '/student/student_drafts.jsonl')
CLEANED_STUDENT_DRAFTS_PATH = os.path.join(DATA_DIR, '/student/cleaned_student_drafts.jsonl')
TEACHER_OUTPUTS_PATH = os.path.join(DATA_DIR, '/teacher/teacher_outputs.jsonl')
BASELINE_PATH = os.path.join(DATA_DIR, '/train/train_baseline.jsonl')
COT_PATH = os.path.join(DATA_DIR, '/train/train_cot.jsonl')

# Print configuration
print("=== Configuration ===")
print(f"Dataset: {DATASET_NAME}")
print(f"Model: {MODEL_NAME}")
print(f"Batch size: {BATCH_SIZE}")
print(f"4-bit quantization: {USE_4BIT}")
print(f"GPT-4 dry run: {DRY_RUN}")
print("=="*10)

=== Configuration ===
Dataset: voidful/StrategyQA
Model: microsoft/Phi-3.5-mini-instruct
Batch size: 8
4-bit quantization: True
GPT-4 dry run: False


## Load Dataset

In [None]:
from datasets import load_dataset
import json
import os

def save_jsonl(data, filepath):
    """Save data to a JSONL file."""
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    with open(filepath, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')

def load_jsonl(filepath):
    """Load data from a JSONL file."""
    with open(filepath, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f]

# Check if files exist
train_path = os.path.join(RAW_DIR, 'strategyqa_train.jsonl')
val_path = os.path.join(RAW_DIR, 'strategyqa_validation.jsonl')
test_path = os.path.join(RAW_DIR, 'strategyqa_test.jsonl')

print("Looking for files in:")
print(f"- Train: {train_path}")
print(f"- Val: {val_path}")
print(f"- Test: {test_path}")

# Create data directory
os.makedirs(RAW_DIR, exist_ok=True)

files_exist = all(os.path.exists(p) for p in [train_path, val_path, test_path])
print(f"Files exist: {files_exist}")

if files_exist:
    print("Loading dataset from local files...")
    train_data = load_jsonl(train_path)
    val_data = load_jsonl(val_path)
    test_data = load_jsonl(test_path)
else:
    print("Downloading and saving dataset...")
    # Load the dataset from HuggingFace
    dataset = load_dataset(DATASET_NAME)

    train_data = list(dataset['train'])
    val_data = list(dataset['validation'])
    test_data = list(dataset['test'])

    # Save to files for future use
    save_jsonl(train_data, train_path)
    save_jsonl(val_data, val_path)
    save_jsonl(test_data, test_path)

    print(f"Train: {len(train_data)} examples")
    print(f"Val: {len(val_data)} examples")
    print(f"Test: {len(test_data)} examples")

# Create sample training data
sample_train_path = os.path.join(DATA_DIR, '/train/sample_train.jsonl')

if not USE_FULL_DATASET:
    # Create a smaller sample for faster development
    import random
    random.seed(RANDOM_SEED)
    target_train_sampled = random.sample(train_data, min(TRAIN_SAMPLES, len(train_data)))
    print(f"Sample size: {len(target_train_sampled)} examples")
else:
    # Use all training data
    target_train_sampled = train_data
    print(f"Using full training set: {len(target_train_sampled)} examples")

# Also create combined train+validation for Q&A-CoT (more data)
full_train_val = train_data + val_data
full_train_val_path = os.path.join(DATA_DIR, '/train/full_train_val.jsonl')

# Save both sampled and full datasets
save_jsonl(target_train_sampled, sample_train_path)
save_jsonl(full_train_val, full_train_val_path)

print(f"Full training set saved to {train_path}")
print(f"Validation set saved to {val_path}")
print(f"Sampled train set (≈{TRAIN_SAMPLES} entries) saved to {sample_train_path}")
print(f"Combined train+val set ({len(full_train_val)} entries) saved to {full_train_val_path}")

# Update file paths based on choice
if USE_FULL_DATASET:
    # Update the global path variables to point to the full dataset
    SAMPLE_TRAIN_PATH = os.path.join(DATA_DIR, 'full_train_val.jsonl')
    print(f"📊 Updated SAMPLE_TRAIN_PATH to use full dataset: {SAMPLE_TRAIN_PATH}")
    # Update file names to avoid confusion
    STUDENT_DRAFTS_PATH = os.path.join(DATA_DIR, '/student/student_drafts_full.jsonl')
    CLEANED_STUDENT_DRAFTS_PATH = os.path.join(DATA_DIR, '/student/cleaned_student_drafts_full.jsonl')
    TEACHER_OUTPUTS_PATH = os.path.join(DATA_DIR, 'teacher_outputs_full.jsonl')
    BASELINE_PATH = os.path.join(DATA_DIR, 'train_baseline_full.jsonl')
    COT_PATH = os.path.join(DATA_DIR, 'train_cot_full.jsonl')
    print(f"📊 Updated output paths to use '_full' suffix for clarity")

  from .autonotebook import tqdm as notebook_tqdm


Looking for files in:
- Train: c:\Users\noham\Desktop\Self-Improving-LLM\data\raw\strategyqa_train.jsonl
- Val: c:\Users\noham\Desktop\Self-Improving-LLM\data\raw\strategyqa_validation.jsonl
- Test: c:\Users\noham\Desktop\Self-Improving-LLM\data\raw\strategyqa_test.jsonl
Files exist: True
Loading dataset from local files...
Sample size: 200 examples
Full training set saved to c:\Users\noham\Desktop\Self-Improving-LLM\data\raw\strategyqa_train.jsonl
Validation set saved to c:\Users\noham\Desktop\Self-Improving-LLM\data\raw\strategyqa_validation.jsonl
Sampled train set (≈200 entries) saved to c:\Users\noham\Desktop\Self-Improving-LLM\data\sample_train.jsonl
Combined train+val set (1603 entries) saved to c:\Users\noham\Desktop\Self-Improving-LLM\data\full_train_val.jsonl


## Generate Student Drafts

In this section we load a base language model (e.g. `meta-llama/Llama-2-7b-hf` or `gpt2`) and generate a short *student draft* for each question in the sampled training set.  A draft consists of a yes/no answer followed by one or two clarifying questions, as specified in the data‑generation loop.  Adjust the model name based on your available hardware and licences.

> **Tip:** On Colab, you can enable a GPU via *Runtime → Change runtime type → GPU* and use half‑precision weights to reduce memory usage.  For demonstration, we use `gpt2` (which is small) to keep the example runnable on CPU.


In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from torch.utils.data import DataLoader
from datasets import load_dataset
import torch
import json
import os
from tqdm import tqdm

def setup_dataset(input_path: str, tokenizer, batch_size: int = BATCH_SIZE):
    """Load and prepare dataset for GPU processing."""
    # Load the dataset
    dataset = load_dataset('json', data_files=input_path, split='train')

    # Keep the original questions for reference
    original_questions = dataset['question']

    # Tokenization function
    def tokenize_function(examples):
        return tokenizer(
            examples['question'],
            truncation=True,
            padding='max_length',
            max_length=MAX_SEQ_LENGTH,
            return_tensors=None  # Return as list, not tensors
        )

    # Apply tokenization
    tokenized_dataset = dataset.map(tokenize_function, batched=True)

    # Create a custom dataset that includes both tokenized data and original questions
    class QADataset(torch.utils.data.Dataset):
        def __init__(self, tokenized_data, original_questions):
            self.tokenized_data = tokenized_data
            self.original_questions = original_questions

        def __len__(self):
            return len(self.tokenized_data)

        def __getitem__(self, idx):
            item = {
                'input_ids': torch.tensor(self.tokenized_data[idx]['input_ids']),
                'attention_mask': torch.tensor(self.tokenized_data[idx]['attention_mask']),
                'question': self.original_questions[idx]
            }
            return item

    # Create custom dataset
    custom_dataset = QADataset(tokenized_dataset, original_questions)

    # Create DataLoader
    loader = DataLoader(
        custom_dataset,
        batch_size=batch_size,
        shuffle=False  # Keep order for output matching
    )

    return loader

# GPU setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model setup
print(f"Loading model: {MODEL_NAME}")

# Use 4-bit quantization if enabled and on GPU
if device.type == 'cuda' and USE_4BIT:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )
    print("Loading model in 4-bit quantization...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto"
    )
else:
    print("Loading model in standard precision...")
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Ensure the tokenizer has a padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model.eval()

if device.type == 'cuda':
    print(f"GPU Memory after model load: {torch.cuda.memory_allocated()/1e9:.2f} GB")

# Load and prepare dataset (fix path)
parent_dir = os.path.dirname(os.getcwd())
sample_train_path_full = os.path.join(parent_dir, SAMPLE_TRAIN_PATH)
print(f"Loading dataset from {sample_train_path_full} with batch size {BATCH_SIZE}")
train_loader = setup_dataset(sample_train_path_full, tokenizer, batch_size=BATCH_SIZE)

Using device: cuda
Loading model: microsoft/Phi-3.5-mini-instruct
Loading model in 4-bit quantization...


Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.80s/it]


GPU Memory after model load: 2.26 GB
Loading dataset from c:\Users\noham\Desktop\Self-Improving-LLM\data\sample_train.jsonl with batch size 8


Generating train split: 200 examples [00:00, 6861.96 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 1408.62 examples/s]


In [5]:
def build_messages_self_questioning(q: str):
    """
    SIMPLIFIED Self-Questioning Format - Focus on working generation first
    """
    return [
        {"role": "system",
         "content": "You are a reasoning assistant that answers yes/no questions with brief reasoning.\n\n"
                   "RESPONSE FORMAT:\n"
                   "Question 1: [Ask a relevant question about the topic]\n"
                   "Answer 1: [Answer with facts]\n"
                   "Therefore: [Brief conclusion]\n"
                   "Final answer: Yes (or No)\n\n"
                   "IMPORTANT: Always end with 'Final answer: Yes' or 'Final answer: No'. Never write 'Uncertain' or 'Maybe'."},

        {"role": "user", "content": "Did Aristotle use a laptop?"},
        {"role": "assistant", "content": "Question 1: Did laptops exist in Aristotle's time?\n"
                                        "Answer 1: No, laptops were invented much later.\n"
                                        "Therefore: Aristotle could not have used something that didn't exist.\n"
                                        "Final answer: No"},

        {"role": "user", "content": q},
    ]

def generate_batch_drafts_simplified(batch):
    """Generate drafts with SIMPLIFIED approach - no complex constraints"""
    import torch
    import gc

    # Set padding side to left for generation
    tokenizer.padding_side = 'left'

    # Create simple prompts
    prompts = [
        tokenizer.apply_chat_template(build_messages_self_questioning(q),
                                    tokenize=False,
                                    add_generation_prompt=True)
        for q in batch["question"]
    ]

    # Tokenize
    inputs = tokenizer(
        prompts,
        padding=True,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_SEQ_LENGTH
    ).to(device)

    # SIMPLIFIED generation - no complex constraints
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=STUDENT_MAX_TOKENS,
            do_sample=True,
            temperature=STUDENT_TEMPERATURE,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True
        )

    # Decode outputs - handle both tensor and dict outputs
    prompt_lens = inputs["attention_mask"].sum(dim=1)
    if hasattr(outputs, 'sequences'):
        decoded = [tokenizer.decode(seq[p_len:], skip_special_tokens=True).strip()
                   for seq, p_len in zip(outputs.sequences, prompt_lens)]
    else:
        decoded = [tokenizer.decode(seq[p_len:], skip_special_tokens=True).strip()
                   for seq, p_len in zip(outputs, prompt_lens)]
    
    # Simple post-processing - just ensure we have some content
    drafts = []
    for response in decoded:
        if len(response.strip()) < 10:  # Too short
            # Simple fallback
            drafts.append("Question 1: What factors are relevant?\nAnswer 1: Need to consider the key aspects.\nTherefore: Based on available information.\nFinal answer: No")
        else:
            drafts.append(response)

    # Memory cleanup
    del inputs, outputs
    if device.type == 'cuda':
        torch.cuda.empty_cache()
    gc.collect()

    return drafts

# Use simplified version
generate_batch_drafts = generate_batch_drafts_simplified

In [6]:
# Import tqdm for progress bars
from tqdm import tqdm
import json
import os
import torch
import gc

# Create output directory if it doesn't exist (fix path)
parent_dir = os.path.dirname(os.getcwd())
student_drafts_path_full = os.path.join(parent_dir, STUDENT_DRAFTS_PATH)
os.makedirs(os.path.dirname(student_drafts_path_full), exist_ok=True)

# Check if we already have partial results (for resuming)
checkpoint_file = student_drafts_path_full + '.checkpoint'
processed_count = 0
existing_results = []

if os.path.exists(checkpoint_file):
    print(f"Found checkpoint file, loading existing results...")
    with open(checkpoint_file, 'r', encoding='utf-8') as f:
        existing_results = [json.loads(line) for line in f]
    processed_count = len(existing_results)
    print(f"Resuming from {processed_count} already processed examples")

# Process batches and write outputs with error handling
print(f"Writing student drafts to {student_drafts_path_full}")
print(f"Processing {len(train_loader)} batches (starting from batch {processed_count // BATCH_SIZE})")

batch_start_idx = processed_count // BATCH_SIZE
all_results = existing_results.copy()

with open(checkpoint_file, 'a' if existing_results else 'w', encoding='utf-8') as checkpoint_f:
    try:
        for batch_idx, batch in enumerate(tqdm(train_loader, desc='Generating drafts', initial=batch_start_idx)):
            # Skip already processed batches
            if batch_idx < batch_start_idx:
                continue

            # Get original questions directly from the batch
            questions = batch['question']

            # Move input_ids and attention_mask to device
            input_batch = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }

            try:
                # Generate drafts for the batch with error handling
                drafts = generate_batch_drafts({'question': questions})

                # Write results to checkpoint file immediately
                batch_results = []
                for q, draft in zip(questions, drafts):
                    out_rec = {
                        'question': q,
                        'student_draft': draft
                    }
                    batch_results.append(out_rec)
                    all_results.append(out_rec)
                    checkpoint_f.write(json.dumps(out_rec, ensure_ascii=False) + '\n')

                # Flush to disk after each batch
                checkpoint_f.flush()

                # Print progress
                if device.type == 'cuda':
                    memory_gb = torch.cuda.memory_allocated()/1e9
                    print(f"Batch {batch_idx+1}/{len(train_loader)} completed. GPU Memory: {memory_gb:.2f} GB")

                # Periodic cleanup every 10 batches
                if (batch_idx + 1) % 10 == 0:
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
                    gc.collect()
                    print(f"✅ Memory cleanup completed after batch {batch_idx+1}")

            except RuntimeError as e:
                print(f"❌ Error in batch {batch_idx}: {e}")
                print(f"🔄 Attempting to continue with next batch...")

                # Clean up memory after error
                if device.type == 'cuda':
                    torch.cuda.empty_cache()
                gc.collect()

                # Save partial results for failed batch
                for q in questions:
                    failed_rec = {
                        'question': q,
                        'student_draft': 'Answer: Error during generation\nQuestions: What went wrong? How to fix this?'
                    }
                    all_results.append(failed_rec)
                    checkpoint_f.write(json.dumps(failed_rec, ensure_ascii=False) + '\n')

                checkpoint_f.flush()
                continue

    except KeyboardInterrupt:
        print(f"\n🛑 Generation interrupted by user. Progress saved to checkpoint.")
        print(f"📊 Processed {len(all_results)} examples so far")

    except Exception as e:
        print(f"\n❌ Unexpected error: {e}")
        print(f"📊 Progress saved. Processed {len(all_results)} examples")

# Copy final results to main output file
print(f"\n📝 Copying {len(all_results)} results to final output file...")
with open(student_drafts_path_full, 'w', encoding='utf-8') as out_f:
    for result in all_results:
        out_f.write(json.dumps(result, ensure_ascii=False) + '\n')

print(f"✅ Student drafts written to {student_drafts_path_full}")
print(f"📊 Total examples processed: {len(all_results)}")
print(f"🔄 Checkpoint file: {checkpoint_file}")
print("Using optimized student prompts with error recovery and checkpointing!")

Writing student drafts to c:\Users\noham\Desktop\Self-Improving-LLM\data\student_drafts.jsonl
Processing 25 batches (starting from batch 0)


Generating drafts:   4%|▍         | 1/25 [01:03<25:19, 63.30s/it]

Batch 1/25 completed. GPU Memory: 2.27 GB


Generating drafts:   8%|▊         | 2/25 [01:28<15:46, 41.14s/it]

Batch 2/25 completed. GPU Memory: 2.27 GB


Generating drafts:  12%|█▏        | 3/25 [02:31<18:44, 51.10s/it]

Batch 3/25 completed. GPU Memory: 2.27 GB


Generating drafts:  16%|█▌        | 4/25 [04:19<25:37, 73.22s/it]

Batch 4/25 completed. GPU Memory: 2.27 GB


Generating drafts:  20%|██        | 5/25 [06:06<28:32, 85.61s/it]

Batch 5/25 completed. GPU Memory: 2.27 GB


Generating drafts:  24%|██▍       | 6/25 [06:42<21:42, 68.55s/it]

Batch 6/25 completed. GPU Memory: 2.27 GB


Generating drafts:  28%|██▊       | 7/25 [07:08<16:27, 54.85s/it]

Batch 7/25 completed. GPU Memory: 2.27 GB


Generating drafts:  32%|███▏      | 8/25 [09:24<22:49, 80.56s/it]

Batch 8/25 completed. GPU Memory: 2.27 GB


Generating drafts:  36%|███▌      | 9/25 [12:05<28:14, 105.90s/it]

Batch 9/25 completed. GPU Memory: 2.27 GB


Generating drafts:  40%|████      | 10/25 [14:46<30:40, 122.67s/it]

Batch 10/25 completed. GPU Memory: 2.27 GB
✅ Memory cleanup completed after batch 10


Generating drafts:  44%|████▍     | 11/25 [17:26<31:17, 134.11s/it]

Batch 11/25 completed. GPU Memory: 2.27 GB


Generating drafts:  48%|████▊     | 12/25 [18:11<23:10, 107.00s/it]

Batch 12/25 completed. GPU Memory: 2.27 GB


Generating drafts:  52%|█████▏    | 13/25 [19:38<20:12, 101.01s/it]

Batch 13/25 completed. GPU Memory: 2.27 GB


Generating drafts:  56%|█████▌    | 14/25 [22:18<21:47, 118.82s/it]

Batch 14/25 completed. GPU Memory: 2.27 GB


Generating drafts:  60%|██████    | 15/25 [23:17<16:47, 100.75s/it]

Batch 15/25 completed. GPU Memory: 2.27 GB


Generating drafts:  64%|██████▍   | 16/25 [25:57<17:48, 118.67s/it]

Batch 16/25 completed. GPU Memory: 2.27 GB


Generating drafts:  68%|██████▊   | 17/25 [28:37<17:29, 131.17s/it]

Batch 17/25 completed. GPU Memory: 2.27 GB


Generating drafts:  72%|███████▏  | 18/25 [30:47<15:14, 130.65s/it]

Batch 18/25 completed. GPU Memory: 2.27 GB


Generating drafts:  76%|███████▌  | 19/25 [31:37<10:38, 106.43s/it]

Batch 19/25 completed. GPU Memory: 2.27 GB


Generating drafts:  80%|████████  | 20/25 [32:43<07:51, 94.25s/it] 

Batch 20/25 completed. GPU Memory: 2.27 GB
✅ Memory cleanup completed after batch 20


Generating drafts:  84%|████████▍ | 21/25 [35:25<07:38, 114.58s/it]

Batch 21/25 completed. GPU Memory: 2.27 GB


Generating drafts:  88%|████████▊ | 22/25 [38:08<06:27, 129.28s/it]

Batch 22/25 completed. GPU Memory: 2.27 GB


Generating drafts:  92%|█████████▏| 23/25 [40:10<04:14, 127.10s/it]

Batch 23/25 completed. GPU Memory: 2.27 GB


Generating drafts:  96%|█████████▌| 24/25 [40:36<01:36, 96.69s/it] 

Batch 24/25 completed. GPU Memory: 2.27 GB


Generating drafts: 100%|██████████| 25/25 [41:05<00:00, 98.61s/it]

Batch 25/25 completed. GPU Memory: 2.27 GB

📝 Copying 200 results to final output file...
✅ Student drafts written to c:\Users\noham\Desktop\Self-Improving-LLM\data\student_drafts.jsonl
📊 Total examples processed: 200
🔄 Checkpoint file: c:\Users\noham\Desktop\Self-Improving-LLM\data\student_drafts.jsonl.checkpoint
Using optimized student prompts with error recovery and checkpointing!





In [None]:
import json
import re
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
from abc import ABC, abstractmethod
import logging

class ValidationStatus(Enum):
    VALID = "valid"
    INVALID = "invalid"
    CORRECTED = "corrected"
    FAILED = "failed"

@dataclass
class ValidationResult:
    status: ValidationStatus
    original_text: str
    cleaned_text: Optional[str]
    confidence_score: float  # 0.0 - 1.0
    error_messages: List[str]
    metadata: Dict[str, Any]

    def is_valid(self) -> bool:
        return self.status in [ValidationStatus.VALID, ValidationStatus.CORRECTED]

@dataclass
class StudentResponse:
    answer: str  # Yes/No/Uncertain
    reasoning: str
    questions: List[str]
    confidence_score: float
    validation_errors: List[str]
    quality_metrics: Dict[str, float]

class BaseResponseValidator(ABC):
    def __init__(self):
        self.logger = logging.getLogger(self.__class__.__name__)

    @abstractmethod
    def validate(self, response_text: str) -> ValidationResult:
        pass

    def _calculate_base_confidence(self, parsing_success: bool, content_quality: float) -> float:
        """Calculate base confidence score from parsing and content quality."""
        parsing_score = 1.0 if parsing_success else 0.3
        return (parsing_score * 0.6) + (content_quality * 0.4)

class StudentDraftValidator(BaseResponseValidator):
    def __init__(self):
        super().__init__()
        self.answer_patterns = [
            r'Final answer:\s*(Yes|No)',  # New primary pattern
            r'(?:Answer|The answer is)\s*:?\s*(Yes|No|Uncertain)\b',
            r'\*\*Answer\*\*\s*:?\s*(Yes|No|Uncertain)\b',
            r'\b(Yes|No)\b(?=\s*[.!]?\s*$)',  # Final yes/no
        ]
        self.question_patterns = [
            r'Questions?\s*:?\s*(.+?)(?=\n|$)',
            r'Question\s+\d+\s*:?\s*(.+?)(?=\n|Answer|$)',
            r'[^.!?]*\?',  # Any sentence ending with ?
        ]

    def validate(self, response_text: str) -> ValidationResult:
        """Validate student draft with Q&A format support."""
        errors = []
        metadata = {}

        # Layer 1: Structural Parsing
        parsed_data = self._parse_structure(response_text)
        parsing_success = parsed_data is not None

        if not parsing_success:
            errors.append("Failed to parse basic structure")
            # Attempt correction
            corrected = self._attempt_correction(response_text)
            if corrected:
                parsed_data = self._parse_structure(corrected)
                if parsed_data:
                    parsing_success = True
                    errors.append("Structure corrected automatically")

        # Layer 2: Content Validation
        if parsed_data:
            quality_metrics = self._assess_content_quality(parsed_data, response_text)
            metadata['quality_metrics'] = quality_metrics

            # Check Q&A format support
            qa_format = self._detect_qa_format(response_text)
            metadata['qa_format_detected'] = qa_format

            if qa_format:
                qa_quality = self._assess_qa_quality(response_text)
                metadata['qa_quality'] = qa_quality
                quality_metrics['overall'] = (quality_metrics['overall'] + qa_quality) / 2
        else:
            quality_metrics = {'overall': 0.0, 'answer_quality': 0.0, 'reasoning_quality': 0.0}

        # Layer 4: Confidence Scoring
        confidence_score = self._calculate_confidence(
            parsing_success, quality_metrics, len(errors)
        )

        # Determine final status - require valid Yes/No answer
        answer = parsed_data.get('answer') if parsed_data else None
        has_valid_answer = answer and answer in ['Yes', 'No']
        
        if parsing_success and quality_metrics['overall'] >= 0.7 and has_valid_answer:
            status = ValidationStatus.VALID
            cleaned_text = self._format_output(parsed_data) if parsed_data else response_text
        elif parsing_success and quality_metrics['overall'] >= 0.5 and has_valid_answer:
            status = ValidationStatus.CORRECTED
            cleaned_text = self._format_output(parsed_data) if parsed_data else response_text
        else:
            status = ValidationStatus.INVALID
            cleaned_text = None

        return ValidationResult(
            status=status,
            original_text=response_text,
            cleaned_text=cleaned_text,
            confidence_score=confidence_score,
            error_messages=errors,
            metadata=metadata
        )

    def _extract_last_final_answer(self, text: str) -> Optional[str]:
        """Extract the LAST occurrence of final answer patterns (prefer last answer)."""
        # Look for all patterns and find the last one
        all_matches = []
        
        # Pattern 1: Final answer: Yes/No (preferred)
        for match in re.finditer(r'Final answer:\s*(Yes|No)', text, re.IGNORECASE):
            all_matches.append((match.end(), match.group(1).capitalize()))
        
        # Pattern 2: The answer is **Yes/No**
        for match in re.finditer(r'The answer is\s*\*\*\s*(Yes|No)\s*\*\*', text, re.IGNORECASE):
            all_matches.append((match.end(), match.group(1).capitalize()))
        
        # Pattern 3: **Answer**: Yes/No
        for match in re.finditer(r'\*\*Answer\*\*\s*:?\s*(Yes|No)', text, re.IGNORECASE):
            all_matches.append((match.end(), match.group(1).capitalize()))
        
        # Return the last match (highest position)
        if all_matches:
            # Sort by position and return the last answer
            all_matches.sort(key=lambda x: x[0])
            return all_matches[-1][1]
        
        return None

    def _parse_structure(self, text: str) -> Optional[Dict[str, Any]]:
        """Parse response into structured components."""
        # Try Q&A format first (new format)
        qa_match = self._parse_qa_format(text)
        if qa_match:
            return qa_match

        # Fallback to traditional Answer/Questions format
        answer = None
        questions = []

        # Extract answer using last answer preference
        answer = self._extract_last_final_answer(text)
        
        # If no final answer found, try other patterns
        if not answer:
            for pattern in self.answer_patterns[1:]:  # Skip first pattern (already tried)
                matches = list(re.finditer(pattern, text, re.IGNORECASE | re.MULTILINE))
                if matches:
                    # Get the last match
                    answer = matches[-1].group(1).strip().capitalize()
                    break

        # Extract questions
        for pattern in self.question_patterns:
            matches = re.findall(pattern, text, re.IGNORECASE | re.MULTILINE)
            if matches:
                questions.extend([q.strip() for q in matches if q.strip()])
                break

        if answer or questions:
            return {
                'answer': answer,
                'questions': questions,
                'reasoning': text,  # Full text as reasoning
                'format_type': 'traditional'
            }

        return None

    def _parse_qa_format(self, text: str) -> Optional[Dict[str, Any]]:
        """Parse new Q&A interleaved format."""
        # Look for Question N: ... Answer N: ... pattern
        qa_pattern = r'Question\s+(\d+)\s*:?\s*(.+?)(?=Answer\s+\1|$).*?Answer\s+\1\s*:?\s*(.+?)(?=Question\s+\d+|Therefore|\*\*Answer\*\*|Final answer|$)'
        matches = re.findall(qa_pattern, text, re.DOTALL | re.IGNORECASE)

        if matches:
            questions = []
            reasoning_parts = []

            for num, question, answer in matches:
                questions.append(question.strip())
                reasoning_parts.append(f"Q{num}: {question.strip()}\nA{num}: {answer.strip()}")

            # Extract final answer using last answer preference
            final_answer = self._extract_last_final_answer(text)
            
            # Fallback to old patterns if needed
            if not final_answer:
                final_answer_match = re.search(r'(?:Therefore|The answer is|\*\*Answer\*\*)\s*:?\s*(Yes|No|Uncertain)', text, re.IGNORECASE)
                final_answer = final_answer_match.group(1).capitalize() if final_answer_match else None

            return {
                'answer': final_answer,
                'questions': questions,
                'reasoning': '\n\n'.join(reasoning_parts),
                'format_type': 'qa_interleaved',
                'qa_pairs': len(matches)
            }

        return None

    def _detect_qa_format(self, text: str) -> bool:
        """Detect if text uses Q&A interleaved format."""
        qa_pattern = r'(?:Question\s+\d+\s*:.*?Answer\s+\d+\s*:|Answer\s*:.*?Questions?\s*:)'
        return bool(re.search(qa_pattern, text, re.DOTALL | re.IGNORECASE))

    def _assess_content_quality(self, parsed_data: Dict[str, Any], full_text: str) -> Dict[str, float]:
        """Assess quality of parsed content."""
        # Answer quality
        answer = parsed_data.get('answer', '')
        answer_quality = 1.0 if answer in ['Yes', 'No'] else 0.0  # Ban Uncertain answers

        # Questions quality
        questions = parsed_data.get('questions', [])
        if questions:
            # Quality based on question count, length, and question marks
            question_count_score = min(len(questions) / 3, 1.0)  # Optimal: 2-3 questions
            has_question_marks = sum(1 for q in questions if '?' in q) / len(questions)
            avg_length = sum(len(q.split()) for q in questions) / len(questions)
            length_score = min(avg_length / 8, 1.0)  # Optimal: ~8 words

            questions_quality = (question_count_score + has_question_marks + length_score) / 3
        else:
            questions_quality = 0.0

        # Reasoning quality (basic heuristics)
        reasoning = parsed_data.get('reasoning', '')
        reasoning_length = len(reasoning.split())
        reasoning_quality = min(reasoning_length / 50, 1.0)  # Rough heuristic

        # Overall quality
        overall_quality = (answer_quality * 0.4 + questions_quality * 0.4 + reasoning_quality * 0.2)

        return {
            'overall': overall_quality,
            'answer_quality': answer_quality,
            'questions_quality': questions_quality,
            'reasoning_quality': reasoning_quality
        }

    def _assess_qa_quality(self, text: str) -> float:
        """Assess quality specific to Q&A format."""
        qa_pairs = re.findall(r'Question\s+(\d+)\s*:.*?Answer\s+\1\s*:', text, re.DOTALL | re.IGNORECASE)
        if not qa_pairs:
            return 0.0

        # Quality based on number of Q&A pairs and coherence
        pair_count_score = min(len(qa_pairs) / 3, 1.0)  # Optimal: 2-3 pairs

        # Check for logical progression (sequential numbering)
        sequential_score = 1.0
        for i, pair_num in enumerate(qa_pairs, 1):
            if int(pair_num) != i:
                sequential_score -= 0.2

        return max(0.0, (pair_count_score + max(0.0, sequential_score)) / 2)

    def _attempt_correction(self, text: str) -> Optional[str]:
        """Attempt to correct common formatting issues."""
        # Common corrections
        corrected = text

        # Fix missing colons
        corrected = re.sub(r'\b(Answer|Questions?)\b(?!:)', r'\1:', corrected)

        # Fix answer format
        corrected = re.sub(r'\b(yes|no)\b(?!\w)', lambda m: m.group(1).capitalize(), corrected, flags=re.IGNORECASE)

        # Add missing questions if only answer present
        if 'Answer:' in corrected and 'Question' not in corrected:
            corrected += '\nQuestions: What additional context would help clarify this?'

        return corrected if corrected != text else None

    def _calculate_confidence(self, parsing_success: bool, quality_metrics: Dict[str, float], error_count: int) -> float:
        """Calculate overall confidence score."""
        base_confidence = self._calculate_base_confidence(parsing_success, quality_metrics['overall'])

        # Penalize for errors
        error_penalty = min(error_count * 0.1, 0.3)

        # Bonus for high-quality answers
        quality_bonus = 0.1 if quality_metrics['answer_quality'] >= 0.9 else 0.0

        return max(0.0, min(1.0, base_confidence - error_penalty + quality_bonus))

    def _format_output(self, parsed_data: Dict[str, Any]) -> str:
        """Format parsed data into clean output."""
        answer = parsed_data.get('answer', 'No')  # Removed 'Uncertain' fallback - force decisive answer
        questions = parsed_data.get('questions', [])

        if parsed_data.get('format_type') == 'qa_interleaved':
            # Preserve Q&A format with final answer
            reasoning = parsed_data.get('reasoning', '')
            return f"{reasoning}\n\nFinal answer: {answer}"
        else:
            # Traditional format
            questions_text = '; '.join(questions) if questions else 'What additional information is needed?'
            return f"Answer: {answer}\nQuestions: {questions_text}"

# Legacy function wrappers for backward compatibility
def clean_student_draft(draft_text):
    """Legacy wrapper - now uses professional validation."""
    validator = StudentDraftValidator()
    result = validator.validate(draft_text)

    if result.is_valid():
        return result.cleaned_text
    else:
        # Fallback to original extraction for compatibility
        return extract_answer_and_questions(draft_text)

def extract_answer_and_questions(text):
    """Enhanced extraction with better pattern matching."""
    validator = StudentDraftValidator()
    parsed = validator._parse_structure(text)

    if parsed:
        return validator._format_output(parsed)

    # Ultimate fallback - removed 'Uncertain', force 'No' 
    return "Answer: No\nQuestions: What additional information is needed?"

# Load and clean student drafts with enhanced validation
student_drafts_path_full = os.path.join(parent_dir, STUDENT_DRAFTS_PATH)
print(f"Loading student drafts from {student_drafts_path_full}")

with open(student_drafts_path_full, 'r', encoding='utf-8') as f:
    student_drafts = [json.loads(line) for line in f]

print(f"Total student drafts: {len(student_drafts)}")

# Initialize professional validation
validator = StudentDraftValidator()
cleaned_drafts = []
validation_statistics = {
    'total': len(student_drafts),
    'valid': 0,
    'corrected': 0,
    'invalid': 0,
    'high_confidence': 0,  # >= 0.8
    'medium_confidence': 0,  # 0.5 - 0.8
    'low_confidence': 0,  # < 0.5
    'qa_format_detected': 0,
    'traditional_format': 0,
    'quality_scores': []
}

print("\n=== PROFESSIONAL STUDENT DRAFT VALIDATION ===")
print("Using multi-layered validation with Q&A format support...")

for i, draft in enumerate(student_drafts):
    original_text = draft['student_draft']

    # Apply professional validation
    result = validator.validate(original_text)

    # Update statistics
    validation_statistics['quality_scores'].append(result.confidence_score)

    if result.status == ValidationStatus.VALID:
        validation_statistics['valid'] += 1
    elif result.status == ValidationStatus.CORRECTED:
        validation_statistics['corrected'] += 1
    else:
        validation_statistics['invalid'] += 1

    # Confidence tiers
    if result.confidence_score >= HIGH_CONFIDENCE_THRESHOLD:
        validation_statistics['high_confidence'] += 1
    elif result.confidence_score >= MEDIUM_CONFIDENCE_THRESHOLD:
        validation_statistics['medium_confidence'] += 1
    else:
        validation_statistics['low_confidence'] += 1

    # Format detection
    if result.metadata.get('qa_format_detected', False):
        validation_statistics['qa_format_detected'] += 1
    else:
        validation_statistics['traditional_format'] += 1

    # Process based on confidence tier
    if result.confidence_score >= VALIDATION_ACCEPTANCE_THRESHOLD:  # Accept high and medium confidence
        cleaned_draft = {
            'question': draft['question'],
            'student_draft': result.cleaned_text,
            'validation_metadata': {
                'confidence_score': result.confidence_score,
                'status': result.status.value,
                'quality_metrics': result.metadata.get('quality_metrics', {}),
                'errors': result.error_messages
            }
        }
        cleaned_drafts.append(cleaned_draft)
    # Note: Low confidence responses are rejected

    if (i + 1) % 100 == 0:
        print(f"Processed {i + 1}/{len(student_drafts)} drafts...")

# Calculate success rates
total_processed = validation_statistics['total']
success_rate = (validation_statistics['valid'] + validation_statistics['corrected']) / total_processed * 100
average_confidence = sum(validation_statistics['quality_scores']) / len(validation_statistics['quality_scores'])

print(f"\n=== VALIDATION RESULTS ===")
print(f"Total processed: {total_processed}")
print(f"✅ Valid: {validation_statistics['valid']} ({validation_statistics['valid']/total_processed*100:.1f}%)")
print(f"🔧 Corrected: {validation_statistics['corrected']} ({validation_statistics['corrected']/total_processed*100:.1f}%)")
print(f"❌ Invalid: {validation_statistics['invalid']} ({validation_statistics['invalid']/total_processed*100:.1f}%)")
print(f"📊 Overall Success Rate: {success_rate:.1f}% (Target: 95%+)")
print(f"🎯 Average Confidence: {average_confidence:.3f}")

print(f"\n=== CONFIDENCE DISTRIBUTION ===")
print(f"High (≥0.8): {validation_statistics['high_confidence']} ({validation_statistics['high_confidence']/total_processed*100:.1f}%)")
print(f"Medium (0.5-0.8): {validation_statistics['medium_confidence']} ({validation_statistics['medium_confidence']/total_processed*100:.1f}%)")
print(f"Low (<0.5): {validation_statistics['low_confidence']} ({validation_statistics['low_confidence']/total_processed*100:.1f}%)")

print(f"\n=== FORMAT DETECTION ===")
print(f"Q&A Interleaved: {validation_statistics['qa_format_detected']} ({validation_statistics['qa_format_detected']/total_processed*100:.1f}%)")
print(f"Traditional: {validation_statistics['traditional_format']} ({validation_statistics['traditional_format']/total_processed*100:.1f}%)")

print(f"\n=== DATA RETENTION ===")
print(f"Kept after validation: {len(cleaned_drafts)}/{total_processed} ({len(cleaned_drafts)/total_processed*100:.1f}%)")

# Save cleaned data with metadata
cleaned_path_full = os.path.join(parent_dir, CLEANED_STUDENT_DRAFTS_PATH)
with open(cleaned_path_full, 'w', encoding='utf-8') as f:
    for draft in cleaned_drafts:
        f.write(json.dumps(draft) + '\n')

print(f"\n✅ Enhanced cleaned student drafts saved to {cleaned_path_full}")
print(f"🚀 Professional validation complete! Success rate: {success_rate:.1f}%")

if success_rate >= 95:
    print("🎯 SUCCESS: Achieved target validation rate of 95%+")
else:
    print(f"⚠️  Below target: {95 - success_rate:.1f}pp improvement needed")