In [1]:
import pandas as pd
import numpy as np
import copy

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, Adafactor

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Prompt engineering https://community.openai.com/t/prompt-engineering-for-rag/621495

In [None]:
# Fine-tuning is based on the following:
#    A base model
#    A base tokenizer
#    A set of desired (input, output) pairs
#        Importantly, there is some nuance with how the chat template is applied to the input, output pairs
#        This notebook provides a framework for fine-tuning with a system prompt
#            and a fixed Yes/No question with a fixed Yes/No output.

In [2]:
cache_dir = "../assets/models"
model_path = "meta-llama/Llama-3.2-1B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_path, cache_dir=cache_dir, use_safetensors=True
)

tokenizer = AutoTokenizer.from_pretrained(
    model_path, cache_dir=cache_dir, use_safetensors=True, padding_side="left"
)

tokenizer.pad_token = tokenizer.eos_token

In [3]:
AVeriTeC = pd.read_json('../data/AVeriTeC/train.json')
AVeriTeC = AVeriTeC.rename(columns={"claim": "sentence"})
AVeriTeC['label'] = 'Yes'
AVeriTeC = AVeriTeC.filter(items=['sentence','label'])
AVeriTeC.head()

Unnamed: 0,sentence,label
0,Hunter Biden had no experience in Ukraine or i...,Yes
1,Donald Trump delivered the largest tax cuts in...,Yes
2,"In Nigeria … in terms of revenue share, 20% go...",Yes
3,Biden has pledged to stop border wall construc...,Yes
4,"After the police shooting of Jacob Blake, Gov....",Yes


In [4]:
class BinaryClassificationTuner:
    def __init__(self, model, tokenizer, train_dataset, messages):
        self.model = model
        self.tokenizer = tokenizer
        self.train_dataset = train_dataset
        self.messages = messages

    def train(self, epochs):
        optimizer = Adafactor(model.parameters(), weight_decay=0.01)
        train_dataset = self._prepare_train_data()
        for train_instance in train_dataset:
            for _ in range(epochs):
                logits = model(train_instance['chat_template_input_ids'], use_cache=False)['logits']
                loss = self._calculate_loss(logits, train_instance['label_input_ids']).mean()
                
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                print("loss: ", loss.item())

        return 
        
    def _prepare_train_data(self):
        """"
        Returns the train dataset as a list of dictionaries, where each is a record with chat
        """
        train_dataset = self.train_dataset.to_dict(orient='records')
        train_dataset_prepared = []

        for train_instance in train_dataset:
            messages = copy.deepcopy(self.messages)
            for message in messages:
                if message['role'] == "user":
                    message['content'] = message['content'].replace('__SENTENCE__',train_instance['sentence'])
                    break
                
            chat_template_input_ids = tokenizer.apply_chat_template(messages, tokenize=True, continue_final_message=True, add_generation_prompt=False, return_tensors="pt")
            chat_template_input_ids = chat_template_input_ids[0, :-1].reshape(1,-1)
            
            label_input_ids = tokenizer(train_instance['label'], add_special_tokens=False, return_tensors="pt", padding="max_length", max_length=chat_template_input_ids.shape[1])['input_ids']
            label_input_ids = torch.where(label_input_ids != tokenizer.pad_token_id, label_input_ids, -100)

            train_dataset_prepared.append({'chat_template_input_ids': chat_template_input_ids,'label_input_ids': label_input_ids})

        return train_dataset_prepared

    def _calculate_loss(self, logits, labels):
        loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
        cross_entropy_loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
        return cross_entropy_loss

## Demo

In [5]:
messages = [
    {"role": "system", "content": "You are an AI agent used to determine whether or not a sentence is a factual claim. Only respond with Yes or No",},
    {"role": "user", "content": "Is the following sentence a factual claim? __SENTENCE__"},
    {"role": "assistant", "content": ""}
]
bct = BinaryClassificationTuner(model, tokenizer, AVeriTeC.head(2), messages)
bct.train(epochs=2)

[{'role': 'system', 'content': 'You are an AI agent used to determine whether or not a sentence is a factual claim. Only respond with Yes or No'}, {'role': 'user', 'content': 'Is the following sentence a factual claim? __SENTENCE__'}, {'role': 'assistant', 'content': ''}]
loss:  0.006558963563293219
loss:  0.005200252402573824
loss:  3.193731390638277e-05
loss:  0.0


In [None]:
messages = [
    {"role": "system", "content": "You are a yes/no answering bot. Only respond to questions with Yes or No",},
    {"role": "user", "content": "Is the capital of New York state New York City?"},
    {"role": "assistant", "content": ""}
]
answer = "Yes"
chat_template = tokenizer.apply_chat_template(messages, tokenize=False, continue_final_message=True)
chat_template_input_ids = tokenizer.apply_chat_template(messages, tokenize=True, continue_final_message=True, add_generation_prompt=False, return_tensors="pt")
chat_template_input_ids = chat_template_input_ids[0, :-1].reshape(1,-1)

label_tokenized = tokenizer([answer], add_special_tokens=False, return_tensors="pt", padding="max_length", max_length=chat_template_input_ids.shape[1])['input_ids']

# -100 comes from the Llama documentation, recommendation for loss
label_tokenized_fixed = torch.where(label_tokenized != tokenizer.pad_token_id, label_tokenized, -100)

# You can use the following to test what the geneartion would complete
#print(tokenizer.batch_decode(model.generate(chat_template_input_ids, max_new_tokens = 1))[0])

In [None]:
# You can use the following to test what the geneartion would complete
# Test a before resposne
print(tokenizer.batch_decode(model.generate(chat_template_input_ids, max_new_tokens = 1))[0])

In [None]:
optimizer = Adafactor(model.parameters(), weight_decay=0.01)

In [None]:
def calculate_loss(logits, labels):
    loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
    cross_entropy_loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
    return cross_entropy_loss

In [None]:
for _ in range(3):
    logits = model(chat_template_input_ids, use_cache=False)["logits"]
    loss = calculate_loss(logits, label_tokenized_fixed).mean()

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print("loss: ", loss.item())


In [None]:
print(tokenizer.batch_decode(model.generate(chat_template_input_ids, max_new_tokens = 1))[0])