### SEED GATHERING GET CONTENT

In [1]:
from tree_sitter_parser import LANGUAGE, make_parser, node_to_string
# from tree_sitter import LANGUAGE, make_parser, node_to_string
import datasets
import os
import signal
from multiprocessing import Pool
#import os
import boto3
import smart_open
#from datasets import load_dataset,Dataset
from botocore import UNSIGNED
from botocore.config import Config

s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
def download_contents(blob_id, src_encoding):
    s3_url = f"s3://softwareheritage/content/{blob_id}"
    with smart_open.open(s3_url, "rb", compression=".gz", transport_params={"client": s3}) as fin:
        content = fin.read().decode(src_encoding)
    
    return content

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
TOPLEVEL_DOCSTRING_QUERY = LANGUAGE.query("""
(
    (function_definition
      name: (identifier)
      body: (block .
        (expression_statement
            (string
                (string_start) @docstring.start
                (string_content)
                (string_end) @docstring.end)))) @function.def
    (#eq? @docstring.start "\\\"\\\"\\\"")
    (#eq? @docstring.end "\\\"\\\"\\\"")
)
""")


def get_fns_with_docstrings(src, tree):
    captures = TOPLEVEL_DOCSTRING_QUERY.captures(tree.root_node)
    res = []
    for capture in captures:
        node, ty = capture
        if ty != "function.def":
            continue
        # if the starting col is not 0, then it's not a top-level fn
        _, col = node.start_point
        if col != 0:
            continue
        res.append(node_to_string(src, node))
    return res


def parse_ex(parser, ex):
    #ex = ex["content"]
    ex = download_contents(ex["blob_id"], ex["src_encoding"])
    try:
        buf = bytes(ex, "utf8")
        tree = parser.parse(buf)
        return get_fns_with_docstrings(buf, tree)
    except:
        return []


# if one parser segfaults, we can just make a new one and other parsers will still be fine
# WE LOVE TREE SITTER!
PARSERS = None


def process_chunk(idx_and_chunk):
    assert PARSERS is not None
    idx, chunk = idx_and_chunk
    parser = PARSERS[idx]
    chunk_new_funs = set()
    for ex in chunk:
        chunk_new_funs.update(parse_ex(parser, ex))
    return chunk_new_funs


def main(args):
    global PARSERS
    # ds = datasets.load_dataset(
    #     args.dataset,
    #     data_dir=args.data_dir,
    #     split="train",
    # )
    ds = datasets.load_dataset("bigcode/the-stack-v2-dedup", "Python", cache_dir=f"../nk569/stack", streaming=False, split="train")
    funs = set()
    PARSERS = [make_parser() for _ in range(args.num_workers)]
    total_len = len(ds)
    CHUNK_SIZE = 1000 * args.num_workers

    print(f"Total length: {total_len}")
    print(f"Chunk size: {CHUNK_SIZE}")

    chunk = []
    p = Pool(args.num_workers)
    for i, ex in enumerate(ds):
        if i % (total_len // 100) == 0:
            print(f"{i}/{total_len}")
        try:
            chunk.append(ex)
            if len(chunk) == CHUNK_SIZE or i == total_len - 1:
                print(f"Processing chunk {i // CHUNK_SIZE}")
                # divide the chunk into NUM_WORKERS chunks
                subchunk_size = len(chunk) // args.num_workers
                subchunks = [chunk[i:i + subchunk_size]
                             for i in range(0, len(chunk), subchunk_size)]
                new_funs_iter = p.imap(
                    process_chunk, [(i, subchunk) for i, subchunk in enumerate(subchunks)])
                print("Getting new functions")
                len_before = len(funs)
                while True:
                    try:
                        def timeout_handler(_, __):
                            raise KeyboardInterrupt  # it's fineeeeeee
                        signal.signal(signal.SIGALRM, timeout_handler)
                        signal.alarm(60)
                        funs.update(next(new_funs_iter))
                        signal.alarm(0)
                    except KeyboardInterrupt:
                        signal.alarm(0)
                        print("Keyboard interrupt. Terminating pool")
                        p.terminate()
                        p = Pool(args.num_workers)
                        break
                    except StopIteration:
                        break
                    except Exception as e:
                        print(e)

                signal.alarm(0)

                PARSERS = [make_parser() for _ in range(args.num_workers)]

                print(
                    f"Done processing chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions")

                chunk = []
        except Exception as e:
            print(e)
            chunk = []

        if i == total_len - 1:
            break

    p.close()

    new_ds_dict = {
        "content": list(funs),
        "id": list(range(len(funs)))
    }

    new_ds = datasets.Dataset.from_dict(new_ds_dict)
    #new_ds.push_to_hub(args.push, private=True)

In [13]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
parser.add_argument("--dataset", type=str,
                    default="bigcode/the-stack-dedup")
parser.add_argument("--data_dir", type=str, default="data/python")
parser.add_argument("--push", type=str, default="data/python")
args = parser.parse_args([])
main(args)

Total length: 47272886
Chunk size: 20000
0/47272886
Processing chunk 0
Getting new functions
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
module 'signal' has no attribute 'SIGALRM'
modu

KeyboardInterrupt: 

In [8]:
ds = datasets.load_dataset("bigcode/the-stack-v2-dedup", "Python", cache_dir=f"../nk569/stack", streaming=False, split="train")

Generating train split: 100%|██████████| 47272886/47272886 [03:15<00:00, 241748.34 examples/s]


In [9]:
funs = set()
NUMWORKERS = os.cpu_count()
PARSERS = [make_parser() for _ in range(NUMWORKERS)]
total_len = len(ds)
CHUNK_SIZE = 1000 * NUMWORKERS

print(f"Total length: {total_len}")
print(f"Chunk size: {CHUNK_SIZE}")

chunk = []
p = Pool(NUMWORKERS)

Total length: 47272886
Chunk size: 20000


In [None]:
for i, ex in enumerate(iter(ds)):
    if i % (total_len // 100) == 0:
        print(f"{i}/{total_len}")
    try:
        chunk.append(ex)
        if len(chunk) == CHUNK_SIZE or i == total_len - 1:
            print(f"Processing chunk {i // CHUNK_SIZE}")
            # divide the chunk into NUM_WORKERS chunks
            subchunk_size = len(chunk) // NUMWORKERS
            subchunks = [chunk[i:i + subchunk_size]
                         for i in range(0, len(chunk), subchunk_size)]
            new_funs_iter = p.imap(
                process_chunk, [(i, subchunk) for i, subchunk in enumerate(subchunks)])
            print("Getting new functions")
            len_before = len(funs)
            while True:
                try:
                    def timeout_handler(_, __):
                        raise KeyboardInterrupt  # it's fineeeeeee
                    signal.signal(signal.SIGALRM, timeout_handler)
                    signal.alarm(60)
                    funs.update(next(new_funs_iter))
                    signal.alarm(0)
                except KeyboardInterrupt:
                    signal.alarm(0)
                    print("Keyboard interrupt. Terminating pool")
                    p.terminate()
                    p = Pool(NUMWORKERS)
                    break
                except StopIteration:
                    break
                except Exception as e:
                    print(e)

            signal.alarm(0)

            PARSERS = [make_parser() for _ in range(NUMWORKERS)]

            print(
                f"Done processing chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions")

            chunk = []
    except Exception as e:
        print(e)
        chunk = []

    if i == total_len - 1:
        break


p.close()
new_ds_dict = {
    "content": list(funs),
    "id": list(range(len(funs)))
}

new_ds = datasets.Dataset.from_dict(new_ds_dict)

0/47272886
Processing chunk 0
Getting new functions
Done processing chunk 0. Got 32645 new functions
Processing chunk 1
Getting new functions
Done processing chunk 1. Got 32222 new functions
Processing chunk 2
Getting new functions
Done processing chunk 2. Got 32158 new functions
472728/47272886
Processing chunk 3
Getting new functions
Done processing chunk 3. Got 31775 new functions
Processing chunk 4
Getting new functions
Done processing chunk 4. Got 31002 new functions
Processing chunk 5
Getting new functions
Done processing chunk 5. Got 31746 new functions
Processing chunk 6
Getting new functions
Done processing chunk 6. Got 30393 new functions
945456/47272886
Processing chunk 7
Getting new functions
Done processing chunk 7. Got 33611 new functions
Processing chunk 8
Getting new functions
Done processing chunk 8. Got 32778 new functions
Processing chunk 9
Getting new functions
Done processing chunk 9. Got 33583 new functions
Processing chunk 10
Getting new functions
unable to acces

In [35]:
ds = new_ds

In [36]:
ds

Dataset({
    features: ['content', 'id'],
    num_rows: 397062
})

### SEED GATHERING HIGH-QUALITY SUBSET

In [63]:
import subprocess
import tempfile
import signal
import hashlib
import os
import argparse
from typing import List, Dict
from tqdm import tqdm
from tree_sitter_parser import LANGUAGE, global_parser

RETURN_QUERY = LANGUAGE.query("""
(return_statement) @return
""")

def does_have_return(src):
    tree = global_parser.parse(bytes(src, "utf8"))
    root = tree.root_node
    captures = RETURN_QUERY.captures(root)
    for node, _ in captures:
        # if it doesn't have an argument, it's not a return with a value
        if len(node.children) <= 1:  # includes "return" itself
            continue
        else:
            return True
    return False

# runs mypy in the given directory, returns stdout
# then, it logs the number of errors for each file
def run_mypy(d):
    try:
        outs = subprocess.run(
            ["mypy", "."],
            cwd=d,
            capture_output=True,
            timeout=120,
            text=True,
        ).stdout
    except Exception as e:
        print(e)
        return None

    filemap = {}
    lines = outs.split("\n")
    for line in lines:
        if line.strip():
            parts = line.split(":")
            if len(parts) >= 2:
                file = parts[0].split("/")[-1]
                if file not in filemap:
                    filemap[file] = 0
                if "error:" in line:
                    filemap[file] += 1

    return filemap

def typecheck_batch(files: List[str]) -> Dict[str, str]:
    # Create a temporary directory using the tempfile module
    filemap: Dict[str, str] = {}
    with tempfile.TemporaryDirectory() as tempdir:
        for contents in files:
            hash_object = hashlib.sha1(bytes(contents, "utf8"))
            hex_dig = hash_object.hexdigest()
            filemap[hex_dig] = contents
            name = os.path.join(tempdir, hex_dig + ".py")
            with open(name, "w") as f:
                f.write(contents)

        # Run mypy in the temporary directory
        typecheck_map = run_mypy(tempdir)
        print(typecheck_map)

        if typecheck_map is None:
            return {}

        for contents, errors in typecheck_map.items():
            no_py = contents.replace(".py", "")
            if errors == 0:
                continue
            if no_py in filemap:
                del filemap[no_py]

        print(f"Pass rate: {len(filemap)}/{len(files)}")
        return filemap

def infer_imports(code: str) -> str:
    import autoimport
    try:
        def handler(signum, frame):
            raise Exception("Timeout")
        signal.signal(signal.SIGALRM, handler)
        signal.alarm(10)
        inferred = autoimport.fix_code(code)
        signal.alarm(0)
        return inferred
    except Exception as e:
        signal.alarm(0)
        print(f"Error while inferring imports: {e}")
        return code

In [37]:
print("Filtering to only functions with return statements")
ds = ds.filter(lambda ex: does_have_return(
    ex["content"]), num_proc=os.cpu_count())




Filtering to only functions with return statements


Filter (num_proc=128):   0%|          | 0/397062 [00:00<?, ? examples/s]

In [42]:
ds

Dataset({
    features: ['content', 'id'],
    num_rows: 279231
})

In [56]:
# if args.infer_imports:
#     print("Inferring imports for functions")
#     ds = ds.map(lambda ex: {"content": infer_imports(
#         ex["content"])}, num_proc=os.cpu_count())

batch = []
max_i = len(ds) - 1

new_ds = {
    "content": [],
    "sha1": [],
    "id": [],
}

e_id = 0

for i, ex in enumerate(tqdm(ds, total=len(ds))):
    try:
        code = ex["content"]

        batch.append(code)

        if len(batch) == 250 or i == max_i:
            filemap = typecheck_batch(batch)
            for sha1, contents in filemap.items():
                new_ds["content"].append(contents)
                new_ds["sha1"].append(sha1)
                new_ds["id"].append(e_id)
                e_id += 1
            batch = []
            
    except Exception as e:
        print(f"There was an error: {e}")
        continue

new_ds_hf = datasets.Dataset.from_dict(new_ds)

In [61]:
print(new_ds_hf['content'][0])

def gibberish(*args):
    """Concatenate strings in *args together."""
    
    # Initialize an empty string: hodgepodge
    hodgepodge = ''

    # Concatenate the strings in args
    for word in args:
        hodgepodge += word

    # Return hodgepodge
    return(hodgepodge)


In [64]:
save_dir = "../datasets/seed2"

In [65]:
new_ds_hf.save_to_disk(save_dir)

Saving the dataset (0/1 shards):   0%|          | 0/278081 [00:00<?, ? examples/s]

### SEED GATHERING FILTER DATASET

In [66]:
import datasets
import os
from tree_sitter_parser import global_parser, LANGUAGE, does_have_return, make_parser
import benchmark_data
from tqdm import tqdm
import torch
import argparse
from vllm import LLM, SamplingParams
import random

In [67]:
FN_BLOCK_QUERY = LANGUAGE.query("""
(function_definition
  body: (block) @fn-block)
""")


def template_few_shot(code, answer, rationale):
    doc, code = py_extract_docstring(code)
    assert answer == "No" or answer == "Yes"
    prompt = f"""<issue_start>username_0: I have a function in Python and I'd like someone to check my description of this function.
I'm doing this so that I can write a good docstring for this function.

Here is the code for the function:
```py
{code}
```

Here is my description of this program:
```
{doc}
```

Do not attempt to execute the function or to judge its correctness.
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function.
Also, answer with "No" if the description does not match the function.<issue_comment>username_1: Sure, no problem. I will be able to help.
My answer is: {answer}

{rationale}

Upvotes: 200"""
    return prompt


FEW_SHOTS = [
    (
        '''def simple_scan_network():
    """
    Do a simple network scan, which only works if your network configuration
    is 192.168.1.x
    """
    base_ip = "192.168.1."
    addresses = ['127.0.0.1']

    for index in range(1, 255):
        addresses.extend([base_ip + str(index)])

    return addresses''',
        "No",
        "The simple_scan_network function you have provided seems to generate addresses that then would be used for a network scan, but does not actually perform it, unlike the function claims.",
    ),
    (
        '''import pandas


def coerce_integer(df):
    """
    Loop through the columns of a df, if it is numeric,
    convert it to integer and fill nans with zeros.
    This is somewhat heavy-handed in an attempt to force
    Esri to recognize sparse columns as integers.
    """
    # Numeric columns to not coerce to integer
    EXCEPT = ["latitude", "longitude", "zipCode"]

    def numeric_column_to_int(series):
        return (
            series.fillna(0).astype(int)
            if pandas.api.types.is_numeric_dtype(series) and series.name not in EXCEPT
            else series
        )

    return df.transform(numeric_column_to_int, axis=0)''',
        "Yes",
        "The docstring does seem to match the implementation! The function loops through the columns of a df and coerces it as explained.",
    ),
    ('''def __trans_df_into_dict(data):
    """Converte DataFrame to dictionary.

    Args:
        data (pandas.DataFrame): DataFrame.

    Returns:
        dict: Name dictionary.
    """
    data["en_name"] = data["en_name"].str.upper()
    data["en_name_f"] = data["en_name"].str.split(" ", expand=True)[0]
    data["en_name_l"] = data["en_name"].str.split(" ", expand=True)[1]
    data["jp_name_f"] = data["jp_name"].str.split("・", expand=True)[0]
    data["jp_name_l"] = data["jp_name"].str.split("・", expand=True)[1]
    fullname_dict = dict(zip(data["en_name"], data["jp_name"]))
    fname_dict = dict(zip(data["en_name_f"], data["jp_name_f"]))
    lname_dict = dict(zip(data["en_name_l"], data["jp_name_l"]))
    return fullname_dict, fname_dict, lname_dict''',
     "No",
     "The function__trans_df_into_dict  does indeed convert a dataframe into a dictionary, however, it converts various columns that were not described in the docstring.\nFor instance, nowhere in the docstring it mentions handling japanese characters or the name of the column.",
     ),
    (
        '''def inchesToMeters(inches):
    """Convert inches to meters."""
    return inches * 0.0254''',
        "Yes",
        "inchesToMeters is a very simple function, the doccstring explains concisely its purpose, which is of converting inches to meters.",
    ),
    ('''def square_crop(im, target_size=None):
  """ Crop image to `target_size`. If that's None the image is squared
  to the smallest size
  """

  w = im.size[0]
  h = im.size[1]

  target_size = target_size if target_size else min(w, h)

  dx = (w - target_size) / 2
  dy = (h - target_size) / 2

  return im.crop((dx, dy, dx + target_size, dy + target_size))''',
     "Yes",
     "Following the standard description for docstrings for functions and methods, the square_crop function description tells exactly what the function does."
     ),
    ('''def _setup_motifs_files(args):
    """convenience fn, make sure setup is same across
    multiplicity/orientation/spacing workflows
    """
    motifs_files = {}
    motifs_files["early"] = "{}/{}/ggr.scanmotifs.h5".format(
        args.inputs["inference"][args.cluster]["scanmotifs_dir"],
        args.inputs["inference"][args.cluster]["scanmotifs_early_dir"])
    motifs_files["mid"] = "{}/{}/ggr.scanmotifs.h5".format(
        args.inputs["inference"][args.cluster]["scanmotifs_dir"],
        args.inputs["inference"][args.cluster]["scanmotifs_mid_dir"])
    motifs_files["late"] = "{}/{}/ggr.scanmotifs.h5".format(
        args.inputs["inference"][args.cluster]["scanmotifs_dir"],
        args.inputs["inference"][args.cluster]["scanmotifs_late_dir"])

    return motifs_files''',
     "No",
     "The docstring for _setup_motifs_files just says this is a convenience function. There is definitely not enough information to re-implement this function from the docstring alone.",
     ),
    ('''def trip(u, v):
    """
    Returns the scalar triple product of vectors u and v and z axis.
    The convention is z dot (u cross v). Dotting with the z axis simplifies
    it to the z component of the u cross v
    The product is:
        positive if v is to the left of u, that is,
          the shortest right hand rotation from u to v is ccw
        negative if v is to the right of u, that is,
          the shortest right hand rotation from u to v is cw
        zero if v is colinear with u
    Essentially trip is the z component of the cross product of u x v
    """
    return (u[0] * v[1] - u[1] * v[0])''',
     "Yes",
     "The docstring for the trip function is very detailed and describes the function's purpose and the mathematical formula used to calculate the scalar triple product.",
     )
]


def prompt_fmt(code):
    doc, code = py_extract_docstring(code)
    random.shuffle(FEW_SHOTS)
    buf = ""
    for few in FEW_SHOTS:
        buf += template_few_shot(*few)
    buf += f"""<issue_start>username_0: I have a function in Python and I'd like someone to check my description of this function.
I'm doing this so that I can write a good docstring for this function.

Here is the code for the function:
```py
{code}
```

Here is my description of this program:
```
{doc}
```

Do not attempt to execute the function or to judge its correctness.
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function.
Also, answer with "No" if the description does not match the function.
Upvotes: 100<issue_comment>username_1: Sure, no problem. I will be able to help.
My answer is:"""
    return buf


def auto_dtype():
    if torch.cuda.is_bf16_supported():
        return "bfloat16"
    return "auto"


def chunkify(lst, n):
    chunks = []
    for i in range(0, len(lst), n):
        chunk = []
        for j in range(n):
            if i + j < len(lst):
                chunk.append(lst[i + j])
        chunks.append(chunk)
    return chunks


In [71]:
dataset = new_ds_hf

In [72]:
print(f"Loaded {len(dataset)} examples. Running pre-filtering...")

BAD_WORDS = ["todo", "fixme", "bug"]
BAD_IMPORTS = ["argparse", "os", "subprocess", "sys", "setuptools",
               "distutils", "matplotlib", "seaborn"]
BAD_IMPORTS = [f"import {b}" for b in BAD_IMPORTS] + \
    [f"from {b}" for b in BAD_IMPORTS]
BAD_SUBSTRINGS = BAD_WORDS + BAD_IMPORTS

bench_filter = benchmark_data.filter_out()
all_bench = bench_filter["human_eval_docstrings"] + \
    bench_filter["human_eval_solutions"] + \
    bench_filter["mbpp_docstrings"] + \
    bench_filter["mbpp_solutions"]

Loaded 278081 examples. Running pre-filtering...


Downloading readme:   0%|          | 0.00/9.06k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/33.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/60.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/14.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.72k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/257 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/43 [00:00<?, ? examples/s]

Generating prompt split:   0%|          | 0/7 [00:00<?, ? examples/s]

Downloading readme:   0%|          | 0.00/6.52k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/83.9k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/164 [00:00<?, ? examples/s]

num strings from mbpp_docstrings: 120
num strings from mbpp_solutions: 120
num strings from human_eval_docstrings: 164
num strings from human_eval_solutions: 161


In [74]:
def pre_filtering(ex):
    code = ex["content"]
    code_bytes = code.encode('utf-8')

    # filter out bad substrings
    lower = code.lower()
    for word in BAD_SUBSTRINGS:
        if word in lower:
            return False

    for b in all_bench:
        if b in code:  # contaminated sample!
            return False

    # too many lines of code -- say 150
    lines = code.split("\n")
    if len(lines) > 150:
        return False

    # filter functions which don't have an argument
    # 1. find first def statement in lines
    # 2. check if contains ():
    for line in lines:
        if line.startswith("def ") and "():" in line:
            return False

    # filter out functions with no return statement
    parser = make_parser()
    if not does_have_return(code, parser=parser):
        return False

    try:
        tree = global_parser.parse(code_bytes)
        block, _ = FN_BLOCK_QUERY.captures(tree.root_node)[0]

        # get the docstring, filter if not a docstring
        exp = block.children[0]
        if not exp.type == 'expression_statement' and not exp.children[0].type == 'string':
            return False

        docstring = exp.children[0]
        docstring_text = docstring.text.decode('utf-8')
        if not docstring_text.startswith('"""') and not docstring_text.endswith('"""'):
            return False
    except Exception as e:
        print(f"Error in filtering: {e}")
        return False

    return True  # all good!


threads = os.cpu_count() - 1  # type: ignore
dataset = dataset.filter(pre_filtering, num_proc=threads)

Filter (num_proc=127):   0%|          | 0/278081 [00:00<?, ? examples/s]

In [75]:
dataset

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 236926
})

In [76]:
model = LLM(f"../../../StarCoder", dtype=auto_dtype(),
            gpu_memory_utilization=0.95, tensor_parallel_size=1)


INFO 10-11 13:21:28 config.py:1430] Downcasting torch.float32 to torch.bfloat16.
INFO 10-11 13:21:28 llm_engine.py:176] Initializing an LLM engine (v0.5.3) with config: model='../../../StarCoder', speculative_config=None, tokenizer='../../../StarCoder', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=16384, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=../../../StarCoder, use_v2_block_manager=False, enable_prefix_caching=False)
INFO 10-11 13:21:30 selector.py:170] Cannot use FlashAttention-2 backend d

Loading safetensors checkpoint shards:   0% Completed | 0/14 [00:00<?, ?it/s]


INFO 10-11 13:27:56 model_runner.py:692] Loading model weights took 29.7278 GB
INFO 10-11 13:28:08 gpu_executor.py:102] # GPU blocks: 35262, # CPU blocks: 3276
INFO 10-11 13:28:10 model_runner.py:980] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 10-11 13:28:10 model_runner.py:984] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 10-11 13:28:32 model_runner.py:1181] Graph capturing finished in 23 secs.


In [77]:
tokenizer = model.get_tokenizer()

In [78]:
print(f"Now running stage 3 filtering on {len(dataset)} examples...")

Now running stage 3 filtering on 236926 examples...


In [79]:
def unindent(s):
    lines = s.splitlines()
    non_blank_lines = [line for line in lines if line.strip()]
    min_indent = min(len(line) - len(line.lstrip())
                     for line in non_blank_lines) if non_blank_lines else 0
    unindented_lines = [line[min_indent:] if len(
        line) >= min_indent else line for line in lines]
    return '\n'.join(unindented_lines)


def py_extract_docstring(code):
    first_doc = code.find('"""')
    assert first_doc != -1
    first_doc = first_doc + 3
    second_doc = code[first_doc+1:].find('"""')
    assert second_doc != -1
    second_doc = second_doc + first_doc + 1
    doc = code[first_doc:second_doc]
    doc = unindent(doc).strip()
    code = code[:first_doc-3] + code[second_doc+3:]
    return doc, code



In [80]:
dummy = 'def dummy(): \n    """\n    """\n pass'
dummy_prompt = prompt_fmt(dummy)
few_shot_toks = len(tokenizer.encode(
    dummy_prompt)) - len(tokenizer.encode(dummy))
print(f"Few-shot prompt has {few_shot_toks} tokens")

Few-shot prompt has 2573 tokens


In [None]:
prompts = []
for ex in tqdm(dataset, total=len(dataset), desc="Generating prompts"):
    code = ex["content"]
    toks = len(tokenizer.encode(code)) + few_shot_toks
    if toks > 16380:
        print(f"Skipping example with {toks} tokens")
        # to skip, just add dummy prompt
        prompts.append(dummy_prompt)
        continue
    p = prompt_fmt(code)
    prompts.append(p)

responses = []
for chunk in tqdm(chunkify(prompts, 512), desc="Generating responses"):
    outs = model.generate(chunk, SamplingParams(
        temperature=0.0, stop="\n", max_tokens=5))
    contents = [o.outputs[0].text for o in outs]
    for c in contents:
        yes_count = c.lower().count("yes")
        no_count = c.lower().count("no")
        if yes_count > no_count:
            responses.append(True)
        elif yes_count < no_count:
            responses.append(False)
        else:
            # default to No
            responses.append(False)



Generating prompts:  15%|██████████████████████▉                                                                                                                                  | 35479/236926 [00:21<02:06, 1588.67it/s]

Skipping example with 29237 tokens


Generating prompts:  35%|█████████████████████████████████████████████████████▊                                                                                                   | 83335/236926 [00:49<01:36, 1595.44it/s]

Skipping example with 20286 tokens


Generating prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 236926/236926 [02:22<00:00, 1666.63it/s]
Generating responses:   0%|                                                                                                                                                                        | 0/463 [00:00<?, ?it/s]
Processed prompts:   0%|                                                                                                                       | 0/512 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][A
Processed prompts:   0%|▏                                                                                                          | 1/512 [01:10<10:04:01, 70.92s/it, est. speed input: 37.29 toks/s, output: 0.03 toks/s][A
Processed prompts:  39%|████████████████████████████████████████▉                                                 

In [90]:
dataset

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 236926
})

In [91]:
subset = dataset.select(range(75000))

In [92]:
subset

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 75000
})

In [93]:
new_ds = subset.filter(  # horrible hack!
    lambda ex, i: responses[i] and "def dummy()" not in ex["content"], with_indices=True)
print(f"Filtered {len(dataset) - len(new_ds)} examples")

Filter:   0%|          | 0/75000 [00:00<?, ? examples/s]

Filtered 233319 examples


In [97]:
new_ds.save_to_disk("../datasets/seed3")

Saving the dataset (0/1 shards):   0%|          | 0/3607 [00:00<?, ? examples/s]

In [98]:
new_ds

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 3607
})