## 1. Setup: Install Libraries

In [1]:
#!pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.4.7 datasets evaluate sentencepiece

# Login W&B

In [2]:
import wandb
wandb.login()


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdanielbetschart[0m ([33mdanielbetschart-hochschule-luzern[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## 2. Load Dataset

In [3]:
from datasets import load_dataset

dataset_name = "tau/commonsense_qa"
dataset = load_dataset(dataset_name)

print("Dataset loaded:")
print(dataset)
print("\nExample Train instance:")
print(dataset['train'][0])

Dataset loaded:
DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'question_concept', 'choices', 'answerKey'],
        num_rows: 9741
    })
    validation: Dataset({
        features: ['id', 'question', 'question_concept', 'choices', 'answerKey'],
        num_rows: 1221
    })
    test: Dataset({
        features: ['id', 'question', 'question_concept', 'choices', 'answerKey'],
        num_rows: 1140
    })
})

Example Train instance:
{'id': '075e483d21c29a511267ef62bedc0461', 'question': 'The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?', 'question_concept': 'punishing', 'choices': {'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['ignore', 'enforce', 'authoritarian', 'yell at', 'avoid']}, 'answerKey': 'A'}


## 3. Configuration & Model Selection

In [4]:
from transformers import AutoTokenizer # Need tokenizer info early

# --- Define formatting function ---
def format_prompt(example):
    question = example['question']
    choices_text = example['choices']['text']
    choices_labels = example['choices']['label']
    answer_key = example['answerKey']

    prompt = f"### Question:\n{question}\n\n### Choices:\n"
    for label, text in zip(choices_labels, choices_text):
        prompt += f"{label}) {text}\n"
    prompt += f"\n### Answer:\n{answer_key}"
    return {"text": prompt}

# Apply formatting
formatted_dataset = dataset.map(format_prompt, remove_columns=list(dataset['train'].features))
print("\nExample Formatted Prompt (for training):")
print(formatted_dataset['train'][0]['text'])

# --- Define Tokenization (but don't run tokenization globally yet) ---
# We need the tokenizer object available, but tokenization will happen inside the sweep function
base_model_name_for_tokenizer = "tiiuae/falcon-7b-instruct" # Needs to match model used in sweep
tokenizer = AutoTokenizer.from_pretrained(base_model_name_for_tokenizer, token=None)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

max_sequence_length = 256 # Define max length here

def tokenize_function(examples):
    # Use the globally defined tokenizer and max_sequence_length
    return tokenizer(
        examples["text"],
        truncation=True,
        padding=False,
        max_length=max_sequence_length,
    )

# Note: We map the tokenize_function inside the train_sweep function later
print("Formatting and tokenizer setup complete. Tokenization will occur within sweep runs.")


Example Formatted Prompt (for training):
### Question:
The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?

### Choices:
A) ignore
B) enforce
C) authoritarian
D) yell at
E) avoid

### Answer:
A
Formatting and tokenizer setup complete. Tokenization will occur within sweep runs.


## 4. Load Model and Tokenizer

In [5]:
sweep_config = {
    'method': 'bayes', # Or 'random', 'grid'
    'metric': {
      'name': 'eval/loss', # Optimize for lowest validation loss
      'goal': 'minimize'
    },
    'parameters': {
        'learning_rate': {
            'distribution': 'log_uniform_values',
            'min': 1e-5,
            'max': 5e-4
        },
        'lora_r': {
            'values': [16, 32, 64]
        },
        'lora_alpha': {
             # Often set relative to r, e.g., 2*r, but let's try independent values
            'values': [16, 32, 64]
        },
        'lora_dropout': {
            'distribution': 'uniform',
            'min': 0.05,
            'max': 0.2
        },
        'gradient_accumulation_steps': {
            'values': [4, 8]
        },
        'weight_decay': {
            'values': [0.0, 0.001, 0.01]
        }
        # Add other parameters like per_device_train_batch_size if desired
        # 'per_device_train_batch_size': {'values': [1, 2]}
    }
}

# --- Initialize the Sweep ---
# Make sure you are logged into W&B (`wandb login`)
sweep_id = wandb.sweep(sweep_config, project="falcon-commonsenseqa-sweep") # Choose a project name
print(f"Sweep ID: {sweep_id}")

Create sweep with ID: kk2hcae9
Sweep URL: https://wandb.ai/danielbetschart-hochschule-luzern/falcon-commonsenseqa-sweep/sweeps/kk2hcae9
Sweep ID: kk2hcae9


## 5. train_sweep function

In [6]:
import torch
from transformers import (
    AutoModelForCausalLM,
    # AutoTokenizer loaded globally
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
# ===> IMPORT THIS <===
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import os
import wandb # Make sure wandb is imported here too

# Ensure global variables (tokenizer, formatted_dataset, tokenize_function, max_sequence_length)
# are accessible within this function's scope. If running as a script, define them above or pass as args.

def train_sweep():
    run = None # Define run here to ensure it's accessible in finally block
    try:
        # Initialize W&B run for this specific sweep trial
        run = wandb.init()
        if run is None:
             # Fallback for testing outside agent - provide dummy config
             print("WARNING: wandb.init() returned None. Running with default test config.")
             class DummyConfig:
                 learning_rate = 2e-4; lora_r = 64; lora_alpha = 16
                 lora_dropout = 0.1; gradient_accumulation_steps = 8; weight_decay = 0.001
             config = DummyConfig()
             run = wandb.init(project="falcon-sweep-test", config=vars(config))
        else:
             config = wandb.config

        # --- Model Configuration ---
        model_name = "tiiuae/falcon-7b-instruct"

        # --- QLoRA Configuration ---
        use_4bit = True
        bnb_4bit_compute_dtype = "float16"
        bnb_4bit_quant_type = "nf4"
        use_nested_quant = False
        compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=use_4bit,
            bnb_4bit_quant_type=bnb_4bit_quant_type,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=use_nested_quant,
        )

        # --- Load Model ---
        device_map = {"": 0}
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map=device_map,
            token=None,
            trust_remote_code=True
        )
        model.config.use_cache = False
        model.config.pretraining_tp = 1 # May not be needed after loading but often set

        # ===> PREPARE MODEL FOR QLoRA + GRADIENT CHECKPOINTING <===
        # This is crucial for compatibility
        model = prepare_model_for_kbit_training(model)
        print("Model prepared for k-bit training.")

        # --- LoRA Configuration (Using wandb.config) ---
        peft_config = LoraConfig(
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            r=config.lora_r,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=[
                "query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h",
            ],
        )

        # --- Apply PEFT ---
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()

        # ===> EXPLICITLY ENABLE GRADIENT CHECKPOINTING ON PEFT MODEL <===
        # Do this AFTER prepare_model... and get_peft_model
        # Keep this even if using Trainer's argument (redundancy often needed here)
        model.gradient_checkpointing_enable()
        print("Gradient checkpointing enabled on PEFT model.")

        # --- Tokenize Dataset (using global tokenizer and function) ---
        tokenized_dataset = formatted_dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=["text"]
        )

        # --- Training Arguments Configuration (Using wandb.config) ---
        output_dir = os.path.join(wandb.run.dir, "results")
        os.makedirs(output_dir, exist_ok=True)

        training_arguments = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=1,
            per_device_train_batch_size=1,
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            optim="paged_adamw_32bit",
            save_strategy="steps",
            save_steps=100,
            logging_steps=10,
            learning_rate=config.learning_rate,
            weight_decay=config.weight_decay,
            fp16=False, # bf16 is generally preferred if available
            bf16=torch.cuda.is_bf16_supported(),
            max_grad_norm=0.3,
            max_steps=-1,
            warmup_ratio=0.03,
            group_by_length=True,
            lr_scheduler_type="cosine",
            report_to="wandb",
            eval_strategy="steps",
            eval_steps=100,
            per_device_eval_batch_size=1,
            save_total_limit=2,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            # ===> KEEP THESE TRAINER ARGS <===
            gradient_checkpointing=True,
             # Specify use_reentrant=False for the newer implementation
            gradient_checkpointing_kwargs={'use_reentrant': False},
        )

        # --- Data Collator ---
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

        # --- Initialize Trainer ---
        trainer = Trainer(
            model=model,
            args=training_arguments,
            train_dataset=tokenized_dataset["train"],
            eval_dataset=tokenized_dataset["validation"],
            data_collator=data_collator,
        )

        # --- Start Training ---
        print(f"Starting training for sweep run: {run.id}")
        trainer.train()
        print(f"Training finished for sweep run: {run.id}")

        # Clean up memory
        del model
        del trainer
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"Error during sweep run: {e}")
        # Log error to W&B
        if run:
             run.log({"error": str(e)})
             # Optionally print traceback
             import traceback
             traceback.print_exc() # Print full traceback to see where it originates
        # raise e # Reraise if you want the agent to stop on error
    finally:
        # Ensure W&B run is finished even if errors occur
        if run:
            print(f"Finishing W&B run: {run.id}")
            run.finish()
        print("Cleaning up CUDA cache.")
        torch.cuda.empty_cache()

## 6. Run the W&B Sweep Agent

In [None]:
# This will run trials based on the sweep config
# count=N specifies how many trials to run. Set to None to run indefinitely (or until stopped).
# Ensure sweep_id is defined from the previous cell
try:
    # Make sure sweep_id is defined from the previous cell (wandb.sweep())
    print(f"Starting W&B agent for sweep ID: {sweep_id}")
    wandb.agent(sweep_id, function=train_sweep, count=10) # Run 10 trials for example
    print("W&B agent finished.")
except NameError:
    print("Error: sweep_id is not defined. Please run the sweep initialization cell first.")
except Exception as e:
    print(f"An error occurred while running the W&B agent: {e}")

Starting W&B agent for sweep ID: kk2hcae9


[34m[1mwandb[0m: Agent Starting Run: tig02dc9 with config:
[34m[1mwandb[0m: 	gradient_accumulation_steps: 8
[34m[1mwandb[0m: 	learning_rate: 0.00034890009200157126
[34m[1mwandb[0m: 	lora_alpha: 16
[34m[1mwandb[0m: 	lora_dropout: 0.06947765483463827
[34m[1mwandb[0m: 	lora_r: 32
[34m[1mwandb[0m: 	weight_decay: 0.001
[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.






Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.


Model prepared for k-bit training.


You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


trainable params: 65,273,856 || all params: 6,986,994,560 || trainable%: 0.9342
Gradient checkpointing enabled on PEFT model.
Starting training for sweep run: tig02dc9


You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.
  return fn(*args, **kwargs)


Step,Training Loss,Validation Loss
100,1.5737,1.77615
200,1.429,1.729266
300,1.4258,1.681134
400,1.4336,1.676396
500,1.4131,1.641104
600,1.4739,1.612625
700,1.373,1.598969
800,1.3706,1.566399
900,1.3103,1.554614
1000,1.2672,1.545612


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



## 7. Template Code to Continue Training a Specific Run Later

%%script false
# --- Resume Configuration ---
RUN_ID_TO_RESUME = "abcdef12" # <--- REPLACE with the actual Run ID from W&B
# Find the best checkpoint saved by the run. Often under wandb/sweep-<id>/run-<id>/files/results/checkpoint-XXX
# Or if load_best_model_at_end=True, the final saved model in `output_dir` is the best one.
CHECKPOINT_PATH_TO_RESUME = "wandb/sweep-XXXX/run-YYYY/files/results/checkpoint-ZZZ" # <--- REPLACE with the full path
PROJECT_NAME = "falcon-commonsenseqa-sweep" # Should match the project used for the sweep
ENTITY_NAME = None # Or your W&B entity/username if needed
ADDITIONAL_EPOCHS = 2 # How many *more* epochs to train

# --- Need Hyperparameters from the Resumed Run ---
# You MUST use the same hyperparameters (LoRA config, quantization, etc.)
# as the run you are resuming. Fetch these from the W&B run's config.
# Example (FETCH THESE VALUES FROM THE W&B RUN DASHBOARD for run RUN_ID_TO_RESUME):
print(f"Fetching config for run: {RUN_ID_TO_RESUME}")
# Use W&B API to get the config of the run to resume (safer than manual copy-paste)
resumed_run_api = wandb.Api().run(f"{ENTITY_NAME or wandb.Api().default_entity}/{PROJECT_NAME}/{RUN_ID_TO_RESUME}")
cfg = resumed_run_api.config
print(f"Fetched Config: {cfg}")

RESUMED_LORA_R = cfg.get("lora_r", 64) # Provide default if key missing
RESUMED_LORA_ALPHA = cfg.get("lora_alpha", 16)
RESUMED_LORA_DROPOUT = cfg.get("lora_dropout", 0.1)
RESUMED_LR = cfg.get("learning_rate", 2e-4)
RESUMED_WEIGHT_DECAY = cfg.get("weight_decay", 0.001)
RESUMED_GRAD_ACCUM_STEPS = cfg.get("gradient_accumulation_steps", 8)
# ... fetch any other tuned parameters used in peft_config or training_args ...

# --- Imports (Redundant if run in same session, but good practice) ---
import torch
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
    TrainingArguments, Trainer, DataCollatorForLanguageModeling
)
from peft import LoraConfig, PeftModel # No get_peft_model needed when loading trained adapter
from datasets import load_dataset # Reload dataset if needed
import wandb
import os

# --- W&B Init for Resuming ---
print(f"Attempting to resume W&B run: {RUN_ID_TO_RESUME}")
wandb.login() # Ensure logged in
resumed_run = wandb.init(
    project=PROJECT_NAME,
    id=RUN_ID_TO_RESUME,
    resume="must" # Tell W&B to resume this specific run
)
print(f"Successfully resumed W&B run: {resumed_run.id} at {resumed_run.path}")

# --- Reload Model and Tokenizer (as done in sweep function) ---
model_name = "tiiuae/falcon-7b-instruct"
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)
device_map = {"": 0}

print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map,
    token=None,
    trust_remote_code=True
)
base_model.config.use_cache = False # Important for training
base_model.config.pretraining_tp = 1

# --- Load PEFT Model from Checkpoint ---
# The checkpoint should contain the adapter weights and config.
print(f"Loading PEFT model from checkpoint: {CHECKPOINT_PATH_TO_RESUME}")
# is_trainable=True ensures LoRA layers are unfrozen for continued training
model = PeftModel.from_pretrained(base_model, CHECKPOINT_PATH_TO_RESUME, is_trainable=True)
model.gradient_checkpointing_enable() # Re-enable if needed
model.print_trainable_parameters()

# --- Load Tokenizer (from checkpoint or original) ---
# Loading from checkpoint ensures consistency if tokenizer was modified/saved with adapter
print(f"Loading tokenizer from checkpoint: {CHECKPOINT_PATH_TO_RESUME}")
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH_TO_RESUME, token=None)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
print("Tokenizer loaded.")

# --- Reload and Tokenize Data (ensure it's the same as before) ---
print("Loading and preprocessing dataset...")
dataset = load_dataset("tau/commonsense_qa")
# Redefine or import formatting/tokenization funcs if not in scope
# (Assuming format_prompt and tokenize_function are defined as before)
formatted_dataset = dataset.map(format_prompt, remove_columns=list(dataset['train'].features))
max_sequence_length = 256 # Same as before
tokenized_dataset = formatted_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
print("Dataset ready.")

# --- Define Training Arguments for Resuming ---
# Use a subdirectory within the W&B run's directory for new checkpoints
resume_output_dir = os.path.join(resumed_run.dir, "resume_results")
os.makedirs(resume_output_dir, exist_ok=True)
print(f"Resume output directory: {resume_output_dir}")

training_arguments = TrainingArguments(
    output_dir=resume_output_dir, # Save new checkpoints here
    num_train_epochs=ADDITIONAL_EPOCHS, # Train for more epochs
    per_device_train_batch_size=1, # Match original run if possible, check cfg
    gradient_accumulation_steps=RESUMED_GRAD_ACCUM_STEPS, # MUST match resumed run config
    optim="paged_adamw_32bit",
    save_strategy="steps",
    save_steps=100, # Continue saving checkpoints
    logging_steps=10,
    learning_rate=RESUMED_LR, # Continue with the same LR or decay schedule
    weight_decay=RESUMED_WEIGHT_DECAY, # Match resumed run config
    fp16=False,
    bf16=torch.cuda.is_bf16_supported(),
    max_grad_norm=0.3,
    max_steps=-1, # Train for specified epochs
    warmup_ratio=0.03, # This might restart warmup, consider setting warmup_steps=0 if resuming after warmup
    group_by_length=True,
    lr_scheduler_type="cosine", # Match resumed run config
    report_to="wandb", # Continue reporting to the same run
    evaluation_strategy="steps",
    eval_steps=100,
    per_device_eval_batch_size=1,
    save_total_limit=2,
    load_best_model_at_end=True, # Load the best model *within this resumed training phase*
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    gradient_checkpointing=True, # Re-enable if needed
)

# --- Data Collator ---
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# --- Initialize Trainer ---
trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
)

# --- Resume Training ---
print(f"Resuming training from checkpoint: {CHECKPOINT_PATH_TO_RESUME}")
# The `resume_from_checkpoint` argument tells Trainer to load optimizer, scheduler, etc. state
train_result = trainer.train(resume_from_checkpoint=CHECKPOINT_PATH_TO_RESUME)

print("Resumed training finished.")
# Save the final model state after this resumed training phase
final_save_path = os.path.join(resume_output_dir, "final_model_after_resume")
trainer.save_model(final_save_path)
print(f"Final model saved to: {final_save_path}")
trainer.log_metrics("resume_train", train_result.metrics)
trainer.save_metrics("resume_train", train_result.metrics)

# --- Finish W&B Run ---
resumed_run.finish()
print(f"Finished and closed W&B run: {resumed_run.id}")

# Clean up memory
del model
del base_model
del trainer
torch.cuda.empty_cache()
print("CUDA cache cleared after resuming.")

## 8. Inference Example (After Training)

In [None]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    GenerationConfig # Added for generation parameters
)
from peft import PeftModel
from datasets import load_dataset
import random
import os # For path joining

# --- Configuration ---
base_model_name = "tiiuae/falcon-7b-instruct"

# ===> IMPORTANT: UPDATE THIS PATH <===
# Point this to the checkpoint or final saved adapter directory you want to use for inference
# Example from sweep: "wandb/sweep-SWEEPID/run-RUNID/files/results/checkpoint-XXX"
# Example from resume: "wandb/sweep-SWEEPID/run-RUNID/files/resume_results/final_model_after_resume"
# Example from original non-sweep run: "./results_llama2_7b_commonsenseqa/final_adapter"
adapter_model_dir = "./results_llama2_7b_commonsenseqa/final_adapter" # <--- REPLACE Appropriately

if not os.path.isdir(adapter_model_dir):
     print(f"ERROR: Adapter directory not found: {adapter_model_dir}")
     # Add logic to stop or raise error
     # raise FileNotFoundError(f"Adapter directory not found: {adapter_model_dir}")


# --- Reload Quantization Config (Must match training) ---
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

device_map = {"": 0} # Load model on default GPU

# --- Load Base Model ---
print(f"Loading base model: {base_model_name}")
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    # low_cpu_mem_usage=True, # Can help if CPU RAM is limited
    return_dict=True,
    torch_dtype=compute_dtype, # Use compute_dtype
    device_map=device_map,
    trust_remote_code=True if "falcon" in base_model_name else False, # Needed for Falcon
    token=None # Ensure no token is used for open models
)
print("Base model loaded.")

# --- Load Tokenizer ---
# Load the tokenizer saved WITH THE ADAPTER (ensures consistency)
print(f"Loading tokenizer from adapter directory: {adapter_model_dir}")
tokenizer = AutoTokenizer.from_pretrained(adapter_model_dir, token=None)
# Ensure padding token is set correctly (usually EOS for these models)
if tokenizer.pad_token is None:
     tokenizer.pad_token = tokenizer.eos_token
     tokenizer.padding_side = "right" # Important for generation
print("Tokenizer loaded.")

# --- Load PEFT Adapter ---
print(f"Loading PEFT adapter from: {adapter_model_dir}")
# Load the LoRA adapter onto the base model
model = PeftModel.from_pretrained(base_model, adapter_model_dir)
print("PEFT adapter loaded.")

# --- Merge Adapter (Optional, for faster inference but uses more RAM initially) ---
# print("Merging adapter weights...")
# model = model.merge_and_unload() # This replaces the model object
# print("Adapter merged.")

# --- Set to Evaluation Mode ---
model.eval()
print("Model set to evaluation mode.")

# --- Load Original Dataset (for picking a sample) ---
print("Loading original CommonsenseQA dataset...")
original_dataset = load_dataset("tau/commonsense_qa")
validation_data = original_dataset['validation']
print("Dataset loaded.")

# --- Select a Random Validation Sample ---
random_index = random.randint(0, len(validation_data) - 1)
sample = validation_data[random_index]

# --- Prepare Prompt for Inference (WITHOUT the answer) ---
question = sample['question']
choices_text = sample['choices']['text']
choices_labels = sample['choices']['label']
true_answer_key = sample['answerKey']

# Format prompt exactly as used in training, but stop before the answer
inference_prompt = f"### Question:\n{question}\n\n### Choices:\n"
for label, text in zip(choices_labels, choices_text):
    inference_prompt += f"{label}) {text}\n"
inference_prompt += f"\n### Answer:\n" # Model generates what comes next

# --- Tokenize the Inference Prompt ---
device = model.device # Get the device the model is on
inputs = tokenizer(inference_prompt, return_tensors="pt", padding=False).to(device)

# --- Generate the Answer ---
print("\n--- Running Inference ---")
print(f"Sample Index: {random_index}")
print("\nInput Prompt Sent to Model:")
print("---------------------------")
print(inference_prompt)
print("---------------------------")
print(f"Actual Answer Key: {true_answer_key}")
print("\nGenerating...")

# Configuration for generation
generation_config = GenerationConfig(
    max_new_tokens=5,       # Generate only a few tokens (A, B, C, D, E + maybe newline/EOS)
    temperature=0.1,        # Low temperature for deterministic output
    top_p=0.9,              # Can adjust, but low temp is often enough
    do_sample=False,        # Use greedy decoding (most likely token)
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    # repetition_penalty=1.1 # Optional: discourage repetition
)

with torch.no_grad(): # Disable gradient calculation for inference
    outputs = model.generate(**inputs, generation_config=generation_config)

# --- Decode and Display Results ---
# Decode only the newly generated tokens (slice the output tensor)
# Ensure slicing accounts for the prompt length correctly
input_token_len = inputs['input_ids'].shape[1]
generated_token_ids = outputs[0][input_token_len:]
generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True)

print("\n--- Results ---")
print(f"Raw Generated Text: '{generated_text}'")

# Attempt to parse the prediction
predicted_answer = generated_text.strip().upper() # Remove whitespace, uppercase
parsed_key = None
if predicted_answer and predicted_answer[0] in ['A', 'B', 'C', 'D', 'E']:
     parsed_key = predicted_answer[0]
     print(f"Parsed Predicted Key: {parsed_key}")
     if parsed_key == true_answer_key:
         print("Outcome: CORRECT")
     else:
         print("Outcome: INCORRECT")
else:
     print(f"Outcome: Could not parse a valid key (A-E) from generation: '{predicted_answer}'")

# (Optional) Clean up GPU memory if running multiple inferences
# del model
# del base_model
# torch.cuda.empty_cache()

## 9. Evaluation on Validation Set (After Training)

In [3]:
import torch
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
)
from peft import PeftModel
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm.notebook import tqdm # Progress bar
import numpy as np
import wandb
import os # For path joining

# --- Configuration (Ensure these match your trained model!) ---
base_model_name = "tiiuae/falcon-7b-instruct"

# ===> IMPORTANT: UPDATE THIS PATH <===
# Point this to the checkpoint or final saved adapter directory you want to evaluate
adapter_model_dir = "./wandb/run-20250501_133606-tig02dc9/files/results/checkpoint-1200"

if not os.path.isdir(adapter_model_dir):
     print(f"ERROR: Adapter directory not found: {adapter_model_dir}")
     # Add logic to stop or raise error
     # raise FileNotFoundError(f"Adapter directory not found: {adapter_model_dir}")


dataset_name = "tau/commonsense_qa"
split_to_evaluate = "validation" # Or "test" if you want final test metrics
# Optional: Initialize W&B if you want to log evaluation metrics to a specific run
# EVAL_RUN_ID = "abcdef12" # ID of the run you want to associate these metrics with
# wandb.init(project="falcon-commonsenseqa-sweep", id=EVAL_RUN_ID, resume="allow")


# --- Reload Quantization Config (Must match training) ---
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)
device_map = {"": 0}

# --- Load Base Model & Tokenizer ---
print(f"Loading base model: {base_model_name}")
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    return_dict=True,
    torch_dtype=compute_dtype,
    device_map=device_map,
    trust_remote_code=True if "falcon" in base_model_name else False,
    token=None
)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(adapter_model_dir, token=None) # Load tokenizer saved with adapter
if tokenizer.pad_token is None:
     tokenizer.pad_token = tokenizer.eos_token
     tokenizer.padding_side = "right" # Use right padding for generation

# --- Load PEFT Adapter ---
print(f"Loading PEFT adapter from: {adapter_model_dir}")
model = PeftModel.from_pretrained(base_model, adapter_model_dir)
model.eval() # Set to evaluation mode
device = model.device
print("Model ready for evaluation.")

# --- Load Dataset Split ---
print(f"Loading dataset split: {split_to_evaluate}")
eval_dataset = load_dataset(dataset_name, split=split_to_evaluate)

# --- Prepare for Evaluation ---
label_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}
reverse_label_map = {v: k for k, v in label_map.items()} # For debugging
y_true_eval = []
y_pred_eval = []
y_pred_parsed_keys = [] # Store parsed keys (A, B, C, D, E) or None
y_true_keys = [] # Store true keys

# Generation config (deterministic for evaluation)
generation_config = GenerationConfig(
    max_new_tokens=5, # Enough for the letter + maybe EOS/newline
    temperature=0.1,  # Low temp
    do_sample=False,  # Greedy decoding
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

# --- Evaluation Loop ---
print(f"Running evaluation on {len(eval_dataset)} samples from '{split_to_evaluate}' split...")
num_parse_errors = 0
for i in tqdm(range(len(eval_dataset))):
    sample = eval_dataset[i]
    question = sample['question']
    choices_text = sample['choices']['text']
    choices_labels = sample['choices']['label']
    true_answer_key = sample['answerKey']
    true_label_numeric = label_map.get(true_answer_key, -1) # Handle potential invalid keys in data

    y_true_keys.append(true_answer_key)
    y_true_eval.append(true_label_numeric)

    # Format prompt
    inference_prompt = f"### Question:\n{question}\n\n### Choices:\n"
    for label, text in zip(choices_labels, choices_text):
        inference_prompt += f"{label}) {text}\n"
    inference_prompt += f"\n### Answer:\n"

    # Tokenize
    inputs = tokenizer(inference_prompt, return_tensors="pt", padding=False).to(device)

    # Generate
    predicted_key_numeric = -1 # Default to -1 for parse failure
    predicted_parsed_key = None
    try:
        with torch.no_grad():
            outputs = model.generate(**inputs, generation_config=generation_config)
        # Decode generated part
        input_token_len = inputs['input_ids'].shape[1]
        generated_token_ids = outputs[0][input_token_len:]
        generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True)

        # Parse prediction
        parsed_pred = generated_text.strip().upper()
        if parsed_pred and parsed_pred[0] in label_map:
            predicted_parsed_key = parsed_pred[0]
            predicted_key_numeric = label_map[predicted_parsed_key]
        else:
            num_parse_errors += 1
            # Optional: Log the failed parse
            # print(f"Parse Error (Index {i}): Prompt:\n{inference_prompt}\nGenerated: '{generated_text}' Parsed: '{parsed_pred}'")

    except Exception as e:
        print(f"\nError during generation/parsing for index {i}: {e}")
        num_parse_errors += 1
        # Keep predicted_key_numeric as -1 and predicted_parsed_key as None

    y_pred_eval.append(predicted_key_numeric)
    y_pred_parsed_keys.append(predicted_parsed_key)


print("Evaluation loop finished.")
print(f"Total parse errors: {num_parse_errors}")

# --- Calculate Metrics ---
# Overall accuracy includes parse failures (counted as incorrect)
correct_predictions = sum(1 for true, pred in zip(y_true_eval, y_pred_eval) if true == pred and true != -1)
total_samples = len(y_true_eval)
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0

# For Precision, Recall, F1, only consider samples where parsing succeeded *and* true label was valid
valid_indices = [i for i, (p, t) in enumerate(zip(y_pred_eval, y_true_eval)) if p != -1 and t != -1]

if len(valid_indices) > 0:
    filtered_y_true = [y_true_eval[i] for i in valid_indices]
    filtered_y_pred = [y_pred_eval[i] for i in valid_indices]

    precision, recall, f1, _ = precision_recall_fscore_support(
        filtered_y_true,
        filtered_y_pred,
        average='macro', # Average metrics across classes (A, B, C, D, E)
        zero_division=0   # Set metric to 0 if no predictions/true labels for a class
    )
    num_parsed = len(valid_indices)
else:
    print("Warning: Could not parse any valid predictions (A-E), or no valid true labels found. Precision/Recall/F1 will be 0.")
    precision, recall, f1 = 0.0, 0.0, 0.0
    num_parsed = 0


# --- Print Results ---
print("\n--- Evaluation Metrics ---")
print(f"Split Evaluated:        {split_to_evaluate}")
print(f"Total Samples:          {total_samples}")
print(f"Successfully Parsed:    {num_parsed} ({num_parsed/total_samples:.1%} of total, excluding invalid true labels)")
print(f"Accuracy (overall):     {accuracy:.4f} (Correct predictions / Total samples)")
print(f"Precision (macro, parsed only): {precision:.4f}")
print(f"Recall (macro, parsed only):    {recall:.4f}")
print(f"F1 Score (macro, parsed only):  {f1:.4f}")

# --- Log to W&B if initialized ---
if wandb.run:
    wandb.log({
        f"{split_to_evaluate}_accuracy": accuracy,
        f"{split_to_evaluate}_precision_macro": precision,
        f"{split_to_evaluate}_recall_macro": recall,
        f"{split_to_evaluate}_f1_macro": f1,
        f"{split_to_evaluate}_parsed_count": num_parsed,
        f"{split_to_evaluate}_total_count": total_samples,
        f"{split_to_evaluate}_parse_errors": num_parse_errors
    })
    print(f"Metrics logged to W&B run: {wandb.run.path}")
    # Optional: Finish the run if it was only for evaluation
    # wandb.finish()
else:
    print("W&B run not active. Metrics not logged.")


# (Optional) Clean up GPU memory
del model
del base_model
torch.cuda.empty_cache()
print("CUDA cache cleared after evaluation.")

Loading base model: tiiuae/falcon-7b-instruct


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading tokenizer...
Loading PEFT adapter from: ./wandb/run-20250501_133606-tig02dc9/files/results/checkpoint-1200
Model ready for evaluation.
Loading dataset split: validation
Running evaluation on 1221 samples from 'validation' split...




  0%|          | 0/1221 [00:00<?, ?it/s]



Evaluation loop finished.
Total parse errors: 0

--- Evaluation Metrics ---
Split Evaluated:        validation
Total Samples:          1221
Successfully Parsed:    1221 (100.0% of total, excluding invalid true labels)
Accuracy (overall):     0.6470 (Correct predictions / Total samples)
Precision (macro, parsed only): 0.6466
Recall (macro, parsed only):    0.6466
F1 Score (macro, parsed only):  0.6465
W&B run not active. Metrics not logged.
CUDA cache cleared after evaluation.
