## Overview

This notebook builds a long context benchmark, as discussed by Federico, Arjun,
Harm, and Leandro. Unlike typical Code LLM benchmarks, this is a test
generation benchmark: we prompt the model with the implementation of a Python
function (and its docstring), and ask for a test suite. The result is scored in
two steps: if any test in the test suite fails, the score is zero. Otherwise,
we the tests are scored based on their coverage of the funciton's
implementation. To make the problem harder, we add several other functions to
the prompt to serve as distractors. There are enough distractors to exercise
models with very long context lengths (up to 128K tokens). We use two datasets:
HumanEval and MultiPL-T. Both have several Python functions. The HumanEval
functions should be decontaminated before training: their docstrings should not
appear in the training data. The MultiPL-T functions are functions extracted
from the Stack v1.2. Thus they are very likely to appear in models' training
data, but they are merely distractors.

In [14]:
import datasets
import random
import os
from typing import List
import itertools

In [24]:
# In case you're in an environment where you really want this to be set.
print(os.getenv("HF_DATASETS_CACHE"))

random.seed(42)

# This is likely an overestimate. But, it should be close enough and we don't need
# to be exact.
CHARS_PER_TOKEN = 3.5

None


In [8]:
humaneval = datasets.load_dataset("openai_humaneval", split="test")
# The MultiPL-T dataset is currently private, but will be public soon.
multiplt = datasets.load_dataset("nuprl/stack-dedup-python-testgen-starcoder-filter-inferred-v2", split="train")


Found cached dataset openai_humaneval (/home/arjun/.cache/huggingface/datasets/openai_humaneval/openai_humaneval/1.0.0/2955cebd73602e828fa8c0a424c594e5fab4ec863b316ca98f3d8fdb6a626e75)
Found cached dataset parquet (/home/arjun/.cache/huggingface/datasets/nuprl___parquet/nuprl--stack-dedup-python-testgen-starcoder-filter-inferred-v2-8a147987b4874669/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


In [73]:
def get_distractors(approximate_token_count: int) -> List[str]:
    result = []
    result_chars = 0
    target_chars = int(approximate_token_count * CHARS_PER_TOKEN)
    while result_chars < target_chars:
        fn = random.choice(multiplt)["content"]
        result.append(fn)
        result_chars += len(fn)
    return result

def build_prompt(
        approximate_token_count: int,
        humaneval_problem_index: int,
        insert_where: str):
    distractors = get_distractors(approximate_token_count)
    target_problem = humaneval[humaneval_problem_index]
    target_function = target_problem["prompt"] + target_problem["canonical_solution"]
    if insert_where == "first half":
        insert_index = random.randint(0, len(distractors) // 2)    
    elif insert_where == "second half":
        insert_index = random.randint(len(distractors) // 2, len(distractors))
    else:
        raise ValueError(f"Unknown insert_where: {insert_where}")
    distractors.insert(insert_index, target_function)
    return { 
        "prompt": "\n\n".join(distractors),
        "target_function": target_function,
        "humaneval_task_id": target_problem["task_id"],
        "task_id": f"LongBench_{target_problem['task_id']}_{approximate_token_count}_{insert_where}",
        "approx_token_count": approximate_token_count,
        "target_function_name": target_problem["entry_point"]
    }


## Example Prompts

Some examples of prompts that we can construct.

With 0 as the number of target tokens, we get no distractors.

In [28]:
print(build_prompt(0, 53, "first half")["prompt"])



def add(x: int, y: int):
    """Add two numbers x and y
    >>> add(2, 3)
    5
    >>> add(5, 7)
    12
    """
    return x + y



With 400 target tokens, we get 1-2 distractors and the `where` argument starts to
make sense.

In [33]:
print(build_prompt(400, 53, "first half")["prompt"])

def get_cli_fname(lon, lat, scenario=0):
    """Get the climate file name for the given lon, lat, and scenario"""
    # The trouble here is relying on rounding is problematic, so we just
    # truncate
    lon = round(lon, 2)
    lat = round(lat, 2)
    return "/i/%s/cli/%03ix%03i/%06.2fx%06.2f.cli" % (
        scenario,
        0 - lon,
        lat,
        0 - lon,
        lat,
    )



def add(x: int, y: int):
    """Add two numbers x and y
    >>> add(2, 3)
    5
    >>> add(5, 7)
    12
    """
    return x + y


def remove_single_characters(text):
    """
    Remove any remaining single-character words
    :text: string
    :return: string
    """
    return ' '.join([word for word in text.split() if len(word) > 1])

def get_thread_id_from_suggestion_id(suggestion_id):
    """Gets the thread_id from the suggestion_id.

    Args:
        suggestion_id: str. The ID of the suggestion.

    Returns:
        str. The thread ID linked to the suggestion.
    """
    return suggestion_id

In [36]:
print(build_prompt(400, 53, "second half")["prompt"])

def clean_fields(fields):
    """
    Clean and return a ``fields`` list of Deb822Field.
    """
    for hf in (fields or []):
        hf.rstrip()
    return fields

def Compact2D(m):
    """
    Decodes the 64 bit morton code into a 32 bit number in the 2D space using
    a divide and conquer approach for separating the bits. 
    1 bit is not used because the integers are not unsigned
    
    Args:
        n (int): a 64 bit morton code
        
    Returns:
        int: a dimension in 2D space
        
    Raises:
        Exception: ERROR: Morton code is always positive
    """
    if m < 0:
        raise Exception("""ERROR: Morton code is always positive""")
    m &= 0x5555555555555555
    m = (m ^ (m >> 1))  & 0x3333333333333333
    m = (m ^ (m >> 2))  & 0x0f0f0f0f0f0f0f0f
    m = (m ^ (m >> 4))  & 0x00ff00ff00ff00ff
    m = (m ^ (m >> 8))  & 0x0000ffff0000ffff
    m = (m ^ (m >> 16)) & 0x00000000ffffffff
    return m



def add(x: int, y: int):
    """Add two numbers x and y
    

## The Benchmark

There are a number of trivial problems in HumanEval, such as #53 shown above.
We want a subset of problems that have a range of difficulties. The following
ten problems have varying difficulty in several programming languages
and were picked by Francesca Lucchetti for MultiPL-T.

- HumanEval_100_make_a_pile
- HumanEval_13_greatest_common_divisor
- HumanEval_152_compare
- HumanEval_157_right_angle_triangle
- HumanEval_27_flip_case
- HumanEval_40_triples_sum_to_zero
- HumanEval_55_fib
- HumanEval_66_digitSum
- HumanEval_72_will_it_fly
- HumanEval_74_total_match

In [74]:
APPROXIMATE_TOKEN_COUNTS = [0, 8_000, 64_000, 128_000]
HUMANEVAL_PROBLEM_INDICES = [100, 13, 152, 157, 27, 40, 55, 66, 72, 74]
INSERT_WHERES = [ "first half", "second half"]

benchmark = datasets.Dataset.from_list(
    [build_prompt(*x) for x in itertools.product(APPROXIMATE_TOKEN_COUNTS, HUMANEVAL_PROBLEM_INDICES, INSERT_WHERES)])
benchmark

Dataset({
    features: ['prompt', 'target_function', 'humaneval_task_id', 'task_id', 'approx_token_count', 'target_function_name'],
    num_rows: 80
})

In [75]:
benchmark.to_json("benchmark.jsonl", lines=True)

Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  9.66ba/s]


14995862

In [76]:
benchmark[0]

{'prompt': '\ndef make_a_pile(n):\n    """\n    Given a positive integer n, you have to make a pile of n levels of stones.\n    The first level has n stones.\n    The number of stones in the next level is:\n        - the next odd number if n is odd.\n        - the next even number if n is even.\n    Return the number of stones in each level in a list, where element at index\n    i represents the number of stones in the level (i+1).\n\n    Examples:\n    >>> make_a_pile(3)\n    [3, 5, 7]\n    """\n    return [n + 2*i for i in range(n)]\n',
 'target_function': '\ndef make_a_pile(n):\n    """\n    Given a positive integer n, you have to make a pile of n levels of stones.\n    The first level has n stones.\n    The number of stones in the next level is:\n        - the next odd number if n is odd.\n        - the next even number if n is even.\n    Return the number of stones in each level in a list, where element at index\n    i represents the number of stones in the level (i+1).\n\n    Exa