# SFT on News Memorization

This file is used to train a supervised fine-tuned model for the news memorization task.

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from datasets import load_dataset, Dataset
from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
from trl import SFTTrainer

import matplotlib.pyplot as plt

2024-12-31 01:42:45.937930: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-31 01:42:46.139543: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-31 01:42:46.139580: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-31 01:42:46.176912: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-31 01:42:46.257761: I tensorflow/core/platform/cpu_feature_guar

In [2]:
# setup the huggingface key
import json
with open('../apikeys.json', 'r') as f:
    apikeys = json.load(f)
hf_key = apikeys['hf_api_key']


## Setup Config for SFT Training

Before running the training, we need to setup the config for the SFT training. 


In [3]:
from transformers import BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig
from trl import DPOConfig

In [4]:
# SFT General Configs
sft_model_name = "meta-llama/Meta-Llama-3-8B"
sft_data_dir = "../datasets/latest_news/latest_news_memorization.csv"
sft_output_dir = "./sft_models/latest_news_memorization"
sft_log_dir = "./sft_logs/latest_news_memorization"

In [6]:
# bnb Configs
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [7]:
# peft Configs (Lora Config)
peft_config = LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "k_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

In [8]:
# Training Arguments
training_args = TrainingArguments(
    per_device_train_batch_size=2,  # batch size per device
    gradient_accumulation_steps=2,  # number of updates steps to accumulate before performing a backward/update pass
    gradient_checkpointing =False,  # disable gradient checkpointing
    max_grad_norm= 0.3,  # max gradient norm
    num_train_epochs=100,  # number of training epochs
    save_steps= 100,  # save the model every x steps (step = batch size * gradient accumulation steps)
    learning_rate=2e-4,  # learning rate
    bf16=True,  # use bf16 for training
    save_total_limit=2,  # save the best 2 checkpoints (1 best and 1 last)
    evaluation_strategy="epoch",  # evaluate the model every x epochs
    output_dir=sft_output_dir,  # output directory
    logging_dir=sft_log_dir,  # logging directory
    optim="paged_adamw_32bit",  # optimizer
    lr_scheduler_type="cosine",  # learning rate scheduler type
    warmup_ratio=0.05,  # warmup ratio
    remove_unused_columns=False  # remove unused columns
)

generate_max_length = 64
tokenizer_max_length = 512  # max length for tokenizer

## Load Dataset

Then we need to load the dataset. Specifically, the dataaset is a csv file with the following columns: `id`, `prompt`, `answer`, `article_title`, `question`, `fact`, `article_text`, `used_in_analysis`. For sft learning, we only need the `prompt` and `answer` columns.

In [9]:
def load_train_dataset():
    dataset = pd.read_csv(sft_data_dir)
    dataset = Dataset.from_pandas(dataset)  # convert to huggingface dataset
    tokenizer = AutoTokenizer.from_pretrained(sft_model_name, token=hf_key)
    tokenizer.pad_token = tokenizer.eos_token  # set pad token to eos token
    
    def tokenize_function(examples):
        return tokenizer(examples["prompt"] + examples["answer"], 
                         padding="max_length", 
                         max_length=tokenizer_max_length,
                         truncation=True,
                         return_tensors="pt",
                         )
        
    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["id", "prompt", "answer", "article_title", "question", "fact", "article_text", "used_in_analysis"])    
    return tokenized_dataset

In [None]:
def load_eval_dataset():
    dataset = pd.read_csv(sft_data_dir)
    dataset = Dataset.from_pandas(dataset)  # convert to huggingface dataset
    # remove all the columns except for `prompt`, `answer`
    dataset = dataset.remove_columns(["id", "article_title", "question", "fact", "article_text", "used_in_analysis"])
    return dataset

## Load Model

Then we need to load the model. We use the `meta-llama/Meta-Llama-3-8B` model as the base model to be fine-tuned.

In [10]:
def load_model():
    model = AutoModelForCausalLM.from_pretrained(sft_model_name, 
                                                 quantization_config=bnb_config,
                                                 device_map="auto",  # use auto device mapping (GPU)
                                                 token=hf_key,
                                                 )
    model = prepare_model_for_kbit_training(model)
    tokenizer = AutoTokenizer.from_pretrained(sft_model_name, token=hf_key)
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

## Evaluation (Accuracy)

We need to evaluate the model's accuracy on the dataset. Specifically, we need to check if the model's answer is the same as the ground truth answer.

In [None]:
def evaluate_accuracy(eval_dataset, model, tokenizer):
    correct = 0
    total = len(eval_dataset)
    for example in eval_dataset:
        prompt = example["prompt"]
        oracle_answer = example["answer"]

        # Generate model output
        input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids
        output_ids = model.generate(input_ids, max_length=64)
        generated_answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        # Compare generated answer with oracle
        if generated_answer.strip().lower() == oracle_answer.strip().lower():
            correct += 1

    accuracy = correct / total
    return accuracy

In [None]:
def plot_accuracy(epochs, accuracies):
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, accuracies, label='Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('News Memorization Accuracy v.s. Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{sft_output_dir}/accuracy.png")
    plt.show()

## Main Training Function

In [None]:
def main():
    model, tokenizer = load_model()
    train_dataset = load_train_dataset()
    eval_dataset = load_eval_dataset()
    model = get_peft_model(model, peft_config)  # apply peft to the model to add LoRA layers
    
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        args=training_args,
        tokenizer=tokenizer,
        max_seq_length=tokenizer_max_length,
    )
    
    epochs = []
    accuracies = []
    
    # Training and evaluation loop
    for epoch in range(int(training_args.num_train_epochs)):
        print(f"Epoch {epoch + 1} / {training_args.num_train_epochs}")
        trainer.train()  # Train for one epoch
        accuracy = evaluate_accuracy(eval_dataset, model, tokenizer)  # Evaluate accuracy
        epochs.append(epoch + 1)
        accuracies.append(accuracy)
        print(f"Accuracy after epoch {epoch + 1}: {accuracy:.4f}")
    
    trainer.train()
    trainer.save_model(sft_output_dir)
    plot_accuracy(epochs, accuracies)