# Gradient Ascent Training Pipeline

## 1. Setup: Imports and Configuration

In [None]:
!pip install -q --upgrade transformers datasets evaluate wandb torch accelerate pandas ipywidgets

In [None]:
import torch
import os
import gc
import importlib

# Import refactored classes
from config_manager import ConfigManager
from model_manager import ModelManager
from dataset_manager import DatasetManager
from gradient_ascent_pipeline import GradientAscentPipeline # Import the GA pipeline
from inference_engine import InferenceEngine
from utils import extract_boxed_answer, compare_math_answers

# --- Experiment Configuration ---
# Define the specific parameters for this gradient ascent run
experiment_params = {
    "experiment_name": "deepseek_lila", # Descriptive name (training_type added automatically)
    "training_type": "gradient_ascent",
    # dataset_json_path is not used for GA data prep, uses base dataset train split
    "dataset_json_path": None,
    # --- Optional Overrides (Comment out to use config.py defaults or GA defaults) ---
    # "learning_rate": 5e-6, # Example override
    # "epochs": 1,
    # "train_batch_size": 1,
    # "gradient_accumulation_steps": 8,
    # "eos_loss_scale_factor": 0.05, # Override GA trainer param
}

# --- Initialize Configuration ---
config_manager = ConfigManager()
# Get config, merging GA defaults from config.py if defined
ga_defaults = {
    'LEARNING_RATE': config_manager.get_base_value('GA_LEARNING_RATE', 2e-5),
    'EPOCHS': config_manager.get_base_value('GA_EPOCHS', 1),
    'EOS_LOSS_SCALE_FACTOR': config_manager.get_base_value('GA_EOS_LOSS_SCALE_FACTOR', 0.1),
    'ASSISTANT_MARKER_STR': config_manager.get_base_value('GA_ASSISTANT_MARKER_STR', '<｜Assistant｜>')
}
merged_exp_params = {**ga_defaults, **experiment_params} # Experiment params override GA defaults
run_config = config_manager.get_config(merged_exp_params)

print("--- Run Configuration (Gradient Ascent) --- ")
for key, val in run_config.items():
    print(f"{key}: {val}")
print("-------------------------------------------")

## 2. Initialize and Run Gradient Ascent Pipeline

In [None]:
# Initialize the Gradient Ascent pipeline with the specific run configuration
pipeline = GradientAscentPipeline(run_config)

# Run the full pipeline (setup, train, evaluate, save, cleanup)
try:
    pipeline.run()
except Exception as e:
    print(f"Gradient Ascent Pipeline execution failed: {e}")
    # Optional: Perform partial cleanup if needed
    # pipeline.cleanup()

## 3. Setup for Inference Comparison

In [None]:
# --- Reload Config for Inference (if needed, or reuse run_config) ---
inf_config = run_config # Reuse the config from the training run

DEVICE = inf_config['DEVICE']
DTYPE_TO_LOAD = inf_config['DTYPE_TO_LOAD']
BASE_MODEL_NAME = inf_config['MODEL_NAME']
SAVED_MODEL_PATH = inf_config['SAVED_MODEL_PATH'] # Path where GA model was saved
MAX_NEW_TOKENS_MATH = inf_config['MAX_NEW_TOKENS_MATH']
MAX_NEW_TOKENS_NON_MATH = inf_config['MAX_NEW_TOKENS_NON_MATH']
NUM_EXAMPLES_TO_COMPARE = inf_config.get('NUM_VALIDATION_EXAMPLES_TO_GENERATE', 5)
NON_MATH_PROMPTS = inf_config.get('NON_MATH_PROMPTS_BASE_STYLE', [])
CONFIG_MAX_LENGTH = inf_config.get('MAX_INPUT_LENGTH')
FALLBACK_MAX_LENGTH = inf_config.get('DEFAULT_FALLBACK_MAX_LENGTH', 4096)
COMPILE_MODEL = inf_config.get('COMPILE_MODEL_FOR_EVALUATION', False) # Get compile flag

inference_style = 'think' # Or 'no_think'

generator_ascent = None
generator_base = None
ascent_model = None
ascent_tokenizer = None
base_model_inf = None
base_tokenizer_inf = None

# --- Load Gradient Ascent Model ---
print(f"\n--- Loading Gradient Ascent Model ({SAVED_MODEL_PATH}) ---")
if os.path.exists(SAVED_MODEL_PATH):
    ascent_model, ascent_tokenizer = ModelManager.load_fine_tuned(SAVED_MODEL_PATH, DEVICE, DTYPE_TO_LOAD)
    if ascent_model and ascent_tokenizer:
        # Pass config values to InferenceEngine
        generator_ascent = InferenceEngine(
            ascent_model, ascent_tokenizer, DEVICE, inference_style,
            config_max_length=CONFIG_MAX_LENGTH,
            fallback_max_length=FALLBACK_MAX_LENGTH,
            compile_model=COMPILE_MODEL # Pass compile flag
        )
        print("Gradient Ascent model loaded for inference.")
    else:
        print("Failed to load Gradient Ascent model/tokenizer.")
else:
    print(f"Gradient Ascent model path not found: {SAVED_MODEL_PATH}")

# --- Load Base Model ---
print(f"\n--- Loading Base Model ({BASE_MODEL_NAME}) ---")
try:
    base_model_manager_inf = ModelManager(BASE_MODEL_NAME, DEVICE, DTYPE_TO_LOAD)
    base_tokenizer_inf = base_model_manager_inf.load_tokenizer()
    base_model_inf = base_model_manager_inf.load_model(for_training=False)
    if base_model_inf and base_tokenizer_inf:
        # Pass config values to InferenceEngine
        generator_base = InferenceEngine(
            base_model_inf, base_tokenizer_inf, DEVICE, inference_style,
            config_max_length=CONFIG_MAX_LENGTH,
            fallback_max_length=FALLBACK_MAX_LENGTH,
            compile_model=COMPILE_MODEL # Pass compile flag
        )
        print("Base model loaded for inference.")
    else:
        print("Failed to load base model/tokenizer.")
except Exception as e:
    print(f"Error loading base model for inference: {e}")

# --- Load Original Dataset for Comparison ---
dataset_for_comparison = None
try:
    temp_tokenizer = ascent_tokenizer if ascent_tokenizer else base_tokenizer_inf
    if temp_tokenizer:
        # Pass config values to DatasetManager
        inf_dataset_manager = DatasetManager(
            temp_tokenizer,
            config_max_length=CONFIG_MAX_LENGTH,
            fallback_max_length=FALLBACK_MAX_LENGTH
        )
        dataset_for_comparison = inf_dataset_manager.load_base_dataset(
            dataset_name=inf_config['BASE_DATASET_NAME'],
            dataset_config=inf_config['BASE_DATASET_CONFIG']
        )
        print("\nLoaded base dataset for comparison.")
    else:
        print("\nCannot load dataset for comparison - no tokenizer available.")
except Exception as e:
     print(f"\nError loading dataset for comparison: {e}")

## 4. Run Inference Comparison

In [None]:
# --- Math Problem Comparison (Ascent vs Base) ---
if dataset_for_comparison and 'validation' in dataset_for_comparison and (generator_ascent or generator_base):
    print(f"\n--- Comparing Math Outputs (Ascent vs Base - First {NUM_EXAMPLES_TO_COMPARE} Examples) ---")
    validation_subset = dataset_for_comparison['validation'].select(range(min(NUM_EXAMPLES_TO_COMPARE, len(dataset_for_comparison['validation']))))

    problems = validation_subset['input']
    ground_truths = validation_subset['output_answer']

    ascent_outputs = []
    base_outputs = []

    if generator_ascent:
        print("Generating with Gradient Ascent Model...")
        ascent_outputs = generator_ascent.generate_math_batch(problems, max_new_tokens=MAX_NEW_TOKENS_MATH, batch_size=inf_config.get('EVAL_BATCH_SIZE', 1))

    if generator_base:
        print("Generating with Base Model...")
        base_outputs = generator_base.generate_math_batch(problems, max_new_tokens=MAX_NEW_TOKENS_MATH, batch_size=inf_config.get('EVAL_BATCH_SIZE', 1))

    # Print comparison
    for i in range(len(problems)):
        print(f"\n--- Example {i+1} ---")
        print(f"Problem: {problems[i][:500]}...")
        print(f"Actual: {ground_truths[i]}")
        gt_boxed = extract_boxed_answer(ground_truths[i])
        print(f"  Actual Boxed: {gt_boxed}")

        if i < len(ascent_outputs):
            print(f"Ascent Model: {ascent_outputs[i]}")
            ascent_boxed = extract_boxed_answer(ascent_outputs[i])
            ascent_correct = compare_math_answers(ground_truths[i], ascent_outputs[i])
            print(f"  Ascent Boxed: {ascent_boxed} (Correct: {ascent_correct})")
        else:
            print("Ascent Model: [Not Generated]")

        if i < len(base_outputs):
            print(f"Base Model: {base_outputs[i]}")
            base_boxed = extract_boxed_answer(base_outputs[i])
            base_correct = compare_math_answers(ground_truths[i], base_outputs[i])
            print(f"  Base Boxed: {base_boxed} (Correct: {base_correct})")
        else:
            print("Base Model: [Not Generated]")
        print("-"*30)
else:
    print("\nSkipping math output comparison (dataset or generators missing).")

# --- Non-Math Prompt Comparison (Ascent vs Base) ---
if NON_MATH_PROMPTS and (generator_ascent or generator_base):
    print("\n\n--- Testing Non-Math Generation (Ascent vs Base) ---")
    for i, prompt in enumerate(NON_MATH_PROMPTS):
        print(f"\n--- Prompt {i+1} --- ")
        print(f"Prompt: {prompt}")
        if generator_ascent:
            print("\nAscent Model Response:")
            ascent_response = generator_ascent.generate_general_response(prompt, max_new_tokens=MAX_NEW_TOKENS_NON_MATH)
            print(ascent_response)
        if generator_base:
            print("\nBase Model Response:")
            base_response = generator_base.generate_general_response(prompt, max_new_tokens=MAX_NEW_TOKENS_NON_MATH)
            print(base_response)
        print("-"*30)
else:
    print("\nSkipping non-math comparison (prompts or generators missing).")

## 5. Final Cleanup (Inference)

In [None]:
# Clean up inference resources
print("\nCleaning up inference resources...")
del ascent_model
del ascent_tokenizer
del generator_ascent
del base_model_inf
del base_tokenizer_inf
del generator_base
del dataset_for_comparison
if 'base_model_manager_inf' in locals(): del base_model_manager_inf
if 'inf_dataset_manager' in locals(): del inf_dataset_manager

gc.collect()
InferenceEngine.cleanup_memory()
print("Inference cleanup complete.")