In [7]:
formats = {
    "Python":"",
    "JavaScript":"// {language}\n{docstring}\n{name}{params}",
    "TypeScript":"// {language}\n{docstring}\n{name}{params}",
    "Java":"// {language}\n{docstring}\n{name}{params}",
    "C":"// {language}\n{docstring}\n{name}{params}",
    "C++":"// {language}\n{docstring}\n{name}{params}",
    "C#":"// {language}\n{docstring}\n{name}{params}",
    "Ruby":"# {language}\n{docstring}\n{name}{params}\n",
    "Shell":"# {language}\n{docstring}\n{name}{params}",
    "PHP":"// {language}\n{docstring}\n{name}{params}",
}

def reconstruct_python_func(example):
    func = f"# {example['language']}\n"
    
    lines = example['docstring'].split("\n")
    for line in lines:
        func += "# " + line + "\n"
    
    func += example['name']
    
    args = example['params']
    if args[0] == '[':
        func += "("
        args = [arg.strip()[1:-1] for arg in args[1:-1].split(",")]
        for arg in args:
            func += arg + ", "
        func = func[:-2] + "):\n"
    else:
        func += args + ":\n"
    
    func_with_body = func + "\n".join(example['body'].split("\n")[1:])
    
    return func_with_body


def reconstruct_func(example):
    if example["language"] == "Python":
        return reconstruct_python_func(example)
    else:
        func = formats[example['language']].format(
            language=example['language'],
            docstring=example['docstring'],
            name=example['name'],
            params=example['params']
            )
        return func+example['body']

In [None]:
from transformers import AutoTokenizer


model_name = "TinyLlama-1.1B-Chat-v1.0"

def init_tokenizer(model_name):
    return AutoTokenizer.from_pretrained(model_name)


tokenizer = init_tokenizer(model_name)

tokenizer.bos_token = "<func>"
tokenizer.eos_token = "</func>"
tokenizer.pad_token = "</s>"

In [None]:
def reformat_prompt(prompt, gen_type):
    func_name = ""
    definition = ""
    docstring = ""
    
    if gen_type == "mbpp":
        func_name = prompt["canonical_solution"]
        definition = "def" + func_name.split("def")[1].split(':\n')[0] + ":\n"
        docstring = prompt["prompt"]
    else:
        prompt = prompt["prompt"].split("def")[1].split(':\n')
        docstring = prompt[1]
        definition = "def" + prompt[0] + ":\n"
    
    docstring = docstring.replace('"""',"").replace("'''","")
    new_prompt = tokenizer.bos_token + "\n# Python\n"
    for line in docstring.split("\n"):
        if line.strip():
            new_prompt += f"# {line}\n"
    new_prompt += definition
    return new_prompt

In [10]:
def reformat_prompt_tokens(prompt, gen_type):
    func_name = ""
    definition = ""
    docstring = ""
    
    if gen_type == "mbpp":
        func_name = prompt["canonical_solution"]
        definition = "def" + func_name.split("def")[1].split(':\n')[0] + ":\n"
        docstring = prompt["prompt"]
    else:
        prompt = prompt["prompt"].split("def")[1].split(':\n')
        docstring = prompt[1]
        definition = "def" + prompt[0] + ":\n"
    
    docstring = docstring.replace('"""',"").replace("'''","")
    new_prompt = tokenizer.bos_token + "\n# <|language|> Python\ <|/language|>\n# <|docstring|>\n"
    for line in docstring.split("\n"):
        if line.strip():
            new_prompt += f"# {line}\n"
    new_prompt += "# <|/docstring|>\n# <|head|>\n"
    new_prompt += definition
    new_prompt += "# <|/head|>\n# <|body|>\n"
    print(new_prompt)
    return new_prompt

In [None]:
def generate_one_completion(id,prompt,code_generator, gen_type):
    print(id)
    prompt = reformat_prompt(prompt, gen_type)
    #prompt = prompt["prompt"]
    
    generated_code = code_generator(
    prompt,
    max_new_tokens=512,
    truncation=True
    )[0]["generated_text"]
    generated_code = "# " + generated_code
    if generated_code.find(tokenizer.eos_token) >= 0:
        generated_code = generated_code.split(tokenizer.eos_token)[0]
    return generated_code

In [None]:
from datasets import load_dataset, concatenate_datasets
from transformers import pipeline,AutoModelForCausalLM
from peft import PeftModel
from evalplus.data import get_human_eval_plus, get_mbpp_plus, write_jsonl

gen_type = "human-eval"

problems = []
match gen_type:
    case "human-eval":
        problems = get_human_eval_plus()
    case "mbpp":
        problems = get_mbpp_plus()


num_samples_per_task = 1

def run_gen(path):
    model = AutoModelForCausalLM.from_pretrained(model_name)
    p_model = PeftModel.from_pretrained(model,path)
    model = p_model.merge_and_unload()
    
    code_generator = pipeline(task="text-generation", model=model, tokenizer=tokenizer, device="cuda")
    samples = []
    for task_id, problem in problems.items():
        tasksplit = task_id.split("/")
        #if tasksplit[0] == "Mbpp" and int(tasksplit[1]) < 601: continue
        for _ in range(num_samples_per_task):
            samples.append(dict(
                task_id=task_id,
                completion=generate_one_completion(task_id,problem,code_generator,gen_type)
                ))
    return samples

for num in [2,6]:
    path = f"results/tagged/checkpoint-{num}"
    samples = run_gen(path)
    write_jsonl(f"{gen_type}/results/tagged/{model_name}-{num}.jsonl", samples)

HumanEval/0
HumanEval/1
HumanEval/2
HumanEval/3
HumanEval/4
HumanEval/5
HumanEval/6
HumanEval/7
HumanEval/8
HumanEval/9
HumanEval/10
HumanEval/11
HumanEval/12
HumanEval/13
HumanEval/14
HumanEval/15
HumanEval/16
HumanEval/17
HumanEval/18
HumanEval/19
HumanEval/20
HumanEval/21
HumanEval/22
HumanEval/23
HumanEval/24
HumanEval/25
HumanEval/26
HumanEval/27
HumanEval/28
HumanEval/29
HumanEval/30
HumanEval/31
HumanEval/32
HumanEval/33
HumanEval/34
HumanEval/35
HumanEval/36
HumanEval/37
HumanEval/38
HumanEval/39
HumanEval/40
HumanEval/41
HumanEval/42
HumanEval/43
HumanEval/44
HumanEval/45
HumanEval/46
HumanEval/47
HumanEval/48
HumanEval/49
HumanEval/50
HumanEval/51
HumanEval/52
HumanEval/53
HumanEval/54
HumanEval/55
HumanEval/56
HumanEval/57
HumanEval/58
HumanEval/59
HumanEval/60
HumanEval/61
HumanEval/62
HumanEval/63
HumanEval/64
HumanEval/65
HumanEval/66
HumanEval/67
HumanEval/68
HumanEval/69
HumanEval/70
HumanEval/71
HumanEval/72
HumanEval/73
HumanEval/74
HumanEval/75
HumanEval/76
HumanEval