In [None]:
# # --- Hugging Face Login ---
# from google.colab import userdata
# from huggingface_hub import notebook_login

# hf_token = userdata.get('HF_TOKEN')
# if not hf_token:
#     raise ValueError("HF_TOKEN not found in Colab Secrets. Please complete the prerequisite steps.")
# notebook_login(hf_token)

# !pip install -Uq transformers==4.53.2
# !pip install -Uq peft
# !pip install -Uq trl
# !pip install -Uq accelerate
# !pip install -Uq datasets
# !pip install -Uq bitsandbytes

# # Install Flash Attention 2
# !pip install flash-attn==2.7.4.post1 \
#   --extra-index-url https://download.pytorch.org/whl/cu124 \
#   --no-build-isolation

In [None]:
# ===== EXPERIMENT CONFIGURATION =====
CONFIG = {
    # Core experiment parameters
    "experiment_type": "generative",  # "discriminative" or "generative"
    "classification_type": "ternary",   # "binary" or "ternary"
    "dataset_strategy": "3N",          # "4N" or "3N" (generative only)
    "include_explanation": True,      # True or False (generative only)
    "include_eln": True,              # True or False (generative only)
    "solution_format": "nl",        # "dict" or "nl" (generative only)
    "model_name": "microsoft/phi-4-mini-instruct",  # or "Qwen/Qwen3-4B"
    
    # Prompting configuration
    "system_prompt": None,  # Will auto-generate if None, or use custom string
    "include_examples": False,
    "num_examples": 3,
    "example_strategy": "balanced",  # "balanced", "error_focused", "custom"
    
    # Training parameters
    "learning_rate": 2e-4,
    "num_epochs": 3,
    "batch_size": 8,
    "max_length": 1600,
    "gradient_accumulation_steps": 4,

    # LoRa params
    "lora_rank": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.1,
    
    # Paths and tokens
    # "base_dataset_dir": "/content/drive/MyDrive/sft_datasets",
    "base_dataset_dir": "../data/base-datasets-sanitized",
    "output_base_dir": "/content/drive/MyDrive/sft_experiments",
    # "hf_token": "your_huggingface_token_here",
    # "wandb_project": "math_error_classification",
    
    # Experiment tracking
    "save_to_hf": True,
    "save_locally": True,
    "use_wandb": False
}

# Generate experiment ID
import datetime
experiment_components = [
    CONFIG["experiment_type"][:4],  # "gene" or "disc"
    CONFIG["classification_type"][:3],  # "bin" or "ter"
    CONFIG["dataset_strategy"] if CONFIG["experiment_type"] == "generative" else "",
    "exp" if CONFIG["include_explanation"] else "no_exp",
    "eln" if CONFIG["include_eln"] else "no_eln",
    CONFIG["solution_format"] if CONFIG["experiment_type"] == "generative" else "",
    "qwen" if "Qwen" in CONFIG["model_name"] else "phi4"
]
experiment_id = "_".join([c for c in experiment_components if c]) + "_" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
CONFIG["experiment_id"] = experiment_id

print(f"Experiment ID: {experiment_id}")
print(f"Configuration loaded successfully!")

def setup_output_directory(config):
    """Creates organized output directory structure"""
    
    output_dir = Path(config["output_base_dir"]) / config["experiment_id"]
    
    # Create subdirectories
    subdirs = ["baseline", "training", "final", "checkpoints"]
    for subdir in subdirs:
        (output_dir / subdir).mkdir(parents=True, exist_ok=True)
    
    # Save configuration
    config_path = output_dir / "config.json"
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2, default=str)
    
    print(f"Output directory created: {output_dir}")
    return output_dir

# # Setup output directory
# output_dir = setup_output_directory(CONFIG)
# CONFIG["output_dir"] = str(output_dir)

import torch
import random
import numpy as np

# Set random seeds for reproducibility
def set_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seeds(42)
print("Dependencies imported and seeds set!")

In [None]:
def generate_system_prompt(config):
    """Auto-generates appropriate system prompt based on config"""
    
    if config["experiment_type"] == "discriminative":
        return "You are a mathematics tutor. Classify the given solution."
    
    # Generative prompts
    base_prompt = "You are a mathematics tutor. Analyze the given solution and provide your assessment in JSON format."
    
    # Add classification instructions
    if config["classification_type"] == "binary":
        base_prompt += " Determine if the solution is 'correct' or 'flawed'."
    else:
        base_prompt += " Classify as 'correct', 'conceptual_error', or 'computational_error'."
    
    # Add field instructions
    fields = []
    if config["include_eln"]:
        if config["solution_format"] == "dict":
            fields.append("identify the erroneous line number (e.g., 'L1', 'FA')")
        else:
            fields.append("quote the full erroneous line text")
    
    if config["include_explanation"]:
        fields.append("provide a brief explanation of any error")
    
    if fields:
        base_prompt += f" Also {', and '.join(fields)}."
    
    base_prompt += " Respond only with valid JSON."
    
    return base_prompt

# Auto-generate system prompt if not provided
if CONFIG["system_prompt"] is None:
    CONFIG["system_prompt"] = generate_system_prompt(CONFIG)

print("System Prompt:")
print(CONFIG["system_prompt"])
print()

# Allow manual override
print("To customize the system prompt, run:")
print('CONFIG["system_prompt"] = "Your custom prompt here"')

In [None]:
class ExampleManager:
    def __init__(self, base_dataset, config):
        # Convert DataFrame to list of dicts if needed
        if hasattr(base_dataset, 'to_dict'):  # It's a DataFrame
            self.samples = base_dataset.to_dict('records')
        else:
            self.samples = base_dataset  # Already a list of dicts
            
        self.config = config
        self._prepare_examples_by_problem()
    
    def _prepare_examples_by_problem(self):
        """Organizes samples by problem index and error type"""
        self.problems_by_type = {
            "correct": {},
            "conceptual_error": {},
            "computational_error": {}
        }
        
        # Group samples by problem index and error type
        for sample in self.samples:
            problem_index = sample["index"]
            error_type = sample["error_type"]
            
            if problem_index not in self.problems_by_type[error_type]:
                self.problems_by_type[error_type][problem_index] = []
            self.problems_by_type[error_type][problem_index].append(sample)
        
        print(f"Problems by type: correct={len(self.problems_by_type['correct'])}, "
              f"conceptual={len(self.problems_by_type['conceptual_error'])}, "
              f"computational={len(self.problems_by_type['computational_error'])}")
    
    def get_examples(self):
        """Returns examples based on dataset strategy"""
        if not self.config["include_examples"]:
            return []
        num_examples = self.config["num_examples"]
        dataset_strategy = self.config["dataset_strategy"]
        examples = []
        
        import random
        if dataset_strategy == "3N":
            # Choose num_examples distinct problem indices that have all 3 versions
            available_problems = set(self.problems_by_type["correct"].keys()) & \
                               set(self.problems_by_type["conceptual_error"].keys()) & \
                               set(self.problems_by_type["computational_error"].keys())
            # Sample problem indices
            selected_problems = random.sample(list(available_problems), num_examples)
            for problem_index in selected_problems:
                # Add all 3 versions: correct, conceptual_error, computational_error
                examples.append(self.problems_by_type["correct"][problem_index][0])
                examples.append(self.problems_by_type["conceptual_error"][problem_index][0])
                examples.append(self.problems_by_type["computational_error"][problem_index][0])
            
            return examples
            
        elif dataset_strategy == "4N":
            import math
            # Get problems that have conceptual errors (with correct versions)
            conceptual_problems = list(
                set(self.problems_by_type["correct"].keys()) & 
                set(self.problems_by_type["conceptual_error"].keys())
            )
            # Get problems that have computational errors (with correct versions)
            computational_problems = list(
                set(self.problems_by_type["correct"].keys()) & 
                set(self.problems_by_type["computational_error"].keys())
            )
            # Calculate splits: floor(n/2) conceptual, ceil(n/2) computational
            n_conceptual = num_examples // 2  # This is floor(n/2)
            n_computational = math.ceil(num_examples / 2)
            
            # Sample conceptual problems
            if conceptual_problems and n_conceptual > 0:
                selected_conceptual = random.sample(conceptual_problems,n_conceptual)
                for problem_index in selected_conceptual:
                    # Add correct + conceptual_error pair
                    examples.append(self.problems_by_type["correct"][problem_index][0])
                    examples.append(self.problems_by_type["conceptual_error"][problem_index][0])
            
            # Sample computational problems
            if computational_problems and n_computational > 0:
                selected_computational = random.sample(computational_problems,n_computational)
                for problem_index in selected_computational:
                    # Add correct + computational_error pair
                    examples.append(self.problems_by_type["correct"][problem_index][0])
                    examples.append(self.problems_by_type["computational_error"][problem_index][0])
            
            return examples
        
        else:
            print(f"Warning: Unknown dataset strategy '{dataset_strategy}'")
            return []

print("Updated ExampleManager class loaded!")

In [None]:
import json
import pandas as pd
from pathlib import Path

def load_base_dataset():
    """Loads the appropriate base dataset"""
    dataset_strategy = CONFIG["dataset_strategy"]
    base_dir = Path(CONFIG["base_dataset_dir"])
    dataset_file = base_dir / f"base_{dataset_strategy}_dataset_sanitized.csv"
    data = pd.read_csv(dataset_file)
    print(f"Loaded base {dataset_strategy} dataset with {len(data)} samples")
    return data

def make_solution_mapping(solution_text: str):
    """Creates a mapping of line numbers to solution lines."""
    solution_lines = solution_text.strip().split('\n')
    solution_mapping = {f"L{i+1}": line.strip() for i, line in enumerate(solution_lines[:-1]) if line.strip()}
    solution_mapping["FA"] = solution_lines[-1].strip()
    return solution_mapping

def format_solution(sample):
    """Formats solution according to config - updated for CSV structure"""
    if sample["error_type"] == "correct":
        solution_text = sample.get("correct_answer", "").strip()
    else:
        solution_text = sample.get("wrong_answer", "").strip()
    if CONFIG["solution_format"] == "dict":
        return make_solution_mapping(solution_text)
    else:
        return solution_text

def format_expected_output(sample):
    """Creates the expected JSON output for a sample - updated for CSV structure"""
    output = {}
    output["verdict"] = sample["error_type"]
    if CONFIG["classification_type"] == "binary" and sample["error_type"] != "correct":
        output["verdict"] = "flawed"
    
    # ELN
    if CONFIG["include_eln"]:
        # If the sample is correct, set ELN or EL to None
        if sample["error_type"] == "correct":
            if CONFIG["solution_format"] == "dict":
                output["erroneous_line_number"] = None
            else:
                output["erroneous_line"] = None
        # If sample is erroneous, extract ELN or EL
        else:
            eln = sample["erroneous_line_number"]
            if CONFIG["solution_format"] == "dict":
                output["erroneous_line_number"] = eln
            else:
                solution_text = sample["wrong_answer"]
                solution_mapping = make_solution_mapping(solution_text)
                el = solution_mapping[eln] # get the actual line
                output["erroneous_line"] = el
    
    # Explanation
    if CONFIG["include_explanation"]:
        output["explanation"] = sample["explanation"]
    
    return json.dumps(output)

# def format_user_message(sample):
#     """Format a sample into a user message."""
#     return f"### Question:\n{sample['question']}\n\n### Answer:\n{format_solution(sample)}"

def format_user_message(sample):
    """Format a sample into a user message, ensuring all parts are strings."""
    question_text = sample.get('question', '') or '' # Ensures it's a string, not None
    solution_text = format_solution(sample)
    return f"### Question:\n{question_text}\n\n### Answer:\n{solution_text}"

def create_sample_messages(sample, examples):
    """Create complete message list for a sample."""
    messages = []
    
    # System message
    messages.append({
        "role": "system",
        "content": CONFIG["system_prompt"]
    })
    
    # Few-shot examples
    if CONFIG["include_examples"]:
        for example in examples:
            user_content = format_user_message(example)
            assistant_content = format_expected_output(example)
            
            messages.append({"role": "user", "content": user_content})
            messages.append({"role": "assistant", "content": assistant_content})
    
    # Actual sample
    user_content = format_user_message(sample)
    messages.append({"role": "user", "content": user_content})
    
    return messages

print("Updated formatting functions loaded!")

In [None]:
from transformers import AutoTokenizer

def load_tokenizer(model_name):
    """Loads tokenizer with proper configuration"""
    print(f"Loading tokenizer: {model_name}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    tokenizer.padding_side = "left"  # Ensure left padding for causal models
    
    print(f"✓ Tokenizer loaded successfully!")
    return tokenizer

def apply_chat_template(
        messages, 
        tokenizer, 
        add_generation_prompt=False, 
        tokenize=True, 
        **kwargs
    ):
    """
    Applies chat template to messages with consistent interface
    
    Args:
        messages: List of message dictionaries with 'role' and 'content' keys
        tokenizer: The tokenizer to use for formatting
        add_generation_prompt: Whether to add generation prompt (for inference)
        tokenize: Whether to return tokens (True) or text (False)
        **kwargs: Additional arguments for tokenizer (like return_tensors, max_length, etc.)
    
    Returns:
        If tokenize=True: tokenizer output dict with input_ids, attention_mask, etc.
        If tokenize=False: formatted text string
    """
    
    # Check if this is a Qwen3 model and disable thinking if so
    template_kwargs = {}
    if CONFIG["model_name"].startswith("Qwen"):
        template_kwargs['enable_thinking'] = False
    
    # Apply chat template to get formatted text
    formatted_text = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=add_generation_prompt,
        **template_kwargs
    )
    
    # Return text if not tokenizing
    if not tokenize:
        return formatted_text
    
    # Tokenize and return tensor format
    return tokenizer(formatted_text, **kwargs)

In [None]:
from datasets import Dataset, DatasetDict
import pandas as pd

def step1_create_conversations(sample, examples):
    """[STEP 1] Creates the list of message dictionaries for a sample."""
    messages = create_sample_messages(sample, examples)
    expected_output = format_expected_output(sample)
    messages.append({"role": "assistant", "content": expected_output})
    for i, msg in enumerate(messages):
        if msg['content'] is None:
            raise TypeError(f"Message content is None at index {i} for sample ID {sample.get('id', 'N/A')}. Message: {msg}")
    return {"conversation": messages}

def step2_apply_chat_template(sample, tokenizer):
    """[STEP 2] Applies the tokenizer's chat template to a conversation."""
    formatted_text = apply_chat_template(
        sample["conversation"],
        tokenizer,
        add_generation_prompt=False,
        tokenize=False
    )
    return {"text": formatted_text}

def step3_tokenize_text(sample, tokenizer):
    """[STEP 3] Tokenizes the formatted text."""
    tokenized = tokenizer(
        sample["text"],
        truncation=True,
        max_length=CONFIG["max_length"],
        padding=False
    )
    return tokenized


# New Cell after the modular functions

def prepare_dataset(config, tokenizer):
    """Orchestrates the modular data preparation pipeline."""
    base_df = load_base_dataset()
    raw_dataset = Dataset.from_pandas(base_df)
    example_manager = ExampleManager(base_df, config)
    examples = example_manager.get_examples()
    system_prompt = config["system_prompt"]
    if system_prompt is None:
        raise ValueError("System prompt is None! Check cell 3.")

    print("Executing Step 1: Creating conversations...")
    ds_step1 = raw_dataset.map(
        lambda x: step1_create_conversations(x, examples)
    )
    print("✅ Step 1 complete.")

    print("\nExecuting Step 2: Applying chat template...")
    ds_step2 = ds_step1.map(
        lambda x: step2_apply_chat_template(x, tokenizer)
    )
    print("✅ Step 2 complete.")

    print("\nExecuting Step 3: Tokenizing text...")
    processed_dataset = ds_step2.map(
        lambda x: step3_tokenize_text(x, tokenizer),
        remove_columns=ds_step2.column_names
    )
    print("✅ Step 3 complete.")

    split_dataset = processed_dataset.train_test_split(test_size=0.2, seed=42)
    print(f"\nDataset prepared: {len(split_dataset['train'])} training, {len(split_dataset['test'])} evaluation samples")

    return split_dataset['train'], split_dataset['test'], examples

In [None]:
from transformers import AutoModelForCausalLM
from transformers.utils.quantization_config import BitsAndBytesConfig

from peft import (
    LoraConfig, 
    get_peft_model, 
    TaskType, 
    prepare_model_for_kbit_training
)

def load_model(model_name):
    """Loads model with appropriate configuration"""
    print(f"Loading model: {model_name}")
    
    # Configure quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True
    )

    # Configure LoRA
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=CONFIG["lora_rank"],
        lora_alpha=CONFIG["lora_alpha"],
        lora_dropout=CONFIG["lora_dropout"],
        target_modules="all-linear",
        bias="none"
    )

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        attn_implementation="flash_attention_2"
    )

    # Prepare model for 4-bit training with LoRA
    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, lora_config)

    # Check params
    model.print_trainable_parameters()
    print(f"✓ Model loaded successfully!")
    print(f"✓ Model device: {next(model.parameters()).device}")
    
    return model

In [None]:
import json
import re
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import EvalPrediction

def normalize_text(text):
    """Normalize text for flexible comparison (removes spaces, converts to lowercase)."""
    if text is None:
        return ""
    return re.sub(r'\s+', '', str(text).lower().strip())

def extract_json_from_response(response):
    """Extract JSON from model response, handling various formatting issues."""
    if not response: return {}
    response = response.strip()
    patterns = [r'\{.*\}', r'```json\s*(\{.*\})\s*```', r'```\s*(\{.*\})\s*```']
    for pattern in patterns:
        matches = re.findall(pattern, response, re.DOTALL)
        for match in matches:
            try: return json.loads(match)
            except json.JSONDecodeError: continue
    try: return json.loads(response)
    except json.JSONDecodeError: return {}

def compute_metrics_for_trainer(eval_pred: EvalPrediction, tokenizer):
    """Computes metrics from trainer's predictions."""
    # With predict_with_generate=True, 'predictions' are the generated token IDs, not logits
    predicted_ids, labels = eval_pred
    
    # The .argmax() call is no longer needed
    # REMOVED: predicted_ids = predictions.argmax(axis=-1)
    
    # Decode the predicted IDs and labels
    decoded_preds = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)
    
    labels[labels == -100] = tokenizer.pad_token_id
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    verdict_expected, verdict_predicted = [], []
    eln_expected, eln_predicted = [], []
    parse_failures = 0

    for pred_text, label_text in zip(decoded_preds, decoded_labels):
        pred_json = extract_json_from_response(pred_text)
        expected_json = extract_json_from_response(label_text)

        if not pred_json:
            parse_failures += 1
        
        verdict_expected.append(expected_json.get("verdict", "unknown"))
        verdict_predicted.append(pred_json.get("verdict", "unknown"))
        
        if CONFIG["include_eln"]:
            key = "erroneous_line" if CONFIG["solution_format"] == "nl" else "erroneous_line_number"
            expected_line = str(expected_json.get(key, ""))
            predicted_line = str(pred_json.get(key, ""))
            eln_expected.append(normalize_text(expected_line))
            eln_predicted.append(normalize_text(predicted_line))

    # Calculate metrics
    verdict_accuracy = accuracy_score(verdict_expected, verdict_predicted)
    precision, recall, f1, _ = precision_recall_fscore_support(verdict_expected, verdict_predicted, average='macro', zero_division=0)
    
    metrics = {
        "verdict_accuracy": verdict_accuracy,
        "verdict_precision": precision,
        "verdict_recall": recall,
        "verdict_f1": f1,
        "parse_failures": parse_failures
    }

    if CONFIG["include_eln"]:
        metrics["eln_accuracy"] = accuracy_score(eln_expected, eln_predicted)
        
    return metrics

def print_metrics(metrics, stage_name="Evaluation"):
    """Print simple metrics summary."""
    print(f"\n{stage_name} Results:")
    print(f"Total samples: {metrics['total_samples']} (Parse failures: {metrics['parse_failures']})")
    print(f"Verdict - Accuracy: {metrics['verdict_accuracy']:.3f}, Precision: {metrics['verdict_precision']:.3f}, Recall: {metrics['verdict_recall']:.3f}, F1: {metrics['verdict_f1']:.3f}")
    
    if "eln_accuracy" in metrics:
        print(f"ELN Accuracy: {metrics['eln_accuracy']:.3f}")

# Simple evaluation function
def evaluate_results(results, tokenizer, stage_name="Evaluation"):
    """Evaluate results with simple metrics."""
    metrics = compute_metrics_for_trainer(results, tokenizer)
    print_metrics(metrics, stage_name)
    return metrics

print("✅ Simplified metrics functions loaded (with text normalization)!")
print("\nTo evaluate your baseline results, run:")
print("baseline_metrics = evaluate_results(baseline_results, CONFIG, 'Baseline')")

In [None]:
from transformers import (
    TrainingArguments, 
    Trainer, 
    DataCollatorForLanguageModeling, 
    EarlyStoppingCallback
)
from functools import partial

def setup_trainer(model, tokenizer, train_dataset, eval_dataset):
    """Sets up the Trainer for fine-tuning."""
    
    # Data collator for language modeling (pads batches dynamically)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    
    output_dir = Path(CONFIG["output_dir"]) / "training"

    training_args = TrainingArguments(
        output_dir=str(output_dir),

        # Basic training parameters
        optim="paged_adamw_8bit",
        num_train_epochs=CONFIG["num_epochs"],
        per_device_train_batch_size=CONFIG["batch_size"],
        gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
        learning_rate=CONFIG["learning_rate"],
        weight_decay=0.01,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",

        # Evaluation and saving strategy
        eval_strategy="steps",
        eval_steps=25,
        save_strategy="steps",
        save_steps=25,
        save_total_limit=1,
        per_device_eval_batch_size=CONFIG["batch_size"],
        eval_accumulation_steps=CONFIG["gradient_accumulation_steps"],  
        load_best_model_at_end=True,
        metric_for_best_model="eln_accuracy", # Make sure this metric exists
        greater_is_better=True,

        # Other settings
        logging_steps=25,
        fp16=True,
        bf16=True,
        report_to="none",
        seed=42
    )

    # Use a partial function to pass config and tokenizer to the metrics function
    compute_metrics_with_config = partial(
        compute_metrics_for_trainer, 
        tokenizer=tokenizer
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics_with_config, # Use the new metrics function
        callbacks=[EarlyStoppingCallback(early_stopping_patience=10)]
    )
    
    print("✓ Trainer initialized successfully!")
    return trainer

## Main execution

In [None]:
# 1. Load tokenizer
tokenizer = load_tokenizer(CONFIG["model_name"])

In [None]:
df = load_base_dataset()

for i in range(5):
    sample = df.iloc[i].to_dict()
    example_manager = ExampleManager(df, CONFIG)
    examples = example_manager.get_examples()

    # Step 1: Create conversation dict (not a list)
    conversation = step1_create_conversations(sample, examples)

    print("Conversation:")
    print(conversation["conversation"])

    # Step 2: Process conversation into single prompt
    formatted_text = step2_apply_chat_template(conversation, tokenizer)
    print("Formatted Text:")
    print(formatted_text["text"])

    # Step 3: Tokenize
    tokenizer_output = step3_tokenize_text(formatted_text, tokenizer)
    print("Tokenized Output:")
    print(tokenizer_output)

    print(len(tokenizer_output['input_ids']), "tokens generated.")
    print(len(tokenizer_output['attention_mask']), "attention mask tokens generated.")

In [None]:
# 2. Load and process datasets
train_dataset, eval_dataset, examples = prepare_dataset(CONFIG, tokenizer)

In [None]:
# 3. Load model
model = load_model(CONFIG["model_name"])