In [1]:
from datasets import load_dataset
dataset = load_dataset("tau/commonsense_qa")

In [2]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
base_model = "NousResearch/Llama-2-7b-chat-hf"
compute_dtype = getattr(torch, "float16")

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=False,
)


# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [3]:
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=quant_config,
    device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [4]:
pipe = pipeline(task="text-generation", model=base_model, tokenizer=base_model, max_new_tokens=10)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
genTest = "The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change? (A) ignore (B) enforce (C) authoritarian (D) yell at (E) avoid. Please select an answer."

result = pipe(genTest)

print(result)


[{'generated_text': 'The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change? (A) ignore (B) enforce (C) authoritarian (D) yell at (E) avoid. Please select an answer.'}]


In [6]:
import pickle

modelName = "Llama-2-7b-chat-hf"

fileName = modelName + "generatedMaxNewTokens10.pickle"


try:
    with open(fileName, 'rb') as file:
        generatedOutputs = pickle.load(file)
        print("loaded " + str(len(generatedOutputs)) + " rows")
except FileNotFoundError:
    print("File not found making new object")
    generatedOutputs = {}

loaded 1786 rows


In [7]:
import time
import re
answerPattern = r'[A-E]\)'

lastSave = time.time()
newAnswers = 0
for q in dataset["train"]:
    if q["question"] not in generatedOutputs:
        prompt = q["question"]
        for i in range(len(q["choices"]["text"])):
            prompt += " (" + q["choices"]["label"][i] + ") " + q["choices"]["text"][i]
        prompt += "."
        #print(newAnswers % 10)
        print("generating result for the prompt: " + prompt)
        result = pipe(prompt)
        #print(result[0]['generated_text'])
        #print(len(result[0]['generated_text']))
        #print(len(prompt))
        #print(len(result) - len(prompt))
        genText = result[0]['generated_text'][len(prompt):]
        print("result is: " + genText)

        print("correct answer is " + q["answerKey"])

        potentialMatches = re.findall(answerPattern, genText)
        print(potentialMatches)
        if  len(potentialMatches) != 0:
            parsedAnswer = re.findall(answerPattern, genText)[0][0]
            print("parsed answer is " + parsedAnswer)
            generatedOutputs[q["question"]] = {"prompt" : prompt, "genText" : genText, "parsed" : parsedAnswer}
            newAnswers += 1
        else:
            #enter some infor indicating that it wasn't found
            generatedOutputs[q["question"]] = {"prompt" : prompt, "genText" : genText, "parsed" : "None"}
            print("not able to parse an answer###################################################################################################################################")
    print("Number of new answers: " + str(newAnswers))
    if newAnswers % 10 == 1:
        with open(fileName, 'wb') as file:
            # Serialize and save the object to the file
            print("saved " + str(newAnswers) + " new answers in " + str(time.time() - lastSave))
            pickle.dump(generatedOutputs, file)
        lastSave = time.time()

0
generating result for the prompt: How do geese normally get from place to place? (A) carried by people (B) guard house (C) fly (D) follow ultralight airplane (E) group together.
result is: 

Answer: C) fly. Geese
correct answer is C
['C)']
parsed answer is C
saved 1 new answers in 14.390068292617798
saved 1 new answers in 0.0005252361297607422
saved 1 new answers in 0.00029730796813964844
saved 1 new answers in 0.0002932548522949219
saved 1 new answers in 0.00029277801513671875
saved 1 new answers in 0.0002837181091308594
saved 1 new answers in 0.0002982616424560547
saved 1 new answers in 0.00031185150146484375
saved 1 new answers in 0.0027272701263427734
saved 1 new answers in 0.0003190040588378906
saved 1 new answers in 0.00030517578125
saved 1 new answers in 0.00033092498779296875
saved 1 new answers in 0.0003314018249511719
saved 1 new answers in 0.0003209114074707031
saved 1 new answers in 0.00031948089599609375
saved 1 new answers in 0.0003082752227783203
saved 1 new answers in

KeyboardInterrupt: 

In [None]:
import re

testText = "asdf\n\n\nThe Answer is (E) Because I think it is not (A)."

answerPattern = r'\([A-E]\)'


print(re.findall(answerPattern, testText)[0][1])

E
