In [3]:
import transformers
import torch

model_id = "/home/23_zxx/workspace/llama3-ft/Meta-Llama-3-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model = model_id,
    model_kwargs = {"torch_dtype":torch.bfloat16},
    device = "cuda",
)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
import os
import json
def load_raw_dataset(data_path, data_name, split):
    with open(os.path.join(data_path, data_name, f'{split}.json')) as f:
        raw_dataset = json.load(f)
    return raw_dataset

raw_dataset = load_raw_dataset('/home/23_zxx/workspace/llama3-ft/Llama3-Tutorial/data','ProntoQA','dev') 
print(raw_dataset[0])

{'id': 'ProntoQA_1', 'context': 'Jompuses are not shy. Jompuses are yumpuses. Each yumpus is aggressive. Each yumpus is a dumpus. Dumpuses are not wooden. Dumpuses are wumpuses. Wumpuses are red. Every wumpus is an impus. Each impus is opaque. Impuses are tumpuses. Numpuses are sour. Tumpuses are not sour. Tumpuses are vumpuses. Vumpuses are earthy. Every vumpus is a zumpus. Zumpuses are small. Zumpuses are rompuses. Max is a yumpus.', 'question': 'Is the following statement true or false? Max is sour.', 'options': ['A) True', 'B) False'], 'answer': 'B', 'explanation': ['Max is a yumpus.', 'Each yumpus is a dumpus.', 'Max is a dumpus.', 'Dumpuses are wumpuses.', 'Max is a wumpus.', 'Every wumpus is an impus.', 'Max is an impus.', 'Impuses are tumpuses.', 'Max is a tumpus.', 'Tumpuses are not sour.', 'Max is not sour.']}


In [None]:
def load_in_context_examples(demonstration_path,dataset_name,mode):
    with open(os.path.join(demonstration_path, f'{dataset_name}_{mode}.txt')) as f:
        in_context_examples = f.read()
    return in_context_examples

in_context_example = load_in_context_examples('/home/23_zxx/workspace/llama3-ft/Llama3-Tutorial/logic_llm/icl_examples','ProntoQA', 'Direct' )
print(in_context_example)

In [7]:
def prompt_creator(in_context_example, example):
    full_prompt = in_context_example
    context = example['context']
    question = example['question']
    options = '\n'.join([opt.strip() for opt in example['options']])
    full_prompt = full_prompt.replace('[[CONTEXT]]', context)
    full_prompt = full_prompt.replace('[[QUESTION]]', question)
    full_prompt = full_prompt.replace('[[OPTIONS]]', options)
    return full_prompt

In [16]:
from tqdm import tqdm
def reasoning_graph_generation(save_path, mode, dataset_name, split, model_name):
    raw_dataset = load_raw_dataset('/home/23_zxx/workspace/llama3-ft/Llama3-Tutorial/data','ProntoQA','dev')
    print(f"Loaded {len(raw_dataset)} examples from split.")

    in_context_example = load_in_context_examples('/home/23_zxx/workspace/llama3-ft/Llama3-Tutorial/logic_llm/icl_examples','ProntoQA', 'Direct' )

    outputs = []
    generated_texts = []
    for example in tqdm(raw_dataset[0:10]):
        question = example['question']

        # create prompt
        full_prompt = prompt_creator(in_context_example, example)
        messages = [
            {'role':'system','content': 'hello,You are a helpful human assistant!'},
            {'role':'user', 'content': full_prompt }
        ]
        prompt = pipeline.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt = True
        )

        terminators = [
            pipeline.tokenizer.eos_token_id,
            pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        generated_text = pipeline(
            prompt,
            max_new_tokens = 256,
            eos_token_id = terminators,
            do_sample = True,
            temperature = 0.6,
            top_p = 0.9
        )
        generated_texts.append(generated_text)

        response = generated_text[0]['generated_text'].split('<|start_header_id|>assistant<|end_header_id|>')[-1]
        label_phrase = 'correct option is:'
        generated_reasoning = response.split(label_phrase)[0].strip()
        generated_answer = response.split(label_phrase)[-1].strip()
        output = {'id': example['id'], 
                'question': question, 
                'answer': example['answer'], 
                'predicted_reasoning': generated_reasoning,
                'predicted_answer': generated_answer}
        outputs.append(output)

    with open(os.path.join(save_path, f'{dataset_name}_generation.json'), 'w') as f:
        json.dump(generated_texts, f, indent=2, ensure_ascii=False)

    # save outputs        
    with open(os.path.join(save_path, f'{mode}_{dataset_name}_{split}_{model_name}.json'), 'w') as f:
        json.dump(outputs, f, indent=2, ensure_ascii=False)
    
reasoning_graph_generation('/home/23_zxx/workspace/llama3-ft/Llama3-Tutorial/logic_llm/results', 'Direct', 'ProntoQA', 'dev', 'Llama3-8B-Instruction')        


Loaded 500 examples from split.


  0%|          | 0/10 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
 10%|█         | 1/10 [00:03<00:27,  3.11s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
 20%|██        | 2/10 [00:04<00:16,  2.09s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
 30%|███       | 3/10 [00:08<00:21,  3.03s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
 40%|████      | 4/10 [00:12<00:19,  3.21s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
 50%|█████     | 5/10 [00:19<00:24,  4.84s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
 60%|██████    | 6/10 [00:26<00:21,  5.45s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
 70%|███████   | 7/10 [00:28<00:13,  4.37s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
 80%|████████  | 8/10 [00:35<00:10,  5.25s/it]Setting `p

In [17]:
import re
import json
from tqdm import tqdm
import random
import os
import argparse

def extract_number(string):
    # Remove all characters except digits, decimal point and negative sign
    try:
        num_string = re.sub(r'[^\d.-]', '', string)
        num_string = num_string.replace('$', '')
        return float(num_string)
    except:
        try:
            return float(random.randint(0, 100))
            # return float(w2n.word_to_num(string))
        except:
            # print('Error: ', string)
            print('Error')
            return float(random.randint(0, 100))

def argmax(iterable):
    return max(enumerate(iterable), key=lambda x: x[1])[0]

# these functions are heavily influenced by the HF squad_metrics.py script
def normalize_text(s):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_exact_match(prediction, truth):
    return int(normalize_text(prediction) == normalize_text(truth))
    # return prediction == truth

def compute_f1(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()
    
    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)
    
    common_tokens = set(pred_tokens) & set(truth_tokens)
    
    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0
    
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)
    
    return 2 * (prec * rec) / (prec + rec)

def evaluate_sample(prediction, gold_answers):
    em_score = max((compute_exact_match(prediction, answer)) for answer in gold_answers)
    f1_score = max((compute_f1(prediction, answer)) for answer in gold_answers)
    return em_score, f1_score

def get_choice(answer_str):
    choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'A)', 'B)', 'C)', 'D)', 'E)', 'F)', 'G)', 'H)', 
               'A.', 'B.', 'C.', 'D.', 'E.', 'F.', 'G.', 'H.']
    for c in choices:
        if answer_str.startswith(c):
            return c.replace(')', '')
    return None

def evaluate_QA(result_file):
    with open(result_file, 'r') as f:
        QA_results = json.load(f)

    total_em = 0.0
    total_f1 = 0.0
    count = 0
    for sample in QA_results:
        gold_answer = sample['answer'].replace('(', '').replace(')', '').strip()
        answer_str = sample['predicted_answer'].strip()
        prediction = get_choice(answer_str)

        indicators = ['the correct option is', 'the correct answer is', 
                      'The correct answer is', 'The correct option is',
                      'Thus, the answer is']
        if prediction is None:
            for indicator in indicators:
                if answer_str.find(indicator)>=0:
                    answer_str = answer_str.split(indicator)[1].strip()
                    prediction = get_choice(answer_str)
                    break

        if prediction is None:
            print(answer_str)

        print(f"prediction: {prediction} \t gold_answers: {gold_answer} \t match: {prediction == gold_answer}")
        
        em_score = 1.0 if prediction == gold_answer else 0.0
        total_em += em_score
        count += 1
    
    avg_em = total_em / count
    print(f"EM: {avg_em}")

# def parse_args():
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--dataset_name', type=str)
#     parser.add_argument('--model_name', type=str)
#     parser.add_argument('--mode', type=str)
#     parser.add_argument('--split', type=str, default='dev')
#     parser.add_argument('--result_path', type=str, default='./results')
#     args = parser.parse_args()
#     return args

if __name__ == "__main__":
    #args = parse_args()
    result_file = os.path.join('/home/23_zxx/workspace/llama3-ft/Llama3-Tutorial/logic_llm/results', 'Direct_ProntoQA_dev_Llama3-8B-Instruction.json')
    evaluate_QA(result_file)


: B) False.
prediction: None 	 gold_answers: B 	 match: False
prediction: B 	 gold_answers: A 	 match: False
prediction: B 	 gold_answers: A 	 match: False
prediction: B 	 gold_answers: B 	 match: True
I'd be happy to help!

From the context, we know that Alex is a numpus, and Wumpuses are impuses. Since Alex is a numpus, and each numpus is a tumpus, Alex is also a tumpus. We also know that Wumpuses are not dull, and Impuses are Wumpuses. Therefore, Impuses are not dull. Since Impuses are also dumpuses, and dumpuses are not wooden, Impuses are not wooden. However, we don't have any information about the dullness of dumpuses.

Now, let's look at the question: Is the following statement true or false? Alex is not dull.

From the context, we know that Wumpuses are not dull, and Impuses are Wumpuses. Since Alex is a tumpus, and each tumpus is not large, Alex is not large. However, we don't have any information about the dullness of tumpuses. Since Alex is a numpus, and each numpus is a tum