# Gradient Ascent on LILA Dataset

## 1. Setup: Imports, Configuration, Data Loading

In [None]:
!pip install -q --upgrade transformers datasets evaluate wandb torch accelerate

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling
from datasets import load_dataset
from transformers import Trainer, TrainingArguments
import time
import tqdm
import numpy as np
import wandb
import os
import gc # For garbage collection
import config # Import the config file
from typing import List, Any, Dict, Union # For typing
from inference import Generator # Import Generator for inference
from model_handler import ModelHandler # Import ModelHandler for loading

  from .autonotebook import tqdm as notebook_tqdm


### Hyper Parameters & Configuration

In [None]:
# Use settings from config.py where applicable
MODEL_NAME = config.MODEL_NAME # Use model from config
LEARNING_RATE = 2e-5 # Keep specific for this experiment or use config.LEARNING_RATE
EPOCHS = 1 # Keep specific for this experiment or use config.EPOCHS
TRAIN_BATCH_SIZE = config.TRAIN_BATCH_SIZE
GRADIENT_ACCUMULATION_STEPS = config.GRADIENT_ACCUMULATION_STEPS
EVAL_BATCH_SIZE = config.EVAL_BATCH_SIZE
WEIGHT_DECAY = config.WEIGHT_DECAY
EVALUATION_STEPS = config.EVALUATION_STEPS
OUTPUT_DIR = f"gradient_ascent_{MODEL_NAME.split('/')[-1]}" # Specific output dir for this experiment
DEVICE = config.DEVICE # Use device from config
DTYPE_TO_LOAD = config.DTYPE_TO_LOAD # Use dtype from config

# Define the path where the final model *will be* saved (used by both training and inference)
final_model_path = os.path.join(OUTPUT_DIR, "final_model")

print(f"Using Device: {DEVICE}")
print(f"Using Dtype: {DTYPE_TO_LOAD}")
print(f"Output Directory: {OUTPUT_DIR}")
print(f"Final Model Path (Target/Source): {final_model_path}")
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(f'{OUTPUT_DIR}/logs', exist_ok=True)

Using Device: xpu
Using Dtype: torch.bfloat16
Output Directory: gradient_ascent_DeepSeek-R1-Distill-Qwen-1.5B


### Load Tokenizer (Common)

In [None]:
# Load tokenizer (needed for both training and inference)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print("Set pad_token = eos_token")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Get model max length for tokenization (use a temporary model instance or config)
try:
    # Try loading just the config first to avoid loading full model yet
    from transformers import AutoConfig
    model_config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
    MODEL_MAX_LENGTH = model_config.max_position_embeddings
except Exception as e:
    print(f"Warning: Could not get max_position_embeddings from config ({e}). Using default max_length={config.MAX_INPUT_LENGTH}.")
    MODEL_MAX_LENGTH = config.MAX_INPUT_LENGTH
print(f"Using model max length for tokenization: {MODEL_MAX_LENGTH}")

# Note: The actual model loading is moved to the specific sections (Training/Inference)

### Data Preparation (Common)

In [None]:
# --- Data Preparation (Using Chat Template for Gradient Ascent) ---
# We want to maximize the loss on the *assistant answer* part of the sequence
# to make the model *worse* at generating the correct answer following the prompt.

def tokenize_function_gradient_ascent(examples):
    batch_messages = []
    # Use the 'input' and 'output_answer' columns from LILA dataset
    for q, a in zip(examples["input"], examples["output_answer"]):
        # Add the standard instruction prefix used in main.py
        user_content = f"Please reason step by step, and put your final answer within \\boxed{{}}.\n{q}"
        messages = [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": a}
        ]
        batch_messages.append(messages)

    # Apply chat template WITHOUT adding the generation prompt (we provide the full sequence)
    # Add EOS token manually after formatting
    formatted_texts = [
        tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) + tokenizer.eos_token
        for msgs in batch_messages
    ]

    # Tokenize the formatted texts using model's max length
    tokenized = tokenizer(
        formatted_texts,
        truncation=True,
        max_length=MODEL_MAX_LENGTH, # Use dynamically determined max length
        padding=False # Let collator handle padding
    )

    # For gradient ascent on the answer part, labels are initially the same as input_ids.
    # The custom compute_loss will handle masking the prompt part.
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

print(f"Loading base dataset: {config.BASE_DATASET_NAME} ({config.BASE_DATASET_CONFIG})")
dataset = load_dataset(config.BASE_DATASET_NAME, config.BASE_DATASET_CONFIG)
# Use a smaller subset for faster testing if needed
# dataset['train'] = dataset['train'].select(range(100))
# dataset['test'] = dataset['test'].select(range(50))

print("Tokenizing dataset for gradient ascent using chat template...")
# Adjust columns to remove based on the LILA dataset structure
columns_to_remove = list(dataset['train'].features.keys()) # Get all original columns
tokenized_dataset = dataset.map(
    tokenize_function_gradient_ascent,
    batched=True,
    remove_columns=columns_to_remove
)
print("Tokenization complete.")

In [None]:
print(tokenized_dataset)
print(f"Example input_ids: {tokenized_dataset['train'][0]['input_ids'][:50]}...")
print(f"Example labels: {tokenized_dataset['train'][0]['labels'][:50]}...")

---
## Option 1: Perform Gradient Ascent Training
*(Run the cells below if you want to train the model)*

### Load Base Model for Training

In [None]:
# Load the base model specifically for training
print("Loading base model for training...")
model_handler_train = ModelHandler(MODEL_NAME, DEVICE, DTYPE_TO_LOAD)
# Use for_training=True (or let Trainer handle device mapping)
model = model_handler_train.load_model(for_training=True)
print("Base model loaded for training.")

### Define Gradient Ascent Trainer

In [None]:
# Helper function (can be defined globally or inside the class)
def find_last_sublist(main_list: List[int], sub_list: List[int]) -> int:
    """Finds the starting index of the *last* occurrence of sub_list in main_list. Returns -1 if not found."""
    if not sub_list: return 0
    if not main_list: return -1
    len_sub = len(sub_list)
    if len_sub > len(main_list): return -1
    for i in range(len(main_list) - len_sub, -1, -1):
        if main_list[i:i+len_sub] == sub_list:
            return i
    return -1

class GradientAscentTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Encode the assistant marker once
        assistant_prompt_marker_str = "<｜Assistant｜>" # Adjust if your template uses a different marker
        self.assistant_prompt_marker_ids = self.tokenizer.encode(assistant_prompt_marker_str, add_special_tokens=False)
        if not self.assistant_prompt_marker_ids:
            print(f"Warning: Could not encode assistant marker '{assistant_prompt_marker_str}'. Masking will fail.")
        else:
            print(f"GradientAscentTrainer: Using marker '{assistant_prompt_marker_str}' (IDs: {self.assistant_prompt_marker_ids}) for loss masking.")

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs): # Add **kwargs
        """
        Compute loss only on the assistant's response part and negate it.
        Accepts **kwargs to handle potential extra arguments from Trainer.
        """
        # Get inputs and initial labels (collator handles padding)
        input_ids = inputs.get("input_ids")
        labels = inputs.get("labels")
        if labels is None:
             raise ValueError("Labels must be provided.")

        # Clone labels to modify for masking
        masked_labels = labels.clone()

        # Mask the prompt part for each example in the batch
        for i in range(input_ids.size(0)):
            input_ids_list = input_ids[i].tolist()
            mask_until_idx = -1

            if self.assistant_prompt_marker_ids:
                start_idx = find_last_sublist(input_ids_list, self.assistant_prompt_marker_ids)
                if start_idx != -1:
                    # Mask up to and including the marker
                    mask_until_idx = start_idx + len(self.assistant_prompt_marker_ids)
                else:
                    print(f"Warning: Assistant marker not found in example {i}. Loss calculated on full sequence.")
                    # If marker not found, don't mask (loss on whole sequence)
                    mask_until_idx = 0 # Start loss from beginning
            else:
                 print(f"Warning: assistant_prompt_marker_ids is empty. Loss calculated on full sequence.")
                 mask_until_idx = 0

            # Apply masking
            if mask_until_idx > 0 and mask_until_idx < masked_labels.size(1):
                masked_labels[i, :mask_until_idx] = -100
            elif mask_until_idx >= masked_labels.size(1): # Should not happen if data is correct
                 print(f"Warning: Mask index {mask_until_idx} out of bounds for labels length {masked_labels.size(1)}. Masking entire sequence.")
                 masked_labels[i, :] = -100
            # If mask_until_idx is 0, no masking is applied

        # Forward pass with original inputs
        outputs = model(**inputs)
        logits = outputs.get("logits")

        # Compute standard cross-entropy loss using the masked labels
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.model.config.vocab_size), masked_labels.view(-1))

        # Negate the loss for gradient ascent
        neg_loss = -loss

        return (neg_loss, outputs) if return_outputs else neg_loss

    # Optional: Custom evaluation loop if needed, but standard eval should work.
    # The reported 'eval_loss' will be the *positive* standard loss,
    # calculated using the same masking logic if evaluate() calls compute_loss.
    # If evaluate() calls model forward directly, it might use unmasked labels.
    # We want the standard loss (masked) for evaluation to see how *bad* the model is getting on the target part.


### Configure WandB and Training Arguments

In [None]:
# Configure WandB
wandb.login()
wandb.init(
    project="NLP_Gradient_Ascent_LILA", # Updated project name
    config={
        "learning_rate": LEARNING_RATE,
        "epochs": EPOCHS,
        "train_batch_size": TRAIN_BATCH_SIZE,
        "eval_batch_size": EVAL_BATCH_SIZE,
        "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
        "effective_batch_size": TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS,
        "model_name": MODEL_NAME,
        "dataset": f"{config.BASE_DATASET_NAME}/{config.BASE_DATASET_CONFIG}", # Updated dataset info
        "weight_decay": WEIGHT_DECAY,
        "output_dir": OUTPUT_DIR,
        "evaluation_steps": EVALUATION_STEPS,
        "save_steps": 3000, # Adjust as needed
        "device": str(DEVICE),
        "dtype": str(DTYPE_TO_LOAD)
    },
    name=f"grad_ascent-{MODEL_NAME.split('/')[-1]}-LILA-lr{LEARNING_RATE}-ep{EPOCHS}" # Updated run name
)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    num_train_epochs=EPOCHS,
    weight_decay=WEIGHT_DECAY,
    evaluation_strategy="steps",
    eval_steps=EVALUATION_STEPS,
    save_strategy="steps",
    save_steps=3000, # Or save every N steps
    load_best_model_at_end=False, # We don't want the 'best' (lowest negative loss) model
    # metric_for_best_model="eval_loss", # Eval loss should increase
    # greater_is_better=True, # Higher standard eval loss is 'better' for ascent
    logging_dir=f'{OUTPUT_DIR}/logs',
    logging_steps=10,
    # Dtype handling is done at model load, remove bf16/fp16 flags here
    # bf16=(DEVICE.type in ['cuda', 'xpu'] and DTYPE_TO_LOAD == torch.bfloat16),
    # fp16=(DEVICE.type == 'cuda' and DTYPE_TO_LOAD != torch.bfloat16),
    gradient_checkpointing=True, # Saves memory
    report_to="wandb", # Report metrics to WandB
    push_to_hub=False,
)

### Initialize and Run Trainer

In [None]:
trainer = GradientAscentTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validation'], # Use the validation split for evaluation
    tokenizer=tokenizer,
    data_collator=data_collator, # Use the standard LM data collator
)

In [None]:
print("Starting Gradient Ascent Training...")
train_result = trainer.train()
print("Training finished.")

# Save the final model (using the path defined in setup)
print(f"Saving final model to {final_model_path}")
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"Final model saved to {final_model_path}")

# Log final metrics
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)

# Evaluate final model (optional, eval loss should be high)
print("Evaluating final model...")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
print(f"Evaluation metrics: {eval_metrics}")

wandb.finish()

### Clean Up Training Resources

In [None]:
# Delete training-specific objects to free memory before inference
print("Cleaning up training resources...")
if 'model' in locals(): del model
if 'trainer' in locals(): del trainer
if 'model_handler_train' in locals(): del model_handler_train
# Keep 'tokenizer', 'dataset', 'final_model_path', 'config' for potential inference
gc.collect() # Run garbage collection
Generator.cleanup_memory() # Clear GPU cache if applicable
print("Training cleanup complete. You can now proceed to Option 2 (Inference) if desired.")

---
## Option 2: Run Inference Comparison
*(Run the cells below if you want to load the trained model and compare it to the base model)*

### Setup for Inference

In [None]:
# --- Load Gradient Ascent Model for Inference ---
generator_ascent = None
ascent_model = None # Define to allow cleanup later
ascent_tokenizer = None # Define to allow cleanup later
inference_prompt_style = 'think' # Choose 'think' or 'no_think' for generation prompts
print(f"Setting up inference with style: '{inference_prompt_style}'")

# final_model_path was defined in the setup section
if 'final_model_path' in locals() and os.path.exists(final_model_path):
    print(f"Loading gradient ascent model from: {final_model_path}")
    # Use ModelHandler static method for loading saved models
    ascent_model, ascent_tokenizer = ModelHandler.load_fine_tuned(final_model_path, DEVICE, DTYPE_TO_LOAD)
    if ascent_model and ascent_tokenizer:
        print(f"Gradient Ascent tokenizer chat template:\n{ascent_tokenizer.chat_template}")
        generator_ascent = Generator(ascent_model, ascent_tokenizer, DEVICE, inference_style=inference_prompt_style)
    else:
        print("Could not load gradient ascent model/tokenizer properly. Skipping ascent model generation.")
else:
    print(f"Gradient ascent model path not found or not defined ({final_model_path if 'final_model_path' in locals() else 'N/A'}). Skipping ascent model generation.")

# --- Load Base Model for Inference ---
generator_base = None
base_model_inf = None # Define to allow cleanup later
base_tokenizer_inf = None # Define to allow cleanup later
base_model_handler_inf = None # Define to allow cleanup later
try:
    print("\n--- Loading Base Model for Inference ---")
    # Use the same config settings as the gradient ascent run
    base_model_handler_inf = ModelHandler(MODEL_NAME, DEVICE, DTYPE_TO_LOAD)
    # Use the tokenizer loaded during setup if available, otherwise reload
    if 'tokenizer' in locals() and tokenizer.name_or_path == MODEL_NAME:
        print("Reusing tokenizer loaded during setup for base model inference.")
        base_tokenizer_inf = tokenizer
    else:
        print("Reloading tokenizer for base model inference.")
        base_tokenizer_inf = base_model_handler_inf.load_tokenizer()
    
    print(f"Base tokenizer chat template:\n{base_tokenizer_inf.chat_template}")
    # Load model specifically for inference
    base_model_inf = base_model_handler_inf.load_model(for_training=False)
    if base_model_inf and base_tokenizer_inf:
        generator_base = Generator(base_model_inf, base_tokenizer_inf, DEVICE, inference_style=inference_prompt_style)
    else:
        print("Could not load base model/tokenizer properly. Skipping base model generation.")
except Exception as e:
    print(f"Error loading base model for inference: {e}. Skipping base model generation.")

# Ensure the original dataset object is available for comparison
# 'dataset' was loaded in the setup section (cell d25ad6b1)
if 'dataset' not in locals():
    print("Error: 'dataset' object not found. Cannot run inference comparisons. Please re-run the setup cells.")

### Optional: Disable Inference for Specific Models
*(Uncomment lines below to skip generation for a specific model)*

In [None]:
# generator_ascent = None # Uncomment to disable ascent model inference
# generator_base = None # Uncomment to disable base model inference

### Generate Math Outputs (Comparison)

In [None]:
if 'dataset' in locals() and (generator_ascent or generator_base):
    print("\n--- Comparing Math Outputs (Gradient Ascent vs Base) ---")
    # Use the static compare_outputs method from Generator
    # Pass the ascent generator as 'generator_finetuned' for the comparison function
    Generator.compare_outputs(
        dataset=dataset, # Use the original dataset loaded earlier
        generator_finetuned=generator_ascent, # Pass ascent model here
        generator_base=generator_base,
        num_examples=config.NUM_VALIDATION_EXAMPLES_TO_GENERATE # Use config value
    )
else:
    print("Skipping math output comparison due to missing dataset or models failed to load/were disabled.")

### Generate Non-Math Outputs (Comparison)

In [None]:
if generator_ascent or generator_base:
    print("\n--- Comparing Non-Math Outputs (Gradient Ascent vs Base) ---")
    # Use the static test_non_math_generation method from Generator
    Generator.test_non_math_generation(
        prompts=config.NON_MATH_PROMPTS_BASE_STYLE,
        generator_finetuned=generator_ascent, # Pass ascent model here
        generator_base=generator_base
    )
else:
     print("Skipping non-math output comparison as models failed to load or were disabled.")

### Final Cleanup (Inference)

In [None]:
# Clean up inference resources
print("\nCleaning up inference resources...")
if 'ascent_model' in locals(): del ascent_model
if 'ascent_tokenizer' in locals(): del ascent_tokenizer
if 'generator_ascent' in locals(): del generator_ascent
if 'base_model_inf' in locals(): del base_model_inf
if 'base_tokenizer_inf' in locals() and 'tokenizer' in locals() and base_tokenizer_inf is not tokenizer: del base_tokenizer_inf # Only delete if it's a separate instance
if 'generator_base' in locals(): del generator_base
if 'base_model_handler_inf' in locals(): del base_model_handler_inf
# Optionally clean up common resources if no longer needed
# if 'dataset' in locals(): del dataset
# if 'tokenized_dataset' in locals(): del tokenized_dataset
# if 'tokenizer' in locals(): del tokenizer

gc.collect()
Generator.cleanup_memory()
print("Inference cleanup complete.")