In [2]:
cur_dir = os.getcwd()
os.chdir(os.path.join(cur_dir, 'data'))
!git clone https://gitlab.cs.washington.edu/ALGES/TACL2015.git
!git clone https://github.com/chaochun/nlu-asdiv-dataset.git
!git clone https://github.com/openai/grade-school-math.git
os.chdir(cur_dir)

Cloning into 'TACL2015'...
remote: Enumerating objects: 2294, done.[K
remote: Counting objects: 100% (2294/2294), done.[K
remote: Compressing objects: 100% (2234/2234), done.[K
remote: Total 2294 (delta 203), reused 2103 (delta 55), pack-reused 0[K
Receiving objects: 100% (2294/2294), 4.51 MiB | 3.62 MiB/s, done.
Resolving deltas: 100% (203/203), done.
Cloning into 'nlu-asdiv-dataset'...
remote: Enumerating objects: 30, done.[K
remote: Counting objects: 100% (30/30), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 30 (delta 6), reused 20 (delta 5), pack-reused 0[K
Unpacking objects: 100% (30/30), 425.56 KiB | 1.67 MiB/s, done.
Cloning into 'grade-school-math'...
remote: Enumerating objects: 36, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 36 (delta 14), reused 30 (delta 11), pack-reused 0[K
Unpacking objects: 100% (36/36), 3.01 MiB | 4.94 MiB/s, done.


In [1]:
import os
from xml.etree import ElementTree
import numpy as np
import torch
import json
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
from io import StringIO
from contextlib import redirect_stdout
from termcolor import colored
import wandb

In [3]:
import dataset_handler as dh
import helper_func as hf
import exp_impl.simple_func_def as exp_impl

gptj_model = "EleutherAI/gpt-j-6B"
codeparrot_model = "lvwerra/codeparrot"

model_name = "gpt-j"
#model_name = "codegen"

In [6]:
"""Load gsm8k"""

if model_name == "gpt-j":
    priming_text_path = (
        "data/priming_texts/gsm8k/gpt-j/gsm8k_fewer_alt.txt"  # for gpt-j
    )
    current_dataset = dh.init_dataset_from_name(
        "gsm8k", primingtext_path=priming_text_path
    )
else:
    priming_text_path = "data/priming_texts/gsm8k/codegen/gsm8k_fewer_alt_codegen_func.txt"  # for codegen
    current_dataset = dh.init_dataset_from_name(
        "gsm8k",
        primingtext_path=priming_text_path,
        sample_func=exp_impl.sample_n_for_prompting,
        generate_prompt_func=exp_impl.generate_prompt,
    )


In [5]:
"""Load asdiv"""

if model_name == "gpt-j":
    priming_text_path = "data/priming_texts/asdiv/asdiv_prefix.txt" # for gpt-j
else:
    priming_text_path = "data/priming_texts/asdiv/asdiv_prefix_codegen.txt" # for codegen

current_dataset = dh.init_dataset_from_name("asdiv", primingtext_path = priming_text_path)

In [None]:
hf.set_all_seeds()

sample_q_list, sample_a_list = current_dataset.sample_n_for_prompting(5)

print(colored(sample_q_list[0], "blue"))
print(colored(sample_a_list[0], "green"))

In [5]:
if model_name == "gpt-j":
    """GPT-J and codeparrot models run in HFTest venv"""
    tokenizer = AutoTokenizer.from_pretrained(gptj_model)
    model = AutoModelForCausalLM.from_pretrained(gptj_model).half().eval().cuda()
elif model_name == "codegen":
    """CodeGen runs in the venv venv"""
    model_args = hf.model_args()
    #model_args.model = "codegen-350M-mono"
    model, tokenizer = hf.load_CodeGen(model_args)

In [None]:
# Set up for CodeGen
config = hf.codegen_gen_args()
config.num_return_sequences = 4 # 4 for gsm8k 5 for asdiv
config.k = 3
config.max_lenght_after_input = 250
config.top_p = 0.95
config.top_k = 50
config.temperature = 0.7
config.min_length = 3

hf.set_all_seeds(model_name)
hf.testing_loop(current_dataset, tokenizer, model, sample_q_list, sample_a_list, config, func_def_mod=True, print_output=False)

In [None]:
# Set up for gpt-j
config = hf.gptj_gen_args()

hf.set_all_seeds(model_name)
hf.testing_loop(current_dataset, tokenizer, model, sample_q_list, sample_a_list, config, print_output=False)

In [None]:
with wandb.init(project="PracticalWork", entity="antoniolopardo",config=config, name="@100-gsm8k-codegen-func-def"):

        hf.set_all_seeds(model)
        pass_at_k = hf.testing_loop(current_dataset, tokenizer, model, sample_q_list, sample_a_list, config)

        wandb.log({"pass_at_k": pass_at_k})