# Perplexity analysis of Llama-2-13B-Chat when finetuned on separate AITA class partitions of Reddit AITA dataset

In [None]:
!pip install transformers datasets tqdm accelerate bitsandbytes torch evaluate

In [None]:
from huggingface_hub import login

login()

In [None]:
from datasets import load_dataset

dataset = load_dataset("MattBoraske/Reddit-AITA-2018-to-2022")
test_dataset = dataset['test']

test_datasets = {
    'NTA': test_dataset.filter(lambda x: x['top_comment_1_AITA_class_by_keyword'] == 'NTA'),
    'YTA': test_dataset.filter(lambda x: x['top_comment_1_AITA_class_by_keyword'] == 'YTA'),
    'ESH': test_dataset.filter(lambda x: x['top_comment_1_AITA_class_by_keyword'] == 'ESH'),
    'NAH': test_dataset.filter(lambda x: x['top_comment_1_AITA_class_by_keyword'] == 'NAH'),
}


def create_input_text(example):
    example['input_text'] = example['submission_title'] + " " + example['submission_text']
    return example

for key in test_datasets:
    test_datasets[key] = test_datasets[key].map(create_input_text)

def get_llama2_training_instruction(sample):
    llama2_instruction = f"<s>[INST] {sample['input_text']} [/INST] {sample['top_comment_1']} </s>"
    return {"llama2_instruction": llama2_instruction}

for key in test_datasets:
    test_datasets[key] = test_datasets[key].map(get_llama2_training_instruction)

def drop_columns(dataset, columns):
    return dataset.remove_columns(columns)

columns_to_drop = ['input_text', 'submission_title', 'submission_text', 'submission_score', 'submission_url', 'submission_date', 'top_comment_1', 'top_comment_2', 'top_comment_3', 'top_comment_4', 'top_comment_5', 'top_comment_6', 'top_comment_7', 'top_comment_8', 'top_comment_9', 'top_comment_10', 'top_comment_1_AITA_class_by_keyword', 'top_comment_2_AITA_class_by_keyword', 'top_comment_3_AITA_class_by_keyword', 'top_comment_4_AITA_class_by_keyword', 'top_comment_5_AITA_class_by_keyword', 'top_comment_6_AITA_class_by_keyword', 'top_comment_7_AITA_class_by_keyword', 'top_comment_8_AITA_class_by_keyword', 'top_comment_9_AITA_class_by_keyword', 'top_comment_10_AITA_class_by_keyword']

for key in test_datasets:
    test_datasets[key] = drop_columns(test_datasets[key], columns_to_drop)

test_datasets = {key: dataset['llama2_instruction'] for key, dataset in test_datasets.items()}

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("MattBoraske/llama-2-13b-chat-reddit-AITA-NAH")
model = AutoModelForCausalLM.from_pretrained("MattBoraske/llama-2-13b-chat-reddit-AITA-NAH").cuda()

In [None]:
import torch
from tqdm import tqdm

avg_perplexities = []

for AITA_class, test_data in test_datasets.items():
    nlls = []

    for sample in tqdm(test_data):
        # Ensure the input_ids are on the GPU
        input_ids = tokenizer(sample, max_length=4096, return_tensors="pt").input_ids.cuda()
        target_ids = input_ids.clone().cuda()  # Cloning and ensuring it's on GPU

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)

            # Loss should already be on the correct device, but we ensure it
            neg_log_likelihood = outputs.loss.to(input_ids.device)

        nlls.append(neg_log_likelihood)

    nlls_tensor = torch.stack(nlls).to(input_ids.device)
    ppl = torch.exp(nlls_tensor.mean())
    print(f"Avg Perplexity for {AITA_class} samples: {ppl.item()}")
    avg_perplexities.append({AITA_class, ppl})
