In [1]:
import os
from xml.etree import ElementTree
import numpy as np
import torch
import json
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
from io import StringIO
from contextlib import redirect_stdout
from termcolor import colored

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import dataset_handler as dh

In [3]:
cur_dir = os.getcwd()
os.chdir(os.path.join(cur_dir, 'data'))
!git clone https://gitlab.cs.washington.edu/ALGES/TACL2015.git
!git clone https://github.com/chaochun/nlu-asdiv-dataset.git
!git clone https://github.com/openai/grade-school-math.git
os.chdir(cur_dir)

fatal: destination path 'TACL2015' already exists and is not an empty directory.
fatal: destination path 'nlu-asdiv-dataset' already exists and is not an empty directory.
fatal: destination path 'grade-school-math' already exists and is not an empty directory.


In [4]:
def preproc_gen_toks(gen_toks, input_len):
    list_out = []
    for gen_tok in gen_toks:
        last_tokens = gen_tok[input_len:]
        generated_text = tokenizer.decode(last_tokens)
        output = generated_text.split("\n\n")[0]
        list_out.append(output)
    return list_out

def pass_at_k(n, c, k):
    """
    :param n: total number of samples
    :param c: number of correct samples
    :param k: k in pass@$k$
    """
    if n - c < k:
        return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

In [5]:
torch.manual_seed(0)
np.random.seed(5)

# genji_model = ""
gptj_model = "EleutherAI/gpt-j-6B"
codeparrot_model = "lvwerra/codeparrot"

asdiv_path = "data/nlu-asdiv-dataset/dataset/ASDiv.xml"
gsm8k_path = "data/grade-school-math/grade_school_math/data/train.jsonl"
singleEq_path = "data/TACL2015/questions.json"

"""Choose the dataset you want to test"""
# dataset_path = gsm8k_path
dataset_path = singleEq_path
# dataset_path = asdiv_path

"""Load the priming text to add to the prompt and sample a question"""
# priming_text = read_string_from_file("data/priming_texts/gsm8k.txt")
# priming_text = read_string_from_file("data/priming_texts/singleEq.txt")
# priming_text = read_string_from_file("data/priming_texts/asdiv.txt")

# current_dataset = dh.asdiv_dataset(asdiv_path, "data/priming_texts/asdiv.txt", "asdiv")
# current_dataset = dh.gsm8k_datatset(dataset_path, "data/priming_texts/gsm8k_fewer.txt", "gsm8k")
current_dataset = dh.singleEq_dataset(dataset_path, "data/priming_texts/singleEq.txt", "singleEq")

# sample_q, sample_a = sample_gsm8k(dataset_path)
# sample_q, sample_a = sample_singleEq(dataset_path)
# sample_q_list, sample_a_list = sample_asdiv(dataset_path, 25)

sample_q_list, sample_a_list = current_dataset.sample_n_for_prompting(10)

current_dataset.print_entry_from_idx(45)

[33mSam went to 14 football games this year. He went to 29 games  last year. How many football games did Sam go to in all ?[0m
[32m43.0[0m

----------------------------------------------------------------------------------------------------



In [6]:
"""GPT-J and codeparrot models run in HFTest venv"""
tokenizer = AutoTokenizer.from_pretrained(gptj_model)
model = AutoModelForCausalLM.from_pretrained(gptj_model).eval().cuda()

In [7]:
torch.manual_seed(42)
np.random.seed(42)

"""n = 4
k = 3"""
n = 3
k = 3

pass_k_list = []

for sample_q, sample_a in zip(sample_q_list, sample_a_list):
    prompt = f"{current_dataset.priming_text}\n\n#{sample_q}"

    tokens = tokenizer(prompt, return_tensors="pt").input_ids
    generated_tokens = model.generate(
        tokens.long().cuda(),
        use_cache=True,
        do_sample=True,
        top_k=50,
        temperature=0.4,
        top_p=0.9,
        min_length=1,
        max_length=len(tokens[0]) + 100,
        num_return_sequences=n,
        pad_token_id=tokenizer.eos_token_id,
    )

    list_outputs = preproc_gen_toks(generated_tokens, len(tokens[0]))

    is_correct_list = [current_dataset.verify_pred_from_output(output, sample_q, sample_a) for output in list_outputs]

    c = is_correct_list.count(True)

    pass_k = pass_at_k(n, c, k)
    pass_k_list.append(pass_k)

7.111111111111111
5.333333333333333
7.111111111111111
5.333333333333333
1111111111.0
5.333333333333333
16.0
1.84
16.0
1.84
126.0
1.84
750.0
218.0
750.0
218.0
750.0
218.0
7.0
7.0
7.0
7.0
7.0
7.0
35.0
21.0
105.0
21.0
35.0
21.0
672.0
56.0
96.0
56.0
672.0
56.0
984.0
984.0
984.0
984.0
984.0
984.0
4.0
4.0
10.0
4.0
10.0
4.0
7.0
5.0
7.0
5.0
7.0
5.0
14.02
5.98
14.02
5.98
280.4
5.98


In [8]:
print(f"Pass@{k} = {np.mean(np.array(pass_k_list))}")

Pass@3 = 0.3
