# Homework - Fine-tuning leads to forgetting
In this homework you will use your acquired knowledge of fine-tuning large language models (LLMs) to improve the performance of a model on the GSM8K mathematics dataset. You will implement a few-shot learning approach and evaluate the model's performance on the public test set BoolQ.
Because this homework may be challenging, a lot of code is already provided, and you have a clear task of improving the performance of the model in terms of accuracy on the fine-tuning task and absence of forgetting on the evaluation task.

## Instructions
Submit your assignment as Jupter notebook with all relevant execution outputs visible. Please clearly indicate and motivate the relevant steps in your code using comments. Partial grades can be given for incomplete results when the steps are clearly marked. 

Be sure to use comments throughout your code to clearly indicate the purpose of each section and the specific steps you are performing. While complete results are preferred, partial grades may be given if the steps are clearly marked and the work is well-documented.

Report your scores on the GMS8K test set and the BoolQ test set in the notebook. 
Explanations on why your parameters/tactics are chosen for each task are required and if the score improvement is available, you can report it as well.

Task list (you may complete the tasks in any order):

1. Set up the LoRA configuration. (2 point)
2. Set up fixed few-shot example selection. (You may even use your own few-shot examples, but make sure to document them in the notebook) (3 points)
3. Set the number of few-shot examples during training and evaluation. (1 point)
4. Add weight decay to the optimizer. (1 point)
5. Add a learning rate scheduler. (1 point)
6. Add a reasonable warmup ratio. (1 point)
7. Decrease the learning rate. (1 point)
8. Alter the number of training epochs if necessary. (1 point)

### After training

9. Set up greedy decoding for inference. (1 point)
10. Set the inference max_new_tokens parameter to a suitable value. (2 points)
11. Set checkpoint for loading the inference model. (1 point)
12. Set the test number of few-shot examples. (1 point)
13. For full evaluation, increase the ``EVAL_LIMIT`` to ``None``. (1 point)


## Objectives

- A. Fine-tune the model on the GSM8K dataset using the LoRA configuration.
- B. Evaluate the model on the GSM8K test set and tweak your setup to get the accuracy above 50%. (5 points if above threshold)
- C. Evaluate the model on the 100 samples of the BoolQ test set and tweak your setup to avoid forgetting. Thus make sure the accuracy does not significantly decrease from the baseline accuracy of 66% on the 100 samples of the BoolQ test set while achieving good performance on the GSM8K test set. (5 points if above threshold)
- D. Report the final scores on both test sets in the notebook and explain your choices for each task in a markdown cell. **Make sure to properly format your explanation and scores in the answer cell below. Unclear answers will be assigned 0 points.**


ANSWER CELL

**Student name:**  
**Student email:**  
**Date:**  

---

**Task 1:**
- Choice:
- Motivation:

**Task 2:**
- Choice:
- Motivation:

**Task 3:**
- Choice:
- Motivation:

**Task i:**
- Choice:
- Motivation:

Use the above task format for all tasks.

> **Submission checklist (student):**  
> - [ ] All “Choice” fields filled with the chosen approach
> - [ ] All “Motivation” fields (2–4 sentences)  
> - [ ] All code is reproducible

**Objective B:**
- Result:
- Explanation: 

**Objective C:**
- Result:
- Explanation: 

> **Submission checklist (student):**  
> - [ ] All “Result" fields filled with the obtained results
> - [ ] All ”Explanation" fields (5–10 sentences) filled with an analysis of the results.
> - [ ] All code is reproducible



## 1. Environment and data setup

Set up GPU availability, download datasets, and authenticate with Hugging Face.

In [None]:
## Make sure we have a GPU available
!nvidia-smi

### Download datasets

The code block below downloads the relevant datasets.

In [None]:
import urllib.request
import ssl

# Create an SSL context that doesn't verify certificates
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE

urls = [
    ('https://www.csie.ntu.edu.tw/~b10902031/gsm8k_train.jsonl', 'gsm8k_train.jsonl'),
    ('https://www.csie.ntu.edu.tw/~b10902031/gsm8k_test_public.jsonl', 'gsm8k_test_public.jsonl'),
]

for url, filename in urls:
    print(f"Downloading {filename}...")
    try:
        # Create a request with custom SSL context
        request = urllib.request.Request(url)
        with urllib.request.urlopen(request, context=ssl_context) as response:
            with open(filename, 'wb') as f:
                f.write(response.read())
        print(f"Downloaded {filename}")
    except Exception as e:
        print(f"Error downloading {filename}: {e}")

### Authenticate to Hugging Face and get model access

Login to huggingface and make sure you have gained access to the LLama-3.2-1B-Instruct model. Acces can be gained at [huggingface](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) by clicking on the "Request Access" button. After you have access, you can use the code below to download the model and tokenizer.

In [None]:
# Load Hugging Face API key from environment (do NOT hardcode your token here).
# Load Hugging Face API key from environment (do NOT hardcode your token here).
import os
import logging, warnings
from transformers import logging as hf_logging

# Silence transformers/TRL logs early
hf_logging.set_verbosity_error()
logging.getLogger("trl").setLevel(logging.ERROR)

# Hide specific noisy warnings
warnings.filterwarnings(
    "ignore",
    message=r".*loss_type=None.*ForCausalLMLoss.*",
    category=UserWarning,
)
warnings.filterwarnings(
    "ignore",
    message=r".*cuDNN SDPA backward got grad_output\.strides\(\) != output\.strides\(\).*",
    category=UserWarning,
)
os.environ["TQDM_NOTEBOOK"] = "0"  

from huggingface_hub import login
from dotenv import load_dotenv

# Load .env file (if present)
load_dotenv()
hf_key = os.environ.get("HUGGINGFACE_API_KEY")
if hf_key:
    login(hf_key)
else:
    raise EnvironmentError("HUGGINGFACE_API_KEY not found. Copy .env.template to .env and add your token. See Instruction.md")

### Imports

Below are the necessary imports for the homework.

In [None]:
from transformers import (
    AutoModelForCausalLM, # imports the model for causal language modeling
    AutoTokenizer, # imports the tokenizer for the model
    pipeline # imports the pipeline for text generation
)
from peft import (
    LoraConfig, # imports the configuration for LoRA
    get_peft_model, # imports the function to get the PEFT model
    PeftModel # imports the PEFT model
)
import json
import torch
os.environ["CUDA_VISIBLE_DEVICES"] = '0' # Sets the CUDA device to use
device = torch.device('cuda:0') # Creates a CUDA device object
from datasets import Dataset # Imports the Dataset class from the datasets library
from trl import SFTConfig, SFTTrainer # Imports the SFTConfig and SFTTrainer classes from the trl library
import random
random.seed(42) # Sets the random seed for reproducibility
from tqdm import tqdm # Imports the tqdm library for progress bars
import copy

## 2. Model, LoRA configuration, and data preparation

Import libraries, configure the base model and LoRA, and prepare few-shot data.

### Model and LoRA configuration

Import the model and setup your LoRA configuration in a reasonable way.

**Task 1: Set up the LoRA configuration. (2 point)**

In [None]:
sft_model_name = 'meta-llama/Llama-3.2-1B-Instruct'  # No quantization
sft_model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=sft_model_name,
    torch_dtype=torch.float16,  # No quantization
    low_cpu_mem_usage=True,
)
base_model = copy.deepcopy(sft_model)  # Keep a copy of the base model for evaluation
sft_tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path=sft_model_name,
)
sft_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
peft_config = LoraConfig( ## TODO: Add LoRA parameters
    bias='none',
    task_type='CAUSAL_LM',
    target_modules=['q_proj', 'v_proj']
)
peft_model = get_peft_model(sft_model, peft_config)

### Data helpers and few-shot strategy

Below are helpers to read the train set and setup the few-shot learning approach. The current approach uses random sampling to select the few-shot examples. Change this to a fixed and sophisticated approach.

**Task 2. Set up fixed few-shot example selection. (3 points)**

In [None]:
def load_jsonlines(file_name: str):
    f = open(file_name, 'r')
    return [json.loads(line) for line in f]

def nshot_chats(nshot_data: list, n: int, question: str, answer: any, mode: str) -> dict: # Function to create n-shot chats
    if mode not in ['train', 'test']:
        raise AssertionError('Undefined Mode!!!')

    chats = []
    # TODO: Use fixed few-shot examples
    for qna in random.sample(nshot_data, n): # Samples n examples from the n-shot data
        chats.append(
            {
                'role': 'user',
                'content': f'Q: {qna["question"]}' # Creates a user message with the question
            }
        )
        chats.append(
            {
                'role': 'assistant',
                'content': f'A: {qna["answer"]}' # Creates an assistant message with the answer
            }
        )

    chats.append(
        {
            'role': 'user',
            'content': f'Q: {question} Let\'s think step by step. At the end, you MUST write the answer as an integer after \'####\'.' # Creates a user message with the question and instructions
        }
    )
    if mode == 'train':
        chats.append(
            {
                'role': 'assistant',
                'content': f'A: {answer}' # Creates an assistant message with the answer
            }
        )

    return chats # Returns the list of chats

### Prepare training data and choose N-shot examples

You may choose your number of few-shot examples yourselves to maximize performance and minimize unnecessary compute. When playing around, limit the number of examples in the dataset to a small number to speed up the process.

**Task 3: Set the number of few-shot examples during training and evaluation. (1 point)**

In [None]:
gsm8k_train = load_jsonlines('gsm8k_train.jsonl') # You can use refined gsm8k_train_self-instruct.jsonl for fine-tuning

formatted_gsm8k = []
TRAIN_N_SHOT = 0 # TODO: Give model more examples
TRAIN_LIMIT = None # Limit the number of examples for testing
max_token_len = 0 # Record token length of dataset and prevent data from truncation
for i, qna in enumerate(gsm8k_train): # Iterates over the GSM8K training data
    chats = nshot_chats(nshot_data=gsm8k_train, n=TRAIN_N_SHOT, question=qna['question'], answer=qna['answer'], mode='train') # Creates n-shot chats for the current example
    train_sample = sft_tokenizer.apply_chat_template(chats, tokenize=False) # Applies the chat template to the chats

    train_sample = train_sample[train_sample.index("<|eot_id|>") + len("<|eot_id|>"):] # Remove Cutting Knowledge Date in prompt template
    formatted_gsm8k.append( # Appends the formatted example to the list
        {
            'text': train_sample # Adds the text of the example
        }
    )
    max_token_len = max(max_token_len, len(sft_tokenizer(train_sample)['input_ids'])) # Updates the maximum token length
    if TRAIN_LIMIT and i > TRAIN_LIMIT: # Limit the number of examples for testing
        break
    
print(f"Last example: {train_sample}")
formatted_gsm8k = Dataset.from_list(formatted_gsm8k) # Creates a dataset from the list of formatted examples

## 3. Training (SFT + LoRA)

Configure and run supervised fine-tuning with LoRA.

### Configure training

The code block below sets up the training. Configure reasonable parameters.


**Task 4: Add weight decay to the optimizer. (1 point)**

**Task 5: Add a learning rate scheduler. (1 point)**

**Task 6: Add a reasonable warmup ratio. (1 point)**

**Task 7: Decrease the learning rate. (1 point)**

**Task 8: Alter the number of training epochs if necessary. (1 point)**

In [None]:
# trainer
training_arguments = SFTConfig( # Configuration for the SFT trainer
    seed=1126,
    data_seed=1126,
    output_dir=f"sft",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    optim="paged_adamw_32bit",
    num_train_epochs=1, # TODO: If you use fixed few-shot examples, possibly increase epochs
    logging_strategy="steps",
    logging_steps=0.1,
    save_strategy="steps",
    save_steps=0.1,
    learning_rate=2e-4, # TODO: Decrease learning rate  # TODO: Add weight decay # TODO: Add warmup ratio # TODO: Add lr scheduler 
    bf16=True,
    group_by_length=True,
    max_seq_length=max_token_len,
    dataset_text_field='text',
    report_to='none',
)
print(f"On device: {device}")
trainer = SFTTrainer( # Creates the SFT trainer
    model=peft_model,
    train_dataset=formatted_gsm8k,
    peft_config=peft_config,
    processing_class=sft_tokenizer,
    args=training_arguments,
)
trainer.train() # Starts the training process

## 4. Inference and checkpoints

Load checkpoints, set greedy decoding, and define helper functions.

### Load checkpoint and set up inference

Now load a checkpoint of your choice and set up inference. Switch the inference to greedy decoding instead of sampling.

**Task 9: Set up greedy decoding for inference. (1 point)**

**Task 10: Set the inference max_new_tokens parameter to a suitable value. (2 points)**

**Task 11: Set checkpoint for loading the inference model. (1 point)**

In [None]:
generator = pipeline( # Creates a text generation pipeline
    'text-generation',
    model=sft_model,
    tokenizer=sft_tokenizer,
    pad_token_id=sft_tokenizer.eos_token_id,
    max_new_tokens=64, # TODO: Increase max_new_tokens for longer output
    # TODO: Use greedy decoding strategy
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
adapter_path = 'sft/checkpoint-1122' # TODO: Evaluate different checkpoints
pipeline.model = PeftModel.from_pretrained( # Loads the adapter checkpoint
    sft_model,
    adapter_path
)

### Inference helpers

Utility functions for generation and answer extraction.

In [None]:

def get_response(chats: list): # Function to get the response from the model
    # Apply chat template just like in training data preparation
    formatted_prompt = sft_tokenizer.apply_chat_template(chats, tokenize=False, add_generation_prompt=True)
    # Remove system prompt part to match training format
    if "<|eot_id|>" in formatted_prompt:
        formatted_prompt = formatted_prompt[formatted_prompt.index("<|eot_id|>") + len("<|eot_id|>"):]

    gen_text = generator(formatted_prompt)[0]  # Generate from formatted prompt
    
    return gen_text['generated_text'][len(formatted_prompt):].strip() # Return only the new generated text

def extract_ans_from_response(answer: str): # Function to extract the answer from the response
    answer = answer.split('####')[-1].strip() # Splits the answer by '####' and takes the last part

    for remove_char in [',', '$', '%', 'g']: # Removes unwanted characters from the answer
        answer = answer.replace(remove_char, '')

    return answer # Returns the extracted answer

## 5. Evaluation

Evaluate on GSM8K (fine-tuned task) and BoolQ (forgetting check).

### Evaluate on GSM8K

Great! You have now set up the training and inference. The next step is to evaluate the model on the fine-tuning task. Configure your number of shots and limit the number of examples when playing around to speed up the process. You must report the final accuracy on the fine-tuning task on the whole set when you submit your homework.

**Task 12: Set the test number of few-shot examples. (1 point)**

**Task 13: For full evaluation, increase the ``EVAL_LIMIT`` to ``None``. (1 point)**

In [None]:
gsm8k_predictions = []
TEST_N_SHOT = 0 # TODO: give model more examples
EVAL_LIMIT = 20 # TODO: Change to None for full evaluation or keep it for testing
gsm8k_test_public = load_jsonlines('gsm8k_test_public.jsonl') # Loads the GSM8K public test data
gsm8k_total = len(gsm8k_test_public) # Gets the total number of examples in the public test data
gsm8k_progress_bar = tqdm(total=gsm8k_total, desc='GSM8K Public Test Data Evaluation', postfix='Current Accuracy = 0.000') # Creates a progress bar for the public test data evaluation

correct = 0
for i, qna in enumerate(gsm8k_test_public): # Iterates over the public test data
    if EVAL_LIMIT and i >= EVAL_LIMIT - 1:
        break
    
    messages = nshot_chats(nshot_data=gsm8k_train, n=TEST_N_SHOT, question=qna['question'], answer=None, mode='test') # Creates n-shot chats for the current example
    response = get_response(messages) # Gets the response from the model
    pred_ans = extract_ans_from_response(response) # Extracts the predicted answer from the response
    true_ans = extract_ans_from_response(qna["answer"]) # Extracts the true answer from the example
    if i < 3:
        print(f"Example {i+1}/{gsm8k_total}: Question: {qna['question']}") # Prints the question for the current example
        print(f"Example {i+1}/{gsm8k_total}: Response: {response}") # Prints the response for the current example
        print(f"Example {i+1}/{gsm8k_total}: Predicted Answer: {pred_ans}, True Answer: {true_ans}") # Prints the predicted and true answers for the current example
        print("")
    if pred_ans == true_ans: # Checks if the predicted answer is correct
        correct += 1 # Increments the correct count if the prediction is correct
    gsm8k_predictions.append(pred_ans) # Appends the predicted answer to the list of predictions
    
    gsm8k_progress_bar.set_postfix_str(f'Current Accuracy = {correct/(i+1):.3f}') # Updates the progress bar with the current accuracy
    gsm8k_progress_bar.update() # Updates the progress bar


print(f"Predicted last answer: {pred_ans}") # Prints the last predicted answer
print(f"True last answer: {true_ans}") # Prints the true answer of the last example
gsm8k_progress_bar.close() # Closes the progress bar

print(f'GSM8K Public Test Data Evaluation Complete, Total Accuracy: {correct/EVAL_LIMIT:.3f}') # Prints the total accuracy on the public test data

### Evaluate on BoolQ

Now we evaluate how much the model has forgotten on an evaluation task. The evaluation task is the BoolQ dataset, which is a binary question answering dataset. The model should not have been trained on this dataset. The model should only be able to answer the questions based on its general knowledge. The submitted homework should report the accuracy on the 100 examples in the BoolQ dataset below. The base model score is 66%.

In [None]:
from datasets import load_dataset
import pandas as pd
boolq = load_dataset("boolq")
boolq_df = pd.DataFrame(boolq['validation'])

sample_size = min(100, len(boolq_df))
df_sample = boolq_df.sample(n=sample_size, random_state=45)

correct = 0
results = []

boolq_progress_bar = tqdm(total=sample_size, desc='BoolQ Evaluation')

for i, (idx, row) in enumerate(df_sample.iterrows()):
    question = row['question']
    correct_answer = row['answer']  # True/False
    
    # Create chat format just like in training
    messages = [{'role': 'user', 'content': f"Question: {question}\n\nAnswer with just 'Yes' or 'No':"}]
    response = get_response(messages)  # This will now apply chat template consistently
    
    # Extract Yes/No from response
    response_lower = response.lower()
    if 'yes' in response_lower[:30]:  # Look in first 30 chars
        predicted = True
    elif 'no' in response_lower[:30]:
        predicted = False
    else:
        predicted = None
        
    is_correct = predicted == correct_answer
    if is_correct:
        correct += 1
        
    results.append({
        'question': question,
        'correct_answer': correct_answer,
        'predicted': predicted,
        'is_correct': is_correct,
        'response': response
    })
    if i < 3:
        print(f"Example {i+1}:")
        print(f"  Q: {question}")
        print(f"  Correct: {correct_answer}, Predicted: {predicted}, Right: {is_correct}")
        print(f"  Response: {response}")
        print()
    
    boolq_progress_bar.update()

boolq_progress_bar.close()

accuracy = correct / sample_size

print(f"\n=== BOOLQ EVALUATION RESULTS ===")
print(f"Total questions: {sample_size}")
print(f"Correct answers: {correct}")
print(f"Accuracy: {accuracy:.3f} ({accuracy*100:.1f}%)")