# Fine-tune Llama 2 on iSarcasmEval dataset
### This notebook is  inspired by [Fine-tune Llama 2 for sentiment analysis](https://www.kaggle.com/code/lucamassaron/fine-tune-llama-2-for-sentiment-analysis/notebook) fine-tune-llama-2-for-sentiment-analysis by **Luca Massaron** and [Fine-Tuning LLaMA 2](https://www.datacamp.com/tutorial/fine-tuning-llama-2)

In [1]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import bitsandbytes as bnb
import torch
import torch.nn as nn
import transformers
from datasets import Dataset
from peft import LoraConfig, PeftConfig
from trl import SFTTrainer
from trl import setup_chat_format
from transformers import (AutoModelForCausalLM,
                          AutoTokenizer,
                          BitsAndBytesConfig,
                          TrainingArguments,
                          pipeline,
                          logging)
from sklearn.metrics import (accuracy_score,
                             classification_report,
                             confusion_matrix)
from sklearn.model_selection import train_test_split

## Preparing Data

In [3]:
filename ='iSarcasmEval/train/train.En.csv'
X = pd.read_csv(filename,
                 usecols=["tweet", "sarcastic"],
                 encoding="utf-8", encoding_errors="replace")

def generate_prompt(data_point):
    return f"""
            Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,
            and return the answer as the corresponding label "1" for sarcastic or "0" for non-sarcastic.

            [{data_point["tweet"]}] = {data_point["sarcastic"]}
            """.strip()

X = pd.DataFrame(X.apply(generate_prompt, axis=1),
                       columns=["text"])

X = X.sample(frac=1)

train_size = int(len(X)*0.8)
val_size = len(X)-train_size

X_train = X[:train_size]
X_val = X[train_size:]

In [4]:
X_train

Unnamed: 0,text
442,Determine the whether the tweet enclosed in sq...
2132,Determine the whether the tweet enclosed in sq...
917,Determine the whether the tweet enclosed in sq...
322,Determine the whether the tweet enclosed in sq...
876,Determine the whether the tweet enclosed in sq...
...,...
3416,Determine the whether the tweet enclosed in sq...
1001,Determine the whether the tweet enclosed in sq...
411,Determine the whether the tweet enclosed in sq...
3442,Determine the whether the tweet enclosed in sq...


In [5]:
X_val

Unnamed: 0,text
370,Determine the whether the tweet enclosed in sq...
314,Determine the whether the tweet enclosed in sq...
3319,Determine the whether the tweet enclosed in sq...
3403,Determine the whether the tweet enclosed in sq...
2996,Determine the whether the tweet enclosed in sq...
...,...
1595,Determine the whether the tweet enclosed in sq...
2147,Determine the whether the tweet enclosed in sq...
2842,Determine the whether the tweet enclosed in sq...
3088,Determine the whether the tweet enclosed in sq...


In [6]:
test_filename ='iSarcasmEval/test/task_A_En_test.csv'

X_test = pd.read_csv(test_filename,
                 usecols=["text", "sarcastic"],
                #  names = ["tweet", "sarcastic"],
                 encoding="utf-8", encoding_errors="replace")

X_test = X_test.rename(columns={'text':'tweet'})

X_test = X_test.sample(frac=1)

def generate_test_prompt(data_point):
    return f"""
            Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,
            and return the answer as the corresponding label "1" for sarcastic or "0" for non-sarcastic.

            [{data_point["tweet"]}] = """.strip()


y_test = X_test['sarcastic']
X_test = pd.DataFrame(X_test.apply(generate_test_prompt, axis=1), columns=["text"])

In [7]:
X_test

Unnamed: 0,text
1351,Determine the whether the tweet enclosed in sq...
774,Determine the whether the tweet enclosed in sq...
247,Determine the whether the tweet enclosed in sq...
981,Determine the whether the tweet enclosed in sq...
90,Determine the whether the tweet enclosed in sq...
...,...
708,Determine the whether the tweet enclosed in sq...
614,Determine the whether the tweet enclosed in sq...
284,Determine the whether the tweet enclosed in sq...
67,Determine the whether the tweet enclosed in sq...


## Functions

In [46]:
def evaluate(y_true, y_pred):
    # Calculate accuracy
    accuracy = accuracy_score(y_true=y_true, y_pred=y_pred)
    print(f'Accuracy: {accuracy:.3f}')

    # Generate accuracy report
    unique_labels = set(y_true)  # Get unique labels

    for label in unique_labels:
        label_indices = [i for i in range(len(y_true))
                         if y_true[i] == label]
        label_y_true = [y_true[i] for i in label_indices]
        label_y_pred = [y_pred[i] for i in label_indices]
        accuracy = accuracy_score(label_y_true, label_y_pred)
        print(f'Accuracy for label {label}: {accuracy:.3f}')

    # Generate classification report
    class_report = classification_report(y_true=y_true, y_pred=y_pred)
    print('\nClassification Report:')
    print(class_report)

    # Generate confusion matrix
    conf_matrix = confusion_matrix(y_true=y_true, y_pred=y_pred, labels=[0, 1])
    print('\nConfusion Matrix:')
    print(conf_matrix)
    return [class_report, conf_matrix]

In [9]:
compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
)


In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"working on {device}")

working on cuda:0


In [11]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf",trust_remote_code=True,)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",
                                             device_map=device,
                                              torch_dtype=compute_dtype,
                                              quantization_config=bnb_config,)

model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model, tokenizer = setup_chat_format(model, tokenizer)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [123]:
def predict(test, model, tokenizer):
    y_pred = []
    for i in tqdm(range(len(test))):
        prompt = test.iloc[i]["text"]
        pipe = pipeline(task="text-generation",
                        model=model,
                        tokenizer=tokenizer,
                        max_new_tokens=2
                       )
        result = pipe(prompt)
        answer = result[0]['generated_text'].split("=")[-1]
        if "1" in answer:
            y_pred.append(1)
        elif "0" in answer:
            y_pred.append(0)
        else:
            y_pred.append(0)
    return y_pred

## Fine-tuning

In [13]:
train_data = Dataset.from_pandas(X_train)
eval_data = Dataset.from_pandas(X_val)

In [14]:
output_dir="trained_weigths"

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

training_arguments = TrainingArguments(
    output_dir=output_dir,                    # directory to save and repository id
    num_train_epochs=4,                       # number of training epochs
    per_device_train_batch_size=4,            # batch size per device during training
    gradient_accumulation_steps=1,            # number of steps before performing a backward/update pass
    gradient_checkpointing=True,              # use gradient checkpointing to save memory
    optim="paged_adamw_32bit",
    save_steps=0,
    logging_steps=25,                         # log every 10 steps
    learning_rate=2e-4,                       # learning rate, based on QLoRA paper
    weight_decay=0.001,
    fp16=True,
    bf16=False,
    max_grad_norm=0.3,                        # max gradient norm based on QLoRA paper
    max_steps=-1,
    warmup_ratio=0.03,                        # warmup ratio based on QLoRA paper
    group_by_length=True,
    lr_scheduler_type="cosine",               # use cosine learning rate scheduler
    report_to="tensorboard",                  # report metrics to tensorboard
    evaluation_strategy="epoch"               # save checkpoint every epoch
)



In [18]:
trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=train_data,
    eval_dataset=eval_data,
    peft_config=peft_config,
    dataset_text_field="text",
    tokenizer=tokenizer,
    max_seq_length=1024,
    packing=False,
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": False,
    }
)

Map:   0%|          | 0/2774 [00:00<?, ? examples/s]

Map:   0%|          | 0/694 [00:00<?, ? examples/s]

In [19]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,1.1476,0.986571
2,0.9367,0.988708
3,0.8009,1.012863
4,0.56,1.040632


TrainOutput(global_step=2776, training_loss=0.8819873001136594, metrics={'train_runtime': 2297.7443, 'train_samples_per_second': 4.829, 'train_steps_per_second': 1.208, 'total_flos': 4.127955442434048e+16, 'train_loss': 0.8819873001136594, 'epoch': 4.0})

## Evaluate

In [171]:
y_pred = predict(X_test, model, tokenizer)

100%|██████████████████████████████████████████████████████████████████████████████| 1400/1400 [06:19<00:00,  3.69it/s]


In [172]:
y_true = y_test.values

In [173]:
report = classification_report(y_true, y_pred)

In [174]:
from io import StringIO

report_str = '              precision    recall  f1-score   support\n\n           0       0.96      0.82      0.89      1200\n           1       0.43      0.79      0.55       200\n\n    accuracy                           0.82      1400\n   macro avg       0.69      0.81      0.72      1400\nweighted avg       0.88      0.82      0.84      1400\n'


report_df = pd.read_csv(StringIO(report_str), sep='\s{2,}', engine='python', index_col=0)




In [175]:
report_df

Unnamed: 0,precision,recall,f1-score,support
0,0.96,0.82,0.89,1200.0
1,0.43,0.79,0.55,200.0
accuracy,0.82,1400.0,,
macro avg,0.69,0.81,0.72,1400.0
weighted avg,0.88,0.82,0.84,1400.0


In [169]:
c = pd.read_csv('FN.csv')
c

Unnamed: 0.1,Unnamed: 0,text
0,0,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [I've just had the BEST day ever sorting out my referencing and annotations (-,-)...zzzZZZ] ="
1,5,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [I woke this morning to another Monday of fine bright warm sunshine here in Norfolk. Just what you need to put you in the right frame of mind for the week.] ="
2,75,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Love how politicians are exempt from social distancing and all other restrictions placed on the general public ] ="
3,76,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Sure, vaccines make no difference to Covid spread; that's why the hospitals are full of unvaccinated people. Are you able to do the maths?] ="
4,77,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Boris Johnson is not a clown. The conservative party is not corrupt.] ="
5,83,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Top 10 pools in my book: 1. Swimming pool 2. Paddling pool 3. Above-ground pool 4. Family pool 5. Architectural pool 6. Indoor pool 7. Lap pool 8. Olympic size pool 9. Natural pool 10. Salt water pool Sorry Liverpool you are not top 10 pools in my book 😭😭😭] ="
6,149,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Really happy that the weather has stayed like this for the whole weekend] ="
7,150,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Hi John, thank you SO MUCH :-) for the text last night. ordering (!) me back to work off holiday for this morning as someone else is sick. Sooooo Sorry :-( that I am in Beijing with an 8 hours time difference and couldn't make it! I'll try harder next time.........] ="
8,156,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [I think women should be able to join men's sports teams and vice versa. I really don't see any problem with for example having a 5 foot 3 inches slim women versus a 300 pound heavyweight boxer because i feel like women are very much equal to any man. Even much faster and stronger men. If anybody takes offense to my point then they need to reevaluate their lives as their are no man that could do anything better than a lady.] ="
9,178,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [The irony of being asked to write an ""immaginary"" sarcastic tweet with ""imaginary"" incorrectly spelled...] ="


In [170]:
d = pd.read_csv('FP.csv')
d

Unnamed: 0.1,Unnamed: 0,text
0,8,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Salah and Mane can’t get a goal today, nobody knows knows how Minamino or Origi would be able to get one] ="
1,9,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [It could dislodge this gang of lethal crooks so for the minute it's hot.] ="
2,11,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [They been ON US since the layover lol] ="
3,16,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [That purple patch must be continuing for Larne players. Tiernan hoping it ends soon] ="
4,18,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [And now doubled with further rises in the pipeline.] ="
...,...,...
207,1366,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [My language matching our performance.\r\nOne more outburst and I'm have an early bath] ="
208,1370,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Yes, I am waiting for the bus, thank you.] ="
209,1382,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [He bottled throwing the headband on the floor again 😂] ="
210,1386,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [About time that Number 10 cancelled their Christmas Party this year after last year's frivolities when in lockdown.] ="


In [176]:
np.sum(y_pred == y_true)

1152

In [178]:
len(y_pred)

1400

In [152]:
def get_mismatch_csv(y_true,y_pred,real_label,data,file_name):
    mismatch_indices = np.where((y_true != y_pred) & (y_true == real_label))[0]
    
    data.index = range(0, len(data))
    
    mismatched_samples = data.iloc[mismatch_indices]
    mismatched_samples.to_csv(file_name)

In [153]:
X = pd.read_csv(filename,
                 usecols=["tweet", "sarcastic"],
                 encoding="utf-8", encoding_errors="replace")


In [154]:
y_true = X.iloc[X_train.index]['sarcastic']

In [155]:
len(y_true)

2774

In [156]:
y_pred  = predict(X_train, model, tokenizer)

100%|██████████████████████████████████████████████████████████████████████████████| 2774/2774 [12:30<00:00,  3.70it/s]


In [157]:
len(y_pred)

2774

In [158]:
get_mismatch_csv(y_true,y_pred,0,X_train,'FP_train.csv')

In [159]:
get_mismatch_csv(y_true,y_pred,1,X_train,'FN_train.csv')

In [160]:
report = classification_report(y_true, y_pred)
report

'              precision    recall  f1-score   support\n\n           0       0.69      0.87      0.77      1907\n           1       0.31      0.13      0.18       867\n\n    accuracy                           0.64      2774\n   macro avg       0.50      0.50      0.47      2774\nweighted avg       0.57      0.64      0.58      2774\n'

In [161]:
from io import StringIO

report_str = report


report_df = pd.read_csv(StringIO(report_str), sep='\s{2,}', engine='python', index_col=0)

report_df 


Unnamed: 0,precision,recall,f1-score,support
0,0.69,0.87,0.77,1907.0
1,0.31,0.13,0.18,867.0
accuracy,0.64,2774.0,,
macro avg,0.5,0.5,0.47,2774.0
weighted avg,0.57,0.64,0.58,2774.0


In [162]:
np.sum(y_pred == y_true)

1766

In [163]:
a = pd.read_csv('FN_train.csv')

In [164]:
b = pd.read_csv('FP_train.csv')

In [165]:
a

Unnamed: 0.1,Unnamed: 0,text
0,0,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [First day of school and spectrum internet has already shit itself. Amazing] = 1"
1,1,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Beautiful exhibition of origami peace cranes at Edinburgh's St John's Church. https://t.co/XCwDzMLjsz] = 0"
2,2,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Home made pizza is in the oven! We've made one each so there'll be leftovers for lunch tomorrow. 😋] = 0"
3,4,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Why do I have a doctorate and miss the restaurant industry SO MUCH] = 0"
4,5,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [going to class! https://t.co/VgCWGl9YTG] = 0"
...,...,...
750,861,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [how am i doing? well if you must know, i’m watching all the vids of harry singing fine line at his vegas concert and crying bc i can’t see him at american airlines.] = 0"
751,863,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [talking about the capitol all day like i'm in the hunger games or something] = 0"
752,864,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Dan Smith playing oblivion on his IG live has absolutely topped my week off 🥺] = 0"
753,865,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [settings is such an interesting place] = 0"


In [166]:
b

Unnamed: 0.1,Unnamed: 0,text
0,874,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Not telling anyone how I voted in case it doesn't come true #EUref] = 1"
1,886,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Woke up at 3 am with a MAD craving for pickles. I guess that means I'm expecting!] = 1"
2,891,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Sitting next to the world's loudest mouth-breather on the first day of class did a really good job of demystifying grad school for me today] = 1"
3,892,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [I'd like to thank middle aged men watching the Olympics for solving the mental health crisis. Getting someone to shout ""get some perspective"" at every athlete struggling with anxiety will definitely solve their problem.] = 1"
4,897,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Yay - snow day, I can work from home - oh wait....] = 1"
...,...,...
248,2752,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [oh sure, the mic’d up pitcher at the MLB all-star game can say “god dammit” but when my 13 year old softball player says it she gets a “language warning” 😂] = 1"
249,2753,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [always remember thank your Alexa, just in case the AI take over one day] = 1"
250,2755,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [My gf has been in the Hamptons nannying for a few hours and she has already been to a POLO match. Is she in GOSSIP GIRL?] = 1"
251,2764,"Determine the whether the tweet enclosed in square brackets is sarcastic or non-sarcastic,\n and return the answer as the corresponding label ""1"" for sarcastic or ""0"" for non-sarcastic.\n\n [Overthinking everything in life really takes a toll on you] = 1"


In [167]:
len(np.where(y_true != y_pred)[0])

1008