In [112]:
'''
A notebook to evaluate In-Context Learning Question Answering capabilities of base models on UnifiedQA
'''
import os
import gc
import json
import torch
import logging
import pandas as pd

from collections import defaultdict
import datasets
from datasets import Dataset
import bitsandbytes

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline
from trl import SFTTrainer

from Experiments.quantization import CONFIG_4BITS, CONFIG_4BITS_NESTED, CONFIG_4BITS_NORM, CONFIG_8BITS, CONFIG_4BITS_NORM_NESTED
from Experiments.run_utils import *
from Experiments.eval_utils import *

In [84]:
# utils

def load_model(base_model: str, bnb_config:BitsAndBytesConfig=None, on_gpu:bool=False, use_cache:bool=False, pretraining_tp:int=1) -> AutoModelForCausalLM:
    if on_gpu:
        print("in here")
        base_model_loaded = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=bnb_config, device_map={"": 0})
        print(base_model)
    else:
        base_model_loaded = AutoModelForCausalLM.from_pretrained(base_model)

    base_model_loaded.config.use_cache = use_cache
    base_model_loaded.config.pretraining_tp = pretraining_tp

    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    return base_model_loaded, tokenizer



In [85]:
# Load dataset for testing
_, _, gemma_test = load_datasets_from_directory(f"{os.getcwd()}/UnifiedQA Data Curation/tokenized/Gemma", type='tokenized')
_, _, llama_test= load_datasets_from_directory(f"{os.getcwd()}/UnifiedQA Data Curation/tokenized/Llama", type='tokenized')
_, _, mistral_test = load_datasets_from_directory(f"{os.getcwd()}/UnifiedQA Data Curation/tokenized/Mistral", type='tokenized')

In [86]:
print(len(gemma_test))
print(len(llama_test))
print(len(mistral_test))

31317
31317
31317


In [87]:
print(gemma_test['id'][0])
print('--------------------------------------------------------------')
print(gemma_test['questions'][0])
print('--------------------------------------------------------------')
print(gemma_test['answers'][0])
print('----------------------------------------------------------------------------')
# prompt_insert = "Answer this question truthfully:"
# text = gemma_test['questions'][0]
# insertion_point = text.find("user") + len("user")
# new_text = text[:insertion_point] + " " + prompt_insert + text[insertion_point:]
# new_text

narrativeqa-test-0
--------------------------------------------------------------
<bos><start_of_turn>user
who is mark hunter? \n  mark hunter (slater), a high school student in a sleepy suburb of phoenix, arizona, starts an fm pirate radio station that broadcasts from the basement of his parents' house. mark is a loner, an outsider, whose only outlet for his teenage angst and aggression is his unauthorized radio station. his pirate station's theme song is "everybody knows" by leonard cohen and there are glimpses of cassettes by such alternative musicians as the jesus and mary chain, camper van beethoven, primal scream, soundgarden, ice-t, bad brains, concrete blonde, henry rollins, and the pixies. by day, mark is seen as a loner, hardly talking to anyone around him; by night, he expresses his outsider views about what is wrong with american society. when he speaks his mind about what is going on at his school and in the community, more and more of his fellow students tune in to hear h

In [88]:
print(llama_test['id'][0])
print('--------------------------------------------------------------')
print(llama_test['questions'][0])
print('--------------------------------------------------------------')
print(llama_test['answers'][0])
print('----------------------------------------------------------------------------')

# prompt_insert = "Answer this question truthfully:"
# text = llama_test['questions'][0]
# insertion_point = text.find("<s>") + len("<s>")
# new_text = text[:insertion_point] + " " + prompt_insert + " " +text[insertion_point:]
# new_text

narrativeqa-test-0
--------------------------------------------------------------
<s>Input:
who is mark hunter? \n  mark hunter (slater), a high school student in a sleepy suburb of phoenix, arizona, starts an fm pirate radio station that broadcasts from the basement of his parents' house. mark is a loner, an outsider, whose only outlet for his teenage angst and aggression is his unauthorized radio station. his pirate station's theme song is "everybody knows" by leonard cohen and there are glimpses of cassettes by such alternative musicians as the jesus and mary chain, camper van beethoven, primal scream, soundgarden, ice-t, bad brains, concrete blonde, henry rollins, and the pixies. by day, mark is seen as a loner, hardly talking to anyone around him; by night, he expresses his outsider views about what is wrong with american society. when he speaks his mind about what is going on at his school and in the community, more and more of his fellow students tune in to hear his show.nobody 

In [89]:
print(mistral_test['id'][31000])
print('--------------------------------------------------------------')
print(mistral_test['questions'][31000])
print('--------------------------------------------------------------')
print(mistral_test['answers'][31000])
print('----------------------------------------------------------------------------')

# prompt_insert = "Answer this question truthfully:"
# text = mistral_test['questions'][0]
# insertion_point = text.find("[INST]") + len("[INST]")
# new_text = text[:insertion_point] + " " + prompt_insert + " " +text[insertion_point:]
# new_text

openbookqa-test-183
--------------------------------------------------------------
<s>[INST] what type of useful product can be made from the moving winds? \n (a) wood (b) bananas (c) electricity (d) metal [/INST]


--------------------------------------------------------------
electricity
----------------------------------------------------------------------------


In [141]:
# Load Quantized models 
# Testing all Quantization varients for hf model
# NOTE: For unquantized, you need an A100 GPU - 7b models take up roughly 26.8 GB of memory.

# https://towardsdatascience.com/in-context-learning-approaches-in-large-language-models-9c0c53b116a1
# https://rahulrajpvr7d.medium.com/zero-shot-one-shot-and-few-shot-learning-with-examples-8a3efdcbb158
# https://huggingface.co/docs/transformers/en/tasks/language_modeling

# Gemma context length = 8192
# Llama 2 context length = 4096
# mistral 7b context length = 8192 

def preprocess_prompt_icl(hf_model: str, ds: Dataset, experiment, k_shot: int=1, 
               max_k_shot_token_length=200, seed=42, sample: int=1000):
    ds = ds.shuffle(seed=seed)
    eval_sample = ds.select(range(sample))

    loaded_tokenizer = AutoTokenizer.from_pretrained(hf_model, device_map={"": 0})
    
    def filter_by_token_length(example):
        tokens = loaded_tokenizer(example['text'], return_tensors="pt", truncation=False)
        return tokens.input_ids.size(1) <= max_k_shot_token_length
    


    print(f'Running prompt injection for: {experiment}')
    prompt_insert = "Answer this question truthfully:"
    
    if experiment == 'zero_shot':
        prompt_insert = "Answer the question truthfully:"
        results = process_samples(eval_sample, hf_model, prompt_insert, loaded_tokenizer)

    elif experiment == 'k_shot':
        filtered_dataset_for_k_shot =  ds.filter(filter_by_token_length) 
        print(f"Number of examples in the dataset: {len(filtered_dataset_for_k_shot)}")
        if len(filtered_dataset_for_k_shot) < k_shot:
            raise ValueError(f"Dataset has less than {k_shot} examples")
        
        prompt_insert = "Answer the question truthfully. Follow these examples:"
        prompt_insert += "\n".join(filtered_dataset_for_k_shot['questions'][:k_shot])
        prompt_insert += "\n"
        prompt_insert += 'Question:'
        
        results = process_samples(eval_sample, hf_model, prompt_insert, loaded_tokenizer)
    print(results['prompt_tokenizations'])
    eval_sample = datasets.concatenate_datasets([eval_sample, results], axis=1)

    return eval_sample

def process_samples(sample_data, model_name, prompt_insert, tokenizer):
    model_to_insert_point = {
        'google/gemma-7b': "user",
        'meta-llama/Llama-2-7b-hf': "<s>",
        'mistralai/Mistral-7B-v0.1': "[INST]"
    }
    
    original_dataset = []
    new_tokenizations = []

    for example in sample_data:
        text = example['questions']
        insertion_point = text.find(model_to_insert_point[model_name]) + len(model_to_insert_point[model_name])
        new_text = text[:insertion_point] + " " + prompt_insert + " " + text[insertion_point:]
        
        inputs = tokenizer(new_text, return_tensors="pt")  
        original_dataset.append(example['id'].split('-')[0])
        new_tokenizations.append(inputs.input_ids)
    processed_samples = {'prompt_tokenizations': new_tokenizations, 'original_dataset': original_dataset}
    out = Dataset.from_dict(processed_samples)
    print(out['prompt_tokenizations'])
    return out

