In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import re
import torch

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

from tqdm import tqdm

from banking_77_constants import banking77_label_map

### Load Model

In [None]:
# model_name = "meta-llama/Llama-3.2-3B-Instruct"
model_name = "meta-llama/Llama-3.2-1B-Instruct"

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

### Get Llama reserverd special tokens and IDs

In [4]:
prefix_token_strs, prefix_token_ids = [], []

# llama reserved 250 special tokens
for i in range(251):
    prefix_token_strs.append(f"<|reserved_special_token_{i}|>")

prefix_token_ids = tokenizer.convert_tokens_to_ids(prefix_token_strs)

### Load Data

In [None]:
# Define a function to map numerical labels to string labels
def map_labels(example):
    # Map the numerical label to the string label
    example['label_str'] = banking77_label_map[example['label']]
    return example

# Load the PolyAI/banking77 dataset
dataset = load_dataset("PolyAI/banking77")

# Access the training set
train_dataset = dataset["train"]
test_dataset = dataset["test"]

# Apply the function to the dataset
train_dataset = train_dataset.map(map_labels)
test_dataset = test_dataset.map(map_labels)

train_dataset = train_dataset.shuffle(seed=42)
test_dataset = test_dataset.shuffle(seed=42)

# Check the updated dataset
train_dataset[:2]

### Define Prompt Template

In [6]:
# Define the classes part first, while allowing to add the user query in later
prompt_template = \
"""## Instructions
Classify the provided piece of text into one of the predefined classes.

## Classes
{classes}

## Output Format
Provide your answer in <answer></answer> XML tags. Output the xml tags and answer only.

## Input Text
{{text}}

## Answer""".format(classes="\n".join(banking77_label_map.values()))

# Uncomment below to view prompt
# print(prompt_template)

In [7]:
# We define an answer parser that extracts the answer based on the format the prompt tempalte defines
def parse_value_from_xml_with_regex(xml_string, tag_name):
    
    pattern = f'<{tag_name}>(.*?)</{tag_name}>'
    match = re.search(pattern, xml_string, re.DOTALL)  # re.DOTALL allows matching across multiple lines
    
    if match:
        return match.group(1)
    else:
        return ""

assert parse_value_from_xml_with_regex("<answer>foo</answer>", "answer") == "foo"

### Create test dataset with prompt template
This is the baseline dataset so we will not add a soft prefix prompt. See below for when we add the prefix tokens

In [None]:
def create_test_messages(row):
    return {"messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt_template.format(text=row["text"])}
    ]}

test_dataset_no_prefix = test_dataset.map(create_test_messages)

test_dataset_no_prefix[:2]

## Baseline

In [None]:
# running non-finetuned model on just 1 row, change i as desired

i = 0
user_query = test_dataset_no_prefix[i]["text"]
messages = test_dataset_no_prefix[i]["messages"]
golden_answer = test_dataset_no_prefix[i]["label_str"]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
).to(model.device)

outputs = model.generate(
    input_ids,
    max_new_tokens=32,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.001,
    top_p=0,
    pad_token_id=tokenizer.eos_token_id,
)

response = outputs[0][input_ids.shape[-1]:]
uncleaned_response = tokenizer.decode(response, skip_special_tokens=False)
parsed_answer = parse_value_from_xml_with_regex(tokenizer.decode(response, skip_special_tokens=True), "answer")

print("User query:")
print(user_query)
print()
print("Model response (with special tokens):")
print(uncleaned_response)
print()
print("Parsed model response (without tokens):")
print(parsed_answer)
print()
print("Correct answer:")
print(golden_answer)

In [10]:
# # Running on sample of test dataset. We're only using the first 300 rows in this case

pred_ls, golden_ls = [], []
num_correct, num_total = 0, 0

# for i in tqdm(range(len(test_dataset_no_prefix))):
for i in tqdm(range(300)):
    messages = test_dataset_no_prefix[i]["messages"]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(model.device)

    outputs = model.generate(
        input_ids,
        max_new_tokens=128,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.001,
        top_p=0,
        pad_token_id=tokenizer.eos_token_id,
    )

    response = outputs[0][input_ids.shape[-1]:]
    response = tokenizer.decode(response, skip_special_tokens=False)
    pred = parse_value_from_xml_with_regex(response, "answer")
    
    pred_ls.append(pred)
    golden_ls.append(test_dataset_no_prefix[i]["label_str"])

    if pred == test_dataset_no_prefix[i]["label_str"]:
        num_correct += 1
    num_total += 1

accuracy = num_correct / num_total
print(f"Accuracy: {accuracy}")

## Train Model

In [11]:
# Setting hyperparameters

NUM_SPECIAL_TOKENS_IN_PREFIX = 32
LEARNING_RATE = 2e-4
BATCH_SIZE = 4
WARMUP_RATIO = 0.1
WEIGHT_DECAY = 0.01

In [None]:
# prefix will be comprised of n special tokens 
prefix = "".join(prefix_token_strs[:NUM_SPECIAL_TOKENS_IN_PREFIX])

# We create a training dataset that includes the answer
# We also create another test dataset, this time with the prefix for the finetuned model

def create_prefix_messages(row):
    return {"messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prefix + prompt_template.format(text=row["text"])},
        {"role": "assistant", "content": "<answer>" + row["label_str"] + "</answer>"}
    ]}

def create_prefix_messages_no_answer(row):
    return {"messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prefix + prompt_template.format(text=row["text"])}
    ]}


train_dataset = train_dataset.map(create_prefix_messages)
test_dataset_with_prefix = test_dataset.map(create_prefix_messages_no_answer)

train_dataset[:2]

In [13]:
# Freeze all parameters except the embedding layer
# Add the hook to zero out non-special token gradients

for param in model.parameters():
    param.requires_grad = False

model.get_input_embeddings().weight.requires_grad = True

embeddings_to_update = torch.tensor(prefix_token_ids[:NUM_SPECIAL_TOKENS_IN_PREFIX], dtype=torch.long)

# Ensure embeddings_to_update is on the correct device
embeddings_to_update = embeddings_to_update.to(model.device)

def grad_hook(grad):
    mask = torch.zeros_like(grad)
    mask[embeddings_to_update] = 1.0
    
    masked_grad = grad * mask
    return masked_grad

hook_handle = model.get_input_embeddings().weight.register_hook(grad_hook)

In [None]:
# only train on completion tokens
response_template = "<|start_header_id|>assistant<|end_header_id|>"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

trainer = SFTTrainer(
    model,
    train_dataset=train_dataset,
    data_collator=collator,
    args = TrainingArguments(
        per_device_train_batch_size = BATCH_SIZE,
        gradient_accumulation_steps = 1,
        warmup_ratio = WARMUP_RATIO,
        num_train_epochs = 1, # Set this for 1 full training run.
        learning_rate = LEARNING_RATE,
        fp16 = False, # switch these depending if you're GPU supports BF16
        bf16 = True,
        logging_steps = 16,
        optim = "adamw_8bit",
        weight_decay = WEIGHT_DECAY,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        gradient_checkpointing=True
    )
)

In [None]:
trainer.train()
hook_handle.remove()

In [None]:
# Running on sample of test dataset; this time with the newly trained prefix

pred_ls, golden_ls = [], []
num_correct, num_total = 0, 0

# for i in tqdm(range(len(test_dataset_no_prefix))):
for i in tqdm(range(300)):
    messages = test_dataset_with_prefix[i]["messages"]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(model.device)

    outputs = model.generate(
        input_ids,
        max_new_tokens=128,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.001,
        top_p=0,
        pad_token_id=tokenizer.eos_token_id,
    )

    response = outputs[0][input_ids.shape[-1]:]
    response = tokenizer.decode(response, skip_special_tokens=False)
    pred = parse_value_from_xml_with_regex(response, "answer")
    
    pred_ls.append(pred)
    golden_ls.append(test_dataset_no_prefix[i]["label_str"])

    if pred == test_dataset_no_prefix[i]["label_str"]:
        num_correct += 1
    num_total += 1

accuracy = num_correct / num_total
print(f"Accuracy: {accuracy}")

### Other Benchmarks

| Num Prefix Tokens | 16    | 32    | 64    |
| :---------------- | :---: | :---: | :---: |
| Llama 3B          | 0.79  | 0.83  | 0.8266|
| Llama 1B          | 0.6466| 0.6766| 0.7333|

In [None]:
# Confirming that non-prefix weights have not been changed by running no prefix dataset again

pred_ls, golden_ls = [], []
num_correct, num_total = 0, 0

# for i in tqdm(range(len(test_dataset_no_prefix))):
for i in tqdm(range(300)):
    messages = test_dataset_no_prefix[i]["messages"]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(model.device)

    outputs = model.generate(
        input_ids,
        max_new_tokens=128,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.001,
        top_p=0,
        pad_token_id=tokenizer.eos_token_id,
    )

    response = outputs[0][input_ids.shape[-1]:]
    response = tokenizer.decode(response, skip_special_tokens=False)
    pred = parse_value_from_xml_with_regex(response, "answer")
    
    pred_ls.append(pred)
    golden_ls.append(test_dataset_no_prefix[i]["label_str"])

    if pred == test_dataset_no_prefix[i]["label_str"]:
        num_correct += 1
    num_total += 1

accuracy = num_correct / num_total
print(f"Accuracy: {accuracy}")