# Phi-3 Iterative Try

This notebook is used to try the iterative pipeline without the integration of MongoDB to store the versioning of the dataset

## Imports

In [1]:
import io
import json
import logging
import os
import random
import shutil
import sys
import warnings
from contextlib import redirect_stdout, redirect_stderr

import numpy as np
import torch
from accelerate import Accelerator
from bert_score import score
from datasets import Dataset, load_from_disk
from dotenv import load_dotenv
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainerCallback,
    TrainingArguments
)

### Dataset and preprocessing

In [2]:
def load_squad_subset(dataset_dir="squad_v2_05percent"):
    """
    Load the saved SQuAD v2 0.5% dataset from disk.
    
    Args:
        dataset_dir (str): Directory where the dataset was saved
        
    Returns:
        DatasetDict: The loaded dataset with train and test splits
    """
    
    if not os.path.exists(dataset_dir):
        raise FileNotFoundError(f"Dataset directory '{dataset_dir}' not found. Please run the extraction script first.")
    
    print(f"Loading dataset from {dataset_dir}...")
    
    # Load the dataset using Hugging Face datasets
    dataset = load_from_disk(dataset_dir)
    
    # Load and display metadata
    metadata_path = os.path.join(dataset_dir, "metadata.json")
    if os.path.exists(metadata_path):
        with open(metadata_path, "r") as f:
            metadata = json.load(f)
        
        print("Dataset metadata:")
        for key, value in metadata.items():
            print(f"  {key}: {value}")
    
    # Display dataset info
    print(f"\nLoaded dataset splits:")
    for split_name, split_data in dataset.items():
        print(f"  {split_name}: {len(split_data)} examples")
    
    return dataset

In [3]:
dataset = load_squad_subset()
    
# Show examples
print(f"\nExample from each split:")
print(f"Train: {dataset['train'][0]['question']}")
print(f"Test: {dataset['test'][0]['question']}")

# Access specific fields
print(f"\nDataset features: {dataset['train'].features}")

Loading dataset from squad_v2_05percent...
Dataset metadata:
  original_train_size: 130319
  original_validation_size: 11873
  extracted_train_size: 651
  extracted_test_size: 59
  extraction_percentage: 0.5

Loaded dataset splits:
  train: 651 examples
  test: 59 examples

Example from each split:
Train: When did Beyonce start becoming popular?
Test: In what country is Normandy located?

Dataset features: {'id': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'question': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}


In [None]:
#Define a function to format the dataset examples into a prompt
#The prompt will include the context, question, and answer
def make_prompt(example):
    context = example["context"]
    question = example["question"]
    answer = example["answers"]["text"][0] if example["answers"]["text"] else "No answer"

    prompt = f"[INST] Given the context, answer the question.\n\nContext: {context}\n\nQuestion: {question} [/INST] {answer}"
    return {"prompt": prompt, "reference": answer}

formatted_dataset = {
    split: dataset[split].map(make_prompt)
    for split in dataset.keys()
}

In [5]:
# Print the first formatted prompt and its reference answer
print("Prompt:\n", formatted_dataset["train"][0]["prompt"])
print("\nReference Answer:\n", formatted_dataset["train"][0]["reference"])

Prompt:
 [INST] Given the context, answer the question.

Context: Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny's Child. Managed by her father, Mathew Knowles, the group became one of the world's best-selling girl groups of all time. Their hiatus saw the release of Beyoncé's debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy".

Question: When did Beyonce start becoming popular? [/INST] in the late 1990s

Reference Answer:
 in the late 1990s


In [6]:
load_dotenv("key.env")
token = os.getenv("HUGGINGFACE_TOKEN")

from huggingface_hub import login
login(token=token)

In [7]:
torch.cuda.empty_cache()

### Training function

In [8]:
# Enhanced logging suppression - including BERTScore sharding messages
logging.getLogger("transformers").setLevel(logging.CRITICAL)
logging.getLogger("transformers.modeling_utils").setLevel(logging.CRITICAL)
logging.getLogger("transformers.configuration_utils").setLevel(logging.CRITICAL)
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL)
logging.getLogger("accelerate").setLevel(logging.CRITICAL)
logging.getLogger("accelerate.utils.modeling").setLevel(logging.CRITICAL)
logging.getLogger().setLevel(logging.CRITICAL)
warnings.filterwarnings("ignore")

In [9]:
# Initialize BERTScore silently
print("Initializing BERTScore silently...")
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
    from bert_score import score
    _ = score(["test"], ["test"], lang="en", verbose=False)
print("BERTScore initialized successfully!")

# Modified BERTScore function with complete output suppression
def silent_bert_score(cands, refs, lang="en"):
    """BERTScore calculation with all output suppressed"""
    old_stdout = sys.stdout
    old_stderr = sys.stderr
    
    sys.stdout = io.StringIO()
    sys.stderr = io.StringIO()
    
    try:
        P, R, F1 = score(cands, refs, lang=lang, verbose=False)
        return P, R, F1
    finally:
        sys.stdout = old_stdout
        sys.stderr = old_stderr

# Custom Early Stopping based on Training Loss
class TrainingLossEarlyStoppingCallback(TrainerCallback):
    def __init__(self, patience=3, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.wait_count = 0
        
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        if logs is not None and 'train_loss' in logs:
            current_loss = logs['train_loss']
            
            if current_loss < self.best_loss - self.min_delta:
                self.best_loss = current_loss
                self.wait_count = 0
                print(f"📈 Training loss improved to {current_loss:.4f}")
            else:
                self.wait_count += 1
                print(f"📊 No improvement in training loss ({self.wait_count}/{self.patience})")
                
                if self.wait_count >= self.patience:
                    print(f"🛑 Early stopping triggered! Best loss: {self.best_loss:.4f}")
                    control.should_training_stop = True

# Fixed Custom Trainer class with BERTScore loss
class BERTScoreTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.label_names = ["labels"]
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Custom loss function using BERTScore - completely silent
        """
        labels = inputs.get("labels")
        
        # Temporarily disable cache for forward pass
        model.config.use_cache = False
        
        # Forward pass
        outputs = model(**inputs)
        
        # Generate predictions for BERTScore
        with torch.no_grad():
            input_ids = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]
            
            # Re-enable cache for generation
            model.config.use_cache = True
            
            # Generate text
            try:
                generated = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=50,
                    do_sample=False,
                    pad_token_id=self.processing_class.eos_token_id,
                    use_cache=True
                )
                
                # Decode predictions and references
                pred_texts = self.processing_class.batch_decode(generated, skip_special_tokens=True)
                ref_texts = self.processing_class.batch_decode(labels, skip_special_tokens=True)
                
                # Calculate BERTScore with completely silent function
                P, R, F1 = silent_bert_score(pred_texts, ref_texts, lang="en")
                bert_f1 = F1.mean().item()
                
                # Convert BERTScore to loss
                bert_loss = torch.tensor(1.0 - bert_f1, requires_grad=True, device=input_ids.device)
            except Exception as e:
                # Fallback to standard loss if BERTScore fails
                bert_loss = outputs.loss
            finally:
                # Disable cache again for gradient checkpointing compatibility
                model.config.use_cache = False
        
        # Combine with standard language modeling loss
        standard_loss = outputs.loss
        combined_loss = 0.7 * standard_loss + 0.3 * bert_loss
        
        return (combined_loss, outputs) if return_outputs else combined_loss

# Data preparation function - removed tokenizer parameter since it's not used
def prepare_training_data(tokenized_dataset):
    """Prepare data for training"""
    
    def add_labels(example):
        example["labels"] = example["input_ids"].copy()
        return example

    # Only prepare train split
    train_dataset = tokenized_dataset["train"].map(add_labels)
    
    # Keep only necessary columns
    keep_keys = ["input_ids", "attention_mask", "labels"]
    train_dataset = train_dataset.remove_columns(
        [col for col in train_dataset.column_names if col not in keep_keys]
    )
    
    return {"train": train_dataset}

def cleanup_checkpoints(output_dir):
    """Remove checkpoint directories and files"""
    if os.path.exists(output_dir):
        # Find all checkpoint directories
        checkpoint_dirs = [d for d in os.listdir(output_dir) if d.startswith('checkpoint-')]
        
        for checkpoint_dir in checkpoint_dirs:
            checkpoint_path = os.path.join(output_dir, checkpoint_dir)
            if os.path.isdir(checkpoint_path):
                print(f"🗑️ Removing checkpoint: {checkpoint_path}")
                shutil.rmtree(checkpoint_path)
        
        # Remove any other checkpoint-related files
        checkpoint_files = [f for f in os.listdir(output_dir) if 'checkpoint' in f.lower()]
        for checkpoint_file in checkpoint_files:
            file_path = os.path.join(output_dir, checkpoint_file)
            if os.path.isfile(file_path):
                print(f"🗑️ Removing checkpoint file: {file_path}")
                os.remove(file_path)
        
        print("✅ Checkpoint cleanup completed!")

def configure_model_for_training(model):
    """Configure model for training with proper cache settings"""
    
    # Disable use_cache for training compatibility with gradient checkpointing
    if hasattr(model.config, 'use_cache'):
        model.config.use_cache = False
        print("✅ Set use_cache=False for gradient checkpointing compatibility")
    
    # Enable gradient checkpointing if available
    if hasattr(model, 'gradient_checkpointing_enable'):
        model.gradient_checkpointing_enable()
        print("✅ Enabled gradient checkpointing for memory efficiency")
    
    return model

# Main training function
def train_model(model, tokenized_data, tokenizer, train_args):
    """Training function with BERTScore and early stopping"""
    
    # Configure model for training
    model = configure_model_for_training(model)
    
    # Prepare data
    prepared_data = prepare_training_data(tokenized_data)
    
    # Setup data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
        return_tensors="pt"
    )

    # Custom early stopping based on training loss
    early_stopping_callback = TrainingLossEarlyStoppingCallback(
        patience=5,
        min_delta=0.01
    )
    
    # Initialize BERTScore Trainer
    trainer = BERTScoreTrainer(
        model=model,
        args=train_args,
        train_dataset=prepared_data["train"],
        data_collator=data_collator,
        processing_class=tokenizer,
        callbacks=[early_stopping_callback],
    )
    
    # Start training
    print("Starting training with BERTScore optimization...")
    print("Early stopping based on training loss improvement")
    print("Cache disabled for gradient checkpointing compatibility")
    
    trainer.train()
    
    # Re-enable cache for inference after training
    if hasattr(model.config, 'use_cache'):
        model.config.use_cache = True
        print("✅ Re-enabled use_cache for inference")
    
    # Save model
    final_model_path = "./phi3-squad2-final"
    trainer.save_model(final_model_path)
    print(f"✅ Model saved to {final_model_path}")
    
    # Clean up checkpoints after saving the final model
    print("\n🧹 Cleaning up checkpoints...")
    if hasattr(train_args, 'output_dir') and train_args.output_dir:
        cleanup_checkpoints(train_args.output_dir)
    
    # Also clean up from the final model directory if it has checkpoints
    cleanup_checkpoints(final_model_path)
    
    # Clean up any checkpoint directories in the current working directory
    current_dir_checkpoints = [d for d in os.listdir('.') if d.startswith('checkpoint-')]
    for checkpoint_dir in current_dir_checkpoints:
        if os.path.isdir(checkpoint_dir):
            print(f"🗑️ Removing checkpoint: {checkpoint_dir}")
            shutil.rmtree(checkpoint_dir)
    
    print("🎉 Training completed and checkpoints cleaned up!")
    
    return trainer

Initializing BERTScore silently...
BERTScore initialized successfully!


### Evaluation function

In [10]:
def evaluate_model(model, tokenizer, dataset, device="cuda" if torch.cuda.is_available() else "cpu", num_examples=None):
    """
    Comprehensive evaluation on test set with detailed prediction examples
    """
    print("="*60)
    print("STARTING TEST SET EVALUATION WITH EXAMPLES")
    print("="*60)
    
    # Prepare test data using make_prompt function
    print("Preparing test prompts...")
    test_prompts = dataset["test"].map(make_prompt)
    
    # Set model to evaluation mode
    model.eval()
    
    # Initialize lists for predictions and references
    preds = []
    refs = []
    raw_outputs = []
    prompts_list = []
    questions = []
    contexts = []
    
    # Limit examples if specified
    if num_examples:
        test_prompts = test_prompts.select(range(min(num_examples, len(test_prompts))))
    
    print(f"Generating predictions for {len(test_prompts)} test examples...")
    
    # Generate predictions
    for example in tqdm(test_prompts, desc="Evaluating"):
        # Get prompt without answer (remove answer part from make_prompt output)
        full_prompt = example["prompt"]
        if '[/INST]' in full_prompt:
            prompt_without_answer = full_prompt.split('[/INST]')[0] + '[/INST]'
        else:
            prompt_without_answer = full_prompt
        
        # Tokenize input
        inputs = tokenizer(
            prompt_without_answer,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=512
        ).to(device)
        
        # Generate response
        with torch.no_grad():
            output = model.generate(
                **inputs, 
                max_new_tokens=50,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode output
        decoded = tokenizer.decode(output[0], skip_special_tokens=True)
        
        # Extract answer (everything after [/INST])
        if '[/INST]' in decoded:
            answer = decoded.split('[/INST]')[-1].strip()
        else:
            answer = decoded.strip()
        
        # Store results
        preds.append(answer)
        refs.append(example["reference"])
        raw_outputs.append(decoded)
        prompts_list.append(example["prompt"])
        
        # Extract question and context for detailed analysis
        if "Question:" in example["prompt"]:
            question_part = example["prompt"].split("Question:")[-1].split("[/INST]")[0].strip()
            questions.append(question_part)
        if "Context:" in example["prompt"]:
            context_part = example["prompt"].split("Context:")[-1].split("Question:")[0].strip()
            contexts.append(context_part[:200] + "..." if len(context_part) > 200 else context_part)
    
    print("Predictions generated! Computing metrics...")
    
    # Compute BERTScore using silent function from your notebook
    print("Computing BERTScore...")
    try:
        P, R, F1 = silent_bert_score(preds, refs, lang="en")
        bert_scores = {
            "precision": P.mean().item(),
            "recall": R.mean().item(),
            "f1": F1.mean().item()
        }
    except Exception as e:
        print(f"BERTScore computation failed: {e}")
        bert_scores = {"precision": 0.0, "recall": 0.0, "f1": 0.0}
        P = R = F1 = [0.0] * len(preds)
    
    # Compute exact match accuracy
    exact_matches = []
    for pred, ref in zip(preds, refs):
        if ref != "No answer" and ref.lower().strip() in pred.lower().strip():
            exact_matches.append(1)
        else:
            exact_matches.append(0)
    
    exact_match_score = np.mean(exact_matches)
    
    # Compute F1 score (token overlap)
    f1_scores = []
    for pred, ref in zip(preds, refs):
        pred_tokens = set(pred.lower().split())
        ref_tokens = set(ref.lower().split())
        
        if len(pred_tokens) == 0 and len(ref_tokens) == 0:
            f1_scores.append(1.0)
        elif len(pred_tokens) == 0 or len(ref_tokens) == 0:
            f1_scores.append(0.0)
        else:
            common = len(pred_tokens & ref_tokens)
            precision = common / len(pred_tokens)
            recall = common / len(ref_tokens)
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            f1_scores.append(f1)
    
    f1_score = np.mean(f1_scores)
    
    # Compute semantic similarity (simple word overlap)
    semantic_similarities = []
    for pred, ref in zip(preds, refs):
        pred_words = set(pred.lower().split())
        ref_words = set(ref.lower().split())
        if len(pred_words) == 0 and len(ref_words) == 0:
            semantic_similarities.append(1.0)
        elif len(pred_words | ref_words) == 0:
            semantic_similarities.append(0.0)
        else:
            jaccard = len(pred_words & ref_words) / len(pred_words | ref_words)
            semantic_similarities.append(jaccard)
    
    semantic_similarity = np.mean(semantic_similarities)
    
    # Compute answer length statistics
    pred_lengths = [len(pred.split()) for pred in preds]
    ref_lengths = [len(ref.split()) for ref in refs]
    
    # Print comprehensive results
    print("="*60)
    print("EVALUATION RESULTS")
    print("="*60)
    print(f"Test Set Size: {len(preds)}")
    print("-"*60)
    print("BERTScore Metrics:")
    print(f"  Precision: {bert_scores['precision']:.4f}")
    print(f"  Recall:    {bert_scores['recall']:.4f}")
    print(f"  F1 Score:  {bert_scores['f1']:.4f}")
    print("-"*60)
    print("Other Metrics:")
    print(f"  Exact Match: {exact_match_score:.4f}")
    print(f"  F1 Score:    {f1_score:.4f}")
    print(f"  Semantic Similarity: {semantic_similarity:.4f}")
    print("-"*60)
    print("Answer Length Statistics:")
    print(f"  Avg Prediction Length: {np.mean(pred_lengths):.2f} words")
    print(f"  Avg Reference Length:  {np.mean(ref_lengths):.2f} words")
    print("="*60)
    
    # Show detailed examples
    print("\n" + "="*80)
    print("DETAILED PREDICTION EXAMPLES")
    print("="*80)
    
    # Select diverse examples: best, worst, and random
    bert_f1_scores = [f.item() if hasattr(f, 'item') else f for f in F1]
    
    # Get indices for different categories
    sorted_indices = sorted(range(len(bert_f1_scores)), key=lambda i: bert_f1_scores[i], reverse=True)
    
    best_indices = sorted_indices[:3]  # Top 3
    worst_indices = sorted_indices[-3:]  # Bottom 3
    random_indices = random.sample(range(len(preds)), min(4, len(preds)))  # Random 4
    
    example_categories = [
        ("BEST PREDICTIONS", best_indices),
        ("WORST PREDICTIONS", worst_indices),
        ("RANDOM PREDICTIONS", random_indices)
    ]
    
    for category_name, indices in example_categories:
        print(f"\n{category_name}:")
        print("-" * 80)
        
        for i, idx in enumerate(indices):
            print(f"\nExample {i+1} (Index {idx}):")
            print(f"BERTScore F1: {bert_f1_scores[idx]:.4f}")
            print(f"Token F1: {f1_scores[idx]:.4f}")
            print(f"Exact Match: {'✓' if exact_matches[idx] else '✗'}")
            
            if idx < len(questions):
                print(f"Question: {questions[idx]}")
            if idx < len(contexts):
                print(f"Context: {contexts[idx]}")
            
            print(f"Reference Answer: {refs[idx]}")
            print(f"Model Prediction: {preds[idx]}")
            
            # Analysis
            pred_words = len(preds[idx].split())
            ref_words = len(refs[idx].split())
            print(f"Length: Pred={pred_words} words, Ref={ref_words} words")
            
            # Simple similarity check
            pred_lower = preds[idx].lower()
            ref_lower = refs[idx].lower()
            common_words = set(pred_lower.split()) & set(ref_lower.split())
            print(f"Common words: {len(common_words)}")
            
            print("-" * 50)
    
    # Create results dictionary matching your expected format
    results = {
        "test_size": len(preds),
        "exact_match": exact_match_score,
        "f1_score": f1_score,
        "bert_score_f1": bert_scores["f1"],
        "semantic_similarity": semantic_similarity,
        "avg_prediction_length": np.mean(pred_lengths),
        "avg_reference_length": np.mean(ref_lengths),
        "predictions": preds,
        "references": refs,
        "questions": questions,
        "individual_scores": {
            "bert_f1": bert_f1_scores,
            "token_f1": f1_scores,
            "exact_match": exact_matches,
            "semantic_similarity": semantic_similarities
        }
    }
    
    # Save detailed results
    print(f"\n{'='*60}")
    print("SAVING RESULTS")
    print("="*60)
    
    # Save detailed examples
    with open("detailed_predictions.txt", "w", encoding="utf-8") as f:
        f.write("DETAILED TEST SET PREDICTIONS\n")
        f.write("="*80 + "\n\n")
        
        for i, (prompt, pred, ref, f1_score, em) in enumerate(zip(prompts_list, preds, refs, f1_scores, exact_matches)):
            f.write(f"Example {i+1}:\n")
            f.write(f"BERTScore F1: {f1_score:.4f}\n")
            f.write(f"Exact Match: {'✓' if em else '✗'}\n")
            
            if "Question:" in prompt:
                question = prompt.split("Question:")[-1].split("Answer:")[0].strip()
                f.write(f"Question: {question}\n")
            else:
                f.write(f"Prompt: {prompt}\n")
            
            f.write(f"Reference: {ref}\n")
            f.write(f"Prediction: {pred}\n")
            f.write("-" * 50 + "\n\n")
    
    print("Results saved to:")
    print("  - test_evaluation_results.json")
    print("  - detailed_predictions.txt")
    print("="*60)
    
    return results

## Iterative training and dataset generation

In [18]:
def generate_synthetic_answers(model, tokenizer, formatted_dataset, device, generation_num=1):
    """
    Generate synthetic answers using the fine-tuned causal language model.
    Takes formatted_dataset with prompts as input.
    """
    
    print(f"Generating synthetic answers (Generation {generation_num})...")
    
    synthetic_data = []
    
    # Use the train split from formatted_dataset
    train_dataset = formatted_dataset['train']
    
    for example in tqdm(train_dataset, desc="Generating answers"):
        # Get the prompt that was created by make_prompt function
        prompt = example['prompt']
        
        # Find where the prompt ends to extract the incomplete part
        # Assuming the prompt format ends with something like "[/INST]" or "### Response:"
        if '[/INST]' in prompt:
            # For instruction format, generate after [/INST]
            generation_prompt = prompt
            stop_sequence = '[/INST]'
        elif '### Response:' in prompt:
            # For alpaca format, generate after ### Response:
            generation_prompt = prompt
            stop_sequence = '### Response:'
        else:
            # Fallback: use the full prompt
            generation_prompt = prompt
            stop_sequence = None
        
        # Tokenize input
        inputs = tokenizer(
            generation_prompt,
            max_length=512,
            truncation=True,
            padding=True,
            return_tensors="pt"
        ).to(device)
        
        # Generate answer using causal LM
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=100,  # Increased for better answers
                do_sample=True,      # Changed to True for diversity
                temperature=0.7,     # Added temperature for controlled randomness
                top_p=0.9,          # Added nucleus sampling
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
        
        # Decode the generated text
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract the new generated part (answer)
        if stop_sequence and stop_sequence in generated_text:
            # Split by the stop sequence and take everything after it
            parts = generated_text.split(stop_sequence)
            if len(parts) > 1:
                synthetic_answer = parts[-1].strip()
            else:
                synthetic_answer = "No answer found"
        else:
            # If no stop sequence, take everything after the original prompt
            if generated_text.startswith(generation_prompt):
                synthetic_answer = generated_text[len(generation_prompt):].strip()
            else:
                synthetic_answer = generated_text.strip()
        
        # Clean up the answer
        if not synthetic_answer or synthetic_answer == generation_prompt:
            synthetic_answer = "No answer found"
        
        # Create new example with synthetic answer
        new_example = example.copy()
        
        # Update the prompt to include the generated answer
        if stop_sequence:
            new_example['prompt'] = generation_prompt + synthetic_answer
        else:
            new_example['prompt'] = generation_prompt + " " + synthetic_answer
        
        # If original data has structured fields, preserve them and update answers
        if 'answers' in example:
            if synthetic_answer != "No answer found":
                # Try to find answer in context if context exists
                context = example.get('context', '')
                answer_start = context.find(synthetic_answer) if context else 0
                if answer_start == -1:
                    answer_start = 0
                
                new_example['answers'] = {
                    'text': [synthetic_answer],
                    'answer_start': [answer_start]
                }
            else:
                new_example['answers'] = {
                    'text': [],
                    'answer_start': []
                }
        
        # Add generation metadata
        new_example['generation_num'] = generation_num
        new_example['synthetic'] = True
        
        synthetic_data.append(new_example)
    
    # Create a new formatted dataset with the synthetic data
    synthetic_dataset = Dataset.from_list(synthetic_data)
    
    # Return in the same format as input
    return {
        'train': synthetic_dataset
    }


def save_synthetic_dataset(dataset_dir, synthetic_formatted_dataset, generation_num):
    """
    Save the synthetic formatted dataset to a new subdirectory.
    
    Args:
        dataset_dir: Base dataset directory
        synthetic_formatted_dataset: Formatted dataset with synthetic answers
        generation_num: Generation number
    """
    
    # Create new subdirectory
    new_dir = os.path.join(dataset_dir, f"generation_{generation_num}")
    os.makedirs(new_dir, exist_ok=True)
    
    # Save the train split
    synthetic_formatted_dataset['train'].save_to_disk(os.path.join(new_dir, "train"))
    
    # Also save as JSON for inspection
    json_file = os.path.join(new_dir, "synthetic_data.json")
    synthetic_formatted_dataset['train'].to_json(json_file)
    
    # Save metadata
    metadata = {
        "generation_number": generation_num,
        "total_examples": len(synthetic_formatted_dataset['train']),
        "generated_from": "fine_tuned_model",
        "description": f"Synthetic answers generated using fine-tuned model (Generation {generation_num})",
        "format": "formatted_dataset_with_prompts"
    }
    
    with open(os.path.join(new_dir, "metadata.json"), "w") as f:
        json.dump(metadata, f, indent=2)
    
    print(f"Saved synthetic formatted dataset to {new_dir}")
    return new_dir

In [None]:
def iterative_training_and_generation(
    base_dataset_dir="squad_v2_05percent",
    model_path = "microsoft/phi-3-mini-128k-instruct",
    num_generations=3,
    device="cuda" if torch.cuda.is_available() else "cpu"
):
    """
    Perform iterative training and synthetic data generation.
    
    Args:
        base_dataset_dir: Directory containing the original dataset
        num_generations: Number of generations to create
        device: Device for inference
    """
    
    # Accelerator setup
    accelerator = Accelerator()
    
    # Quantization config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        llm_int8_enable_fp32_cpu_offload=True
    )
    
    # Load tokenizer once at the beginning
    base_model_name = "microsoft/phi-3-mini-128k-instruct"
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Define tokenization function
    def tokenize(example):
        return tokenizer(
            example["prompt"],
            truncation=True,
            padding="max_length",
            max_length=256
        )
    
    print(f"Starting iterative training and generation for {num_generations} generations...")
    
    # Load original dataset
    dataset = None
    formatted_dataset = None
    
    # Main progress bar for generations
    generation_progress = tqdm(
        range(1, num_generations + 1), 
        desc="🔄 Overall Progress", 
        unit="generation",
        position=0,
        leave=True
    )
    
    for generation in generation_progress:
        generation_progress.set_description(f"🔄 Generation {generation}/{num_generations}")
        
        print(f"\n{'='*50}")
        print(f"GENERATION {generation}")
        print(f"{'='*50}")
        
        # Step 1: Fine-tune model with current training dataset
        print("Step 1: Fine-tuning model...")
        
        # Prepare training arguments for this generation
        train_config = {
            "bf16": True,
            "do_eval": False,  # Disable evaluation completely
            "learning_rate": 1.0e-05,
            "log_level": "info",
            "logging_steps": 10,
            "logging_strategy": "steps",
            "lr_scheduler_type": "cosine",
            "num_train_epochs": 3,
            "max_steps": -1,
            "output_dir": f"./phi3-squad2-gen{generation}",  # Update for each generation
            "overwrite_output_dir": True,
            "per_device_train_batch_size": 4,
            "remove_unused_columns": True,
            "save_steps": 50,
            "save_total_limit": 2,
            "seed": 42,
            "gradient_checkpointing": True,
            "gradient_checkpointing_kwargs": {"use_reentrant": False},
            "gradient_accumulation_steps": 2,
            "warmup_ratio": 0.05,
            "save_strategy": "steps",
            "load_best_model_at_end": False,  # No evaluation, so no "best" model
            "disable_tqdm": False,  # Enable tqdm progress bars for training
        }

        train_args = TrainingArguments(**train_config)
        
        # Load model for this generation
        print("📥 Loading model...")
        with tqdm(total=1, desc="🤖 Model Loading", position=1, leave=False) as model_pbar:
            if generation == 1:
                # First generation: load base model with quantization
                model = AutoModelForCausalLM.from_pretrained(
                    base_model_name,
                    device_map="auto",
                    quantization_config=bnb_config
                )
                # Prepare dataset for first generation
                dataset = load_squad_subset()
                formatted_dataset = {
                    split: dataset[split].map(make_prompt)
                    for split in dataset.keys()
                }
            else:
                # Subsequent generations: load previous model (already fine-tuned, no quantization needed)
                model_path = f"./phi3-squad2-gen{generation-1}-final"
                model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    device_map="auto"
                )
                # Prepare dataset for subsequent generations
                dataset = load_squad_subset(f"train_{generation-1}_gen")
                formatted_dataset = {
                    split: dataset[split].map(make_prompt)
                    for split in formatted_dataset.keys()
                }

            model_pbar.update(1)
        
        # Apply PEFT configuration
        print("🔧 Applying PEFT configuration...")
        with tqdm(total=1, desc="⚙️ PEFT Setup", position=1, leave=False) as peft_pbar:
            peft_config = {
                "r": 8,  # Reduced from 16 to 8 (fewer parameters)
                "lora_alpha": 16,  # Reduced from 32 to 16
                "lora_dropout": 0.1,  # Slightly increased dropout
                "bias": "none",
                "task_type": "CAUSAL_LM",
                "target_modules": "all-linear",
                "modules_to_save": None,
            }
            lora_config = LoraConfig(**peft_config)
            model = get_peft_model(model, lora_config)
            peft_pbar.update(1)
        
        # Tokenize current training dataset
        print("🔤 Tokenizing dataset...")
        tokenized = {
            split: formatted_dataset[split].map(tokenize, batched=True)
            for split in formatted_dataset.keys()
        }
        
        # Fine-tune the model
        print("🚀 Starting training...")
        trainer = train_model(model, tokenized, tokenizer, train_args)
        
        # Save the fine-tuned model for this generation
        print("💾 Saving model...")
        with tqdm(total=1, desc="💾 Saving Model", position=1, leave=False) as save_pbar:
            final_model_path = f"./phi3-squad2-gen{generation}-final"
            trainer.save_model(final_model_path)
            save_pbar.update(1)
        print(f"✅ Generation {generation} model saved to {final_model_path}")
        
        # Step 2: Generate synthetic answers using the fine-tuned model
        print("Step 2: Generating synthetic answers...")
        model.eval()  # Set to evaluation mode
        
        synthetic_dataset = generate_synthetic_answers(
            model, tokenizer, formatted_dataset, device, generation
        )
        
        # Step 3: Save synthetic dataset
        print("Step 3: Saving synthetic dataset...")
        with tqdm(total=1, desc="💾 Saving Dataset", position=1, leave=False) as dataset_save_pbar:
            new_dir = save_synthetic_dataset(base_dataset_dir, synthetic_dataset, generation)
            dataset_save_pbar.update(1)
        
        # Step 4: Evaluate on test set
        print("Step 4: Evaluating on test set...")
        test_dataset = dataset['test']
        
        # Use the existing evaluation function
        with tqdm(total=1, desc="📊 Evaluating", position=1, leave=False) as eval_pbar:
            evaluation_results = evaluate_model(model, tokenizer, test_dataset, device)
            eval_pbar.update(1)
        
        print(f"📊 Generation {generation} Evaluation Results:")
        print(f"   Exact Match: {evaluation_results['exact_match']:.3f}")
        print(f"   F1 Score: {evaluation_results['f1_score']:.3f}")
        print(f"   BERTScore F1: {evaluation_results['bert_score_f1']:.3f}")
        print(f"   Semantic Similarity: {evaluation_results['semantic_similarity']:.3f}")
        
        # Update current training dataset for next iteration
        current_train_dataset = synthetic_dataset
        
        print(f"Generation {generation} completed!")
        print(f"Synthetic dataset saved to: {new_dir}")
        print(f"Model saved to: {final_model_path}")
        
        # Clean up GPU memory
        torch.cuda.empty_cache()
        
        # Update generation progress
        generation_progress.set_postfix({
            'EM': f"{evaluation_results['exact_match']:.3f}",
            'F1': f"{evaluation_results['f1_score']:.3f}"
        })
    
    generation_progress.close()
    print(f"\n🎉 All {num_generations} generations completed!")
    print("Final models and datasets are ready for use.")

In [22]:
iterative_training_and_generation(
        base_dataset_dir="squad_v2_05percent",
        num_generations=3,
        device="cuda" if torch.cuda.is_available() else "cpu"
    )

Starting iterative training and generation for 3 generations...


🔄 Generation 1/3:   0%|          | 0/3 [00:00<?, ?generation/s]  PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).



GENERATION 1
Step 1: Fine-tuning model...
📥 Loading model...


Generate config GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": 32000,
  "pad_token_id": 32000
}

target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The following TP rules were not applied on any of the layers: {'layers.*.self_attn.qkv_proj': 'colwise_rep', 'layers.*.self_attn.o_proj': 'rowwise_rep', 'layers.*.mlp.gate_up_proj': 'colwise_rep', 'layers.*.mlp.down_proj': 'rowwise_rep'}
The following layers were not sharded: model.norm.weight, model.embed_tokens.weight, model.layers.*.post_attention_layernorm.weight, lm_head.weight, model.layers.*.input_layernorm.weight
loading configuration file generation_config.json from cache at C:\Users\manua\.cache\huggingface\hub\models--microsoft--phi-3-mini-128k-instruct\snapshots\072cb7562cb8c4adf682a8e186aaafa49469eb5d\generation_config.json
Generate config GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": [
    32000,
    32001,
    32007
  ],
  "pad_token_id": 32000
}



Loading dataset from squad_v2_05percent...
Dataset metadata:
  original_train_size: 130319
  original_validation_size: 11873
  extracted_train_size: 651
  extracted_test_size: 59
  extraction_percentage: 0.5

Loaded dataset splits:
  train: 651 examples
  test: 59 examples
🔧 Applying PEFT configuration...




🔤 Tokenizing dataset...
🚀 Starting training...
✅ Set use_cache=False for gradient checkpointing compatibility
✅ Enabled gradient checkpointing for memory efficiency


Using auto half precision backend
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Starting training with BERTScore optimization...
Early stopping based on training loss improvement
Cache disabled for gradient checkpointing compatibility


***** Running training *****
  Num examples = 651
  Num Epochs = 3
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 2
  Total optimization steps = 246
  Number of trainable parameters = 12,582,912


Step,Training Loss
10,3.7963
20,3.8681
30,3.8095
40,3.776
50,3.6639
60,3.4943
70,3.1894
80,3.3211
90,3.0415
100,3.1699


The following layers were not sharded: pooler.dense.bias, encoder.layer.*.attention.self.query.weight, encoder.layer.*.attention.output.dense.weight, encoder.layer.*.attention.self.value.weight, embeddings.token_type_embeddings.weight, encoder.layer.*.output.dense.weight, encoder.layer.*.attention.output.LayerNorm.bias, encoder.layer.*.attention.self.key.bias, encoder.layer.*.output.dense.bias, embeddings.word_embeddings.weight, embeddings.position_embeddings.weight, encoder.layer.*.attention.self.query.bias, encoder.layer.*.output.LayerNorm.bias, pooler.dense.weight, embeddings.LayerNorm.weight, encoder.layer.*.attention.self.key.weight, encoder.layer.*.attention.output.LayerNorm.weight, encoder.layer.*.intermediate.dense.weight, embeddings.LayerNorm.bias, encoder.layer.*.intermediate.dense.bias, encoder.layer.*.output.LayerNorm.weight, encoder.layer.*.attention.output.dense.bias, encoder.layer.*.attention.self.value.bias
The following layers were not sharded: pooler.dense.bias, encod

KeyboardInterrupt: 