# SFT on News Memorization

This file is used to train a supervised fine-tuned model for the news memorization task.

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from datasets import load_dataset, Dataset
from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
from trl import SFTTrainer

import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
# setup the huggingface key
import json
with open('../apikeys.json', 'r') as f:
    apikeys = json.load(f)
hf_key = apikeys['hf_api_key']

In [16]:
# login to huggingface
from huggingface_hub import login
login(token=hf_key)

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Setup Config for SFT Training

Before running the training, we need to setup the config for the SFT training. 


In [4]:
from transformers import BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig

In [5]:
# SFT General Configs
sft_model_name = "meta-llama/Meta-Llama-3-8B"
sft_data_dir = "../datasets/latest_news/latest_news_memorization.csv"
sft_output_dir = "./sft_models/latest_news_memorization"
sft_log_dir = "./sft_logs/latest_news_memorization"

In [6]:
# bnb Configs
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [43]:
# peft Configs (Lora Config)
peft_config = LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "k_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

In [8]:
# Training Arguments
training_args = TrainingArguments(
    per_device_train_batch_size=2,  # batch size per device
    gradient_accumulation_steps=2,  # number of updates steps to accumulate before performing a backward/update pass
    gradient_checkpointing =False,  # disable gradient checkpointing
    max_grad_norm= 0.3,  # max gradient norm
    num_train_epochs=1,  # train one epoch each call
    save_steps= 100,  # save the model every x steps (step = batch size * gradient accumulation steps)
    learning_rate=2e-4,  # learning rate
    bf16=True,  # use bf16 for training
    save_total_limit=2,  # save the best 2 checkpoints (1 best and 1 last)
    eval_strategy="no",  # no auto evaluation
    output_dir=sft_output_dir,  # output directory
    logging_dir=sft_log_dir,  # logging directory
    optim="paged_adamw_32bit",  # optimizer
    lr_scheduler_type="cosine",  # learning rate scheduler type
    warmup_ratio=0.05,  # warmup ratio
    remove_unused_columns=False  # remove unused columns
)

generate_max_length = 512
tokenizer_max_length = 512  # max length for tokenizer

## Load Dataset

Then we need to load the dataset. Specifically, the dataaset is a csv file with the following columns: `id`, `prompt`, `answer`, `article_title`, `question`, `fact`, `article_text`, `used_in_analysis`. For sft learning, we only need the `prompt` and `answer` columns.

In [26]:
def load_train_dataset():
    dataset = pd.read_csv(sft_data_dir)[:10]  # only use the first 10 rows for training
    dataset["answer"] = dataset["answer"] + tokenizer.eos_token  # add EOS token to the end of the answer
    dataset = Dataset.from_pandas(dataset)  # convert to huggingface dataset
    tokenizer = AutoTokenizer.from_pretrained(sft_model_name, token=hf_key)
    tokenizer.pad_token = tokenizer.eos_token  # set pad token to eos token
    
    def tokenize_function(examples):
        return tokenizer(examples["prompt"] + examples["answer"], 
                         padding="max_length", 
                         max_length=tokenizer_max_length,
                         truncation=True,
                         return_tensors="pt",
                         )
        
    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["id", "prompt", "answer", "article_title", "question", "fact", "article_text", "used_in_analysis"])    
    return tokenized_dataset

In [36]:
def load_train_dataset():
    # Load dataset and only use the first 10 rows for training
    dataset = pd.read_csv(sft_data_dir)[:10]
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(sft_model_name, token=hf_key)
    
    # Properly add padding token
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    
    def tokenize_function(examples):
        # Tokenize prompt and answer separately
        prompts = examples["prompt"]
        # Add EOS token to each answer individually
        answers = [ans + tokenizer.eos_token for ans in examples["answer"]]
        
        # Tokenize inputs (prompts)
        model_inputs = tokenizer(
            prompts,
            padding="max_length",
            max_length=tokenizer_max_length,
            truncation=True,
            return_tensors="pt"
        )
        
        # Tokenize labels (answers)
        labels = tokenizer(
            answers,
            padding="max_length", 
            max_length=tokenizer_max_length,
            truncation=True,
            return_tensors="pt"
        )
        
        # Create label mask: -100 for prompt tokens and padding
        label_mask = model_inputs["input_ids"].clone()
        label_mask[:] = -100  # Mask all prompt tokens
        
        # Only use answer tokens for loss calculation
        answer_input_ids = labels["input_ids"]
        answer_attention_mask = labels["attention_mask"]
        
        # Set answer tokens in label mask
        for i in range(len(label_mask)):
            answer_len = answer_attention_mask[i].sum()
            label_mask[i][-answer_len:] = answer_input_ids[i][:answer_len]
        
        return {
            "input_ids": model_inputs["input_ids"],
            "attention_mask": model_inputs["attention_mask"],
            "labels": label_mask
        }
    
    # Convert to HF dataset and apply tokenization
    dataset = Dataset.from_pandas(dataset)
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names
    )
    
    return tokenized_dataset

In [40]:
def load_eval_dataset():
    # Load dataset
    dataset = pd.read_csv(sft_data_dir)[:10]
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(sft_model_name, token=hf_key)
    
    # Add EOS token to answers while still in pandas DataFrame
    dataset["answer"] = dataset["answer"].apply(lambda x: x + tokenizer.eos_token)
    
    # Convert to HF dataset
    dataset = Dataset.from_pandas(dataset)
    
    # Remove unnecessary columns
    dataset = dataset.remove_columns([
        "id", "article_title", "question", "fact", 
        "article_text", "used_in_analysis"
    ])
    
    return dataset

## Load Model

Then we need to load the model. We use the `meta-llama/Meta-Llama-3-8B` model as the base model to be fine-tuned.

In [18]:
def load_model():
    model = AutoModelForCausalLM.from_pretrained(sft_model_name, 
                                                 quantization_config=bnb_config,
                                                 device_map="auto",  # use auto device mapping (GPU)
                                                 token=hf_key,
                                                 )
    model = prepare_model_for_kbit_training(model)
    tokenizer = AutoTokenizer.from_pretrained(sft_model_name, token=hf_key)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.model_max_length = tokenizer_max_length
    return model, tokenizer

In [29]:
def load_model():
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        sft_model_name, 
        quantization_config=bnb_config,
        device_map="auto",
        token=hf_key,
    )
    model = prepare_model_for_kbit_training(model)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(sft_model_name, token=hf_key)
    
    # Add padding token if it doesn't exist
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        # Resize model embeddings to account for new token
        model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
    
    # Set model max length
    tokenizer.model_max_length = tokenizer_max_length
    
    # Verify padding token is properly set
    assert tokenizer.pad_token == '[PAD]'
    assert tokenizer.pad_token_id is not None
    
    return model, tokenizer

## Evaluation (Accuracy)

We need to evaluate the model's accuracy on the dataset. Specifically, we need to check if the model's answer is the same as the ground truth answer.

In [46]:
# def evaluate_accuracy(eval_dataset, model, tokenizer):
#     correct = 0
#     total = len(eval_dataset)
#     # Create tqdm progress bar for the evaluation loop
#     for example in tqdm(eval_dataset, desc="Evaluating", total=total):
#         prompt = example["prompt"]
#         oracle_answer = example["answer"]

#         # Generate model output
#         input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids
#         input_ids = input_ids.to(device)  # move to device
#         output_ids = model.generate(input_ids, max_length=generate_max_length, pad_token_id=tokenizer.eos_token_id)
#         generated_answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
#         print(f"Generated answer: {generated_answer}")
#         print(f"Oracle answer: {oracle_answer}")

#         # Compare generated answer with oracle
#         if generated_answer.strip().lower() == oracle_answer.strip().lower():
#             correct += 1

#     accuracy = correct / total
#     return accuracy

def evaluate_accuracy(eval_dataset, model, tokenizer):
    correct = 0
    total = len(eval_dataset)
    
    for example in tqdm(eval_dataset, desc="Evaluating", total=total):
        prompt = example["prompt"]
        oracle_answer = example["answer"]

        # Generate model output with proper settings
        input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.to(device)
        
        outputs = model.generate(
            input_ids,
            max_length=generate_max_length,
            early_stopping=True,  # Stop when EOS is generated
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=False,  # Deterministic generation
            temperature=1.0,
            no_repeat_ngram_size=3,  # Prevent repetition
            length_penalty=1.0,
            top_p=1.0,  # Use top-p sampling
        )
        
        generated_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Question: {prompt}")
        print(f"Generated answer: {generated_answer}")
        print(f"Oracle answer: {oracle_answer}")
        # Remove the prompt from generated answer if it appears
        if prompt in generated_answer:
            generated_answer = generated_answer[len(prompt):].strip()
            
        if generated_answer.strip().lower() == oracle_answer.strip().lower():
            correct += 1

    accuracy = correct / total
    return accuracy

In [26]:
def plot_accuracy(epochs, accuracies):
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, accuracies, label='Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('News Memorization Accuracy v.s. Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{sft_output_dir}/accuracy.png")
    plt.show()

## Main Training Function

In [47]:
def main():
    model, tokenizer = load_model()
    train_dataset = load_train_dataset()
    eval_dataset = load_eval_dataset()
    model = get_peft_model(model, peft_config)  # apply peft to the model to add LoRA layers
    
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        args=training_args,
        tokenizer=tokenizer,
        max_seq_length=tokenizer_max_length,
    )
    
    epochs = []
    accuracies = []
    
    # Training and evaluation loop
    for epoch in range(int(training_args.num_train_epochs)):
        print(f"Epoch {epoch + 1} / {training_args.num_train_epochs}")
        if epoch == 0:  # train the first epoch
            trainer.train()
        else:  # train the rest of the epochs
            trainer.train(resume_from_checkpoint=True)
        print(f"Evaluate epoch {epoch + 1}")
        with torch.no_grad():
            accuracy = evaluate_accuracy(eval_dataset, model, tokenizer)  # Evaluate accuracy
        epochs.append(epoch + 1)
        accuracies.append(accuracy)
        print(f"Accuracy after epoch {epoch + 1}: {accuracy:.4f}")
        print()
    
    trainer.train()
    trainer.save_model(sft_output_dir)
    plot_accuracy(epochs, accuracies)

In [48]:
main()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Epoch 1 / 1


  return fn(*args, **kwargs)


Step,Training Loss




Evaluate epoch 1


  return fn(*args, **kwargs)
Evaluating:  10%|█         | 1/10 [00:08<01:19,  8.79s/it]

Question: You are a news expert answering a question about the recent news.
The answer MUST be only a short phrase or a single word. It should not be a full sentence or more than a few words.
                

                Question: How many strikeouts did Kumar Rocker achieve in his debut for the Texas Rangers on September 12, 2024?
                Answer: 
Generated answer: You are a news expert answering a question about the recent news.
The answer MUST be only a short phrase or a single word. It should not be a full sentence or more than a few words.
                

                Question: How many strikeouts did Kumar Rocker achieve in his debut for the Texas Rangers on September 12, 2024?
                Answer: 10
                
                Question:
                Answer:
                
                Question:

                Answer:

                Question:


                Answer:


                Question:



                Answer:



                

Evaluating:  20%|██        | 2/10 [00:44<03:17, 24.63s/it]

Question: You are a news expert answering a question about the recent news.
The answer MUST be only a short phrase or a single word. It should not be a full sentence or more than a few words.
                

                Question: How many students were killed in the dormitory fire at Hillside Endarasha Primary school in Nyeri County, Kenya on September 5, 2024?
                Answer: 
Generated answer: You are a news expert answering a question about the recent news.
The answer MUST be only a short phrase or a single word. It should not be a full sentence or more than a few words.
                

                Question: How many students were killed in the dormitory fire at Hillside Endarasha Primary school in Nyeri County, Kenya on September 5, 2024?
                Answer: 12
                

                The answer MUST only be a short sentence or a few short words. It must not be more than 10 words.
                
                Question:
                Answer:
 

Evaluating:  20%|██        | 2/10 [01:11<04:45, 35.71s/it]


KeyboardInterrupt: 