### SEED GATHERING GET CONTENT

In [1]:
#++++++++++++++++++++++++++++ Run this first time if you haven't installed from requirements.txt file/cloned the repo+++++++++++++++++++++
# !pip install tree-sitter==0.20.4
# !git clone https://github.com/tree-sitter/tree-sitter-cpp

In [2]:
from tree_sitter_parser import LANGUAGE, make_parser, node_to_string
import datasets
import os
import signal
from multiprocessing import Pool
#import os
import boto3
import smart_open
#from datasets import load_dataset,Dataset
from botocore import UNSIGNED
from botocore.config import Config

s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
def download_contents(blob_id, src_encoding):
    s3_url = f"s3://softwareheritage/content/{blob_id}"
    with smart_open.open(s3_url, "rb", compression=".gz", transport_params={"client": s3}) as fin:
        content = fin.read().decode(src_encoding)
    
    return content

In [3]:
TOPLEVEL_DOC_COMMENT_QUERY = LANGUAGE.query("""
(
  (function_definition
    declarator: (function_declarator
      declarator: (identifier) @fn-name
    )
    body: (compound_statement
      (comment) @doc.comment
    )
  ) @function.def
)
""")

'''
def get_fns_with_docstrings(src, tree):
    captures = TOPLEVEL_DOC_COMMENT_QUERY.captures(tree.root_node)
    res = []
    for capture in captures:
        node, ty = capture
        if ty != "function.def":
            continue
        # if the starting col is not 0, then it's not a top-level fn
        _, col = node.start_point
        if col != 0:
            continue
        res.append(node_to_string(src, node))
    return res
'''

def get_fns_with_docstrings(src, tree):
    captures = TOPLEVEL_DOC_COMMENT_QUERY.captures(tree.root_node)
    res = []
    current = {"function_node": None, "name": None, "doc": None}

    for node, capture_name in captures:
        if capture_name == "fn-name":
            current["name"] = node_to_string(src, node)

        elif capture_name == "doc.comment":
            current["doc"] = node_to_string(src, node)

        elif capture_name == "function.def":
            current["function_node"] = node

            # Build the result once we have everything
            if current["name"] and current["doc"]:
                full_func_text = node_to_string(src, current["function_node"])
                res.append({
                    "function_name": current["name"],
                    "docstring": current["doc"],
                    "code": full_func_text
                })

            # Reset for next
            current = {"function_node": None, "name": None, "doc": None}

    return res

def parse_ex(parser, ex):
    #ex = ex["content"]
    ex = download_contents(ex["blob_id"], ex["src_encoding"])
    try:
        buf = bytes(ex, "utf8")
        tree = parser.parse(buf)
        return get_fns_with_docstrings(buf, tree)
    except:
        return []


# if one parser segfaults, we can just make a new one and other parsers will still be fine
# WE LOVE TREE SITTER!
PARSERS = None


def process_chunk(idx_and_chunk):
    assert PARSERS is not None
    idx, chunk = idx_and_chunk
    parser = PARSERS[idx]
    chunk_new_funs = set()
    for ex in chunk:
        chunk_new_funs.update(parse_ex(parser, ex))
    return chunk_new_funs


def main(args):
    global PARSERS
    ds = datasets.load_dataset(
        args.dataset,
        data_dir=args.data_dir,
        split="train",
    )
    funs = set()
    PARSERS = [make_parser() for _ in range(args.num_workers)]
    total_len = len(ds)
    CHUNK_SIZE = 1000 * args.num_workers

    print(f"Total length: {total_len}")
    print(f"Chunk size: {CHUNK_SIZE}")

    chunk = []
    p = Pool(args.num_workers)
    for i, ex in enumerate(ds):
        if i % (total_len // 100) == 0:
            print(f"{i}/{total_len}")
        try:
            chunk.append(ex)
            if len(chunk) == CHUNK_SIZE or i == total_len - 1:
                print(f"Processing chunk {i // CHUNK_SIZE}")
                # divide the chunk into NUM_WORKERS chunks
                subchunk_size = len(chunk) // args.num_workers
                subchunks = [chunk[i:i + subchunk_size]
                             for i in range(0, len(chunk), subchunk_size)]
                new_funs_iter = p.imap(
                    process_chunk, [(i, subchunk) for i, subchunk in enumerate(subchunks)])
                print("Getting new functions")
                len_before = len(funs)
                while True:
                    try:
                        def timeout_handler(_, __):
                            raise KeyboardInterrupt  # it's fineeeeeee
                        signal.signal(signal.SIGALRM, timeout_handler)
                        signal.alarm(60)
                        funs.update(next(new_funs_iter))
                        signal.alarm(0)
                    except KeyboardInterrupt:
                        signal.alarm(0)
                        print("Keyboard interrupt. Terminating pool")
                        p.terminate()
                        p = Pool(args.num_workers)
                        break
                    except StopIteration:
                        break
                    except Exception as e:
                        print(e)

                signal.alarm(0)

                PARSERS = [make_parser() for _ in range(args.num_workers)]

                print(
                    f"Done processing chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions")

                chunk = []
        except Exception as e:
            print(e)
            chunk = []

        if i == total_len - 1:
            break

    p.close()

    new_ds_dict = {
        "content": list(funs),
        "id": list(range(len(funs)))
    }

    new_ds = datasets.Dataset.from_dict(new_ds_dict)
    #new_ds.push_to_hub(args.push, private=True)




In [4]:
# NUMWORKERS = os.cpu_count()
NUMWORKERS = 2

In [5]:
ds = datasets.load_dataset("bigcode/the-stack-v2-dedup", "C++", cache_dir=f"./cache/stack", streaming=False, split="train[:1000]")

Resolving data files:   0%|          | 0/757 [00:00<?, ?it/s]

In [6]:
# from itertools import islice

# small_subset = islice(ds, 50)

# # Convert to list if you want to materialize it (use with caution, as this loads into memory)
# ds = list(small_subset)

In [7]:
# Setup a single Parser
funs = set()

parser = make_parser()

for i in range(5):
# Take one example manually
    ex = ds[i]  # First example (directly)

    # Download content if needed
    content = download_contents(ex["blob_id"], ex["src_encoding"])

    # Parse
    src = bytes(content, "utf8")
    tree = parser.parse(src)

    # Extract functions
    functions = get_fns_with_docstrings(src, tree)
    #funs.update(functions)
    # Print results
    print(f"Extracted functions in {i} and doc-comments:")
    for fn in functions:
        print(fn)

Extracted functions in 0 and doc-comments:
Extracted functions in 1 and doc-comments:
Extracted functions in 2 and doc-comments:
Extracted functions in 3 and doc-comments:
Extracted functions in 4 and doc-comments:


In [8]:
# Setup
def process_chunk(idx_and_chunk):
    assert PARSERS is not None
    idx, chunk = idx_and_chunk
    parser = PARSERS[idx]
    chunk_new_funs = set()
    
    for ex in chunk:
        functions = parse_ex(parser, ex)
        for fn in functions:
            chunk_new_funs.add(str(fn))  # <=== Fix here

    return chunk_new_funs




NUMWORKERS = 1
CHUNK_SIZE = 10  # Adjust if needed

PARSERS = [make_parser() for _ in range(NUMWORKERS)]

funs = set()
chunk = []

total_len = len(ds)

print(f"Total dataset size: {total_len}")

# Loop over dataset
for i, ex in enumerate(iter(ds)):
    if i % (max(total_len // 100, 1)) == 0:
        print(f"{i}/{total_len}")

    chunk.append(ex)

    if len(chunk) >= CHUNK_SIZE or i == total_len - 1:
        print(f"\nProcessing chunk {i // CHUNK_SIZE}...")

        # Split the chunk into subchunks
        subchunk_size = max(1, len(chunk) // NUMWORKERS)
        subchunks = [chunk[j:j + subchunk_size] for j in range(0, len(chunk), subchunk_size)]

        len_before = len(funs)

        # Sequentially process each subchunk using process_chunk
        for idx, subchunk in enumerate(subchunks):
            chunk_funs = process_chunk((idx, subchunk))
            funs.update(chunk_funs)

        print(f"✅ Done chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions.")

        chunk = []  # Reset chunk
        PARSERS = [make_parser() for _ in range(NUMWORKERS)]  # Rebuild parsers if needed

# Final dataset creation
print(f"\nTotal unique functions collected: {len(funs)}")

new_ds_dict = {
    "content": list(funs),
    "id": list(range(len(funs)))
}

new_ds = datasets.Dataset.from_dict(new_ds_dict)


Total dataset size: 1000
0/1000

Processing chunk 0...
✅ Done chunk 0. Got 14 new functions.
10/1000

Processing chunk 1...
✅ Done chunk 1. Got 0 new functions.
20/1000

Processing chunk 2...
✅ Done chunk 2. Got 1 new functions.
30/1000

Processing chunk 3...
✅ Done chunk 3. Got 0 new functions.
40/1000

Processing chunk 4...
✅ Done chunk 4. Got 4 new functions.
50/1000

Processing chunk 5...
✅ Done chunk 5. Got 1 new functions.
60/1000

Processing chunk 6...
✅ Done chunk 6. Got 0 new functions.
70/1000

Processing chunk 7...
✅ Done chunk 7. Got 10 new functions.
80/1000

Processing chunk 8...
✅ Done chunk 8. Got 3 new functions.
90/1000

Processing chunk 9...
✅ Done chunk 9. Got 0 new functions.
100/1000

Processing chunk 10...
✅ Done chunk 10. Got 0 new functions.
110/1000

Processing chunk 11...
✅ Done chunk 11. Got 0 new functions.
120/1000

Processing chunk 12...
✅ Done chunk 12. Got 0 new functions.
130/1000

Processing chunk 13...
✅ Done chunk 13. Got 0 new functions.
140/1000



In [9]:
new_ds.save_to_disk("./extracted_functions_cpp")
new_ds[7]

Saving the dataset (0/1 shards):   0%|          | 0/210 [00:00<?, ? examples/s]

{'content': "{'function_name': 'systrace_init', 'docstring': '// already initialized', 'code': 'int systrace_should_trace(const char *module)\\n{\\n    // hack this if you want to temporarily omit some traces.\\n    return 1;\\n}'}",
 'id': 7}

In [10]:
from datasets import load_from_disk

ds = load_from_disk("./extracted_functions_cpp")

# for example in ds:
#     print(example)

In [11]:
import os
from datasets import Dataset
from tree_sitter_parser import LANGUAGE, make_parser, does_have_return

# Define return-statement query for C++
RETURN_QUERY = LANGUAGE.query("""
(return_statement) @return
""")

# Use a Tree-sitter parser set to C++
parser = make_parser()

# Filter dataset to only functions with meaningful return values
def filter_cpp_functions_with_return(ds: Dataset) -> Dataset:
    filtered_ds = []
    for i in ds:
        if does_have_return(i["content"], parser):
            filtered_ds.append(i)
    return filtered_ds


In [12]:
print("Filtering to only C++ functions with return statements...")
filtered_ds = filter_cpp_functions_with_return(ds)
print(f"✅ Filtered dataset size: {len(filtered_ds)}")


Filtering to only C++ functions with return statements...
✅ Filtered dataset size: 73


In [13]:
from datasets import Dataset

# Convert the list into a Dataset
filtered_ds = Dataset.from_list(filtered_ds)

# Now you can save it to disk
filtered_ds.save_to_disk("./functions_with_return_cpp")

Saving the dataset (0/1 shards):   0%|          | 0/73 [00:00<?, ? examples/s]

In [14]:
#Code to filter out functions that have valid types is next but I'm not sure we need it for c++ as functions are already typed
#Should check with TA if it is needed

In [15]:
ds = load_from_disk("./functions_with_return_cpp")

In [16]:
import ast

# content = ast.literal_eval(ds[0]["content"])
# print(content['code'], '\n')

for i in range(73):
    try:
        content = ast.literal_eval(ds[i]["content"])
        print(content['function_name'])
        print(content['code'])
    except Exception as e:
        print(f"Error parsing content: {e}")
        continue



FunkData_Temp_PWM
boolean FunkData_clickColor(int RadA, int RadB)
{
  // Send data:
  RF24NetworkHeader header(FunkMasterSwitchcabinet);   // Address where the data is going
  dataOutgoing.header = 2; // Farbrad
  dataOutgoing.val1 = RadA; // RAD A
  dataOutgoing.val2 = RadB; // RAD B
  dataOutgoing.val3 = 0; // NA
  bool ok = network.write(header, &dataOutgoing, sizeof(dataOutgoing)); // Send the data

  if (ok == true)
  {
    Serial.println("Funk DATA Color A!");
    return true;
  }
  else
  {
    return false;
  }
}
Merge
int main()
{
    int a[] = {6, 2, 3, 1, 9, 10, 15, 13, 12, 17}; // creating an array of integers.
    int n;
    n = sizeof(a) / sizeof(a[0]);
    MergeSort(a, n);

    for (int atom : a)
    {
        cout << atom << " ";
    }
    return 0;
}
systrace_init
int systrace_should_trace(const char *module)
{
    // hack this if you want to temporarily omit some traces.
    return 1;
}
CDC_Receive_FS
uint8_t CDC_Transmit_FS(uint8_t* Buf, uint16_t Len)
{
    uint8_t r

In [17]:
#You have to get the code by ast.literal_eval(ds[i]["content"])['code]

#######################################################################
#                     Part 3
########################################################################

In [18]:
import datasets
import os
from tree_sitter_parser import global_parser, LANGUAGE, does_have_return, make_parser
#import benchmark_data
from tqdm import tqdm
import torch
import argparse
#from vllm import LLM, SamplingParams
import random

In [19]:
FN_BLOCK_QUERY = LANGUAGE.query("""
(function_definition
  body: (compound_statement) @fn-block)
""")


def template_few_shot(code, answer, rationale):
    doc, code = cpp_extract_docstring(code)
    assert answer == "No" or answer == "Yes"
    prompt = f"""<issue_start>username_0: I have a function in C++ and I'd like someone to check my description of this function.
I'm doing this so that I can write a good docstring for this function.

Here is the code for the function:
```py
{code}
```

Here is my description of this program:
```
{doc}
```

Do not attempt to execute the function or to judge its correctness.
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function.
Also, answer with "No" if the description does not match the function.<issue_comment>username_1: Sure, no problem. I will be able to help.
My answer is: {answer}

{rationale}

Upvotes: 200"""
    return prompt


FEW_SHOTS = [
    (''' 
    /**
    * Transposes a given 2D matrix of integers.
    * Input: A 2D vector representing the matrix.
    * Output: A new 2D vector with rows and columns swapped.
    */
    std::vector<std::vector<int>> transposeMatrix(const std::vector<std::vector<int>>& matrix) {
        if (matrix.empty()) return {};

        size_t rows = matrix.size();
        size_t cols = matrix[0].size();
        std::vector<std::vector<int>> transposed(cols, std::vector<int>(rows));

        for (size_t i = 0; i < rows; ++i)
            for (size_t j = 0; j < cols; ++j)
                transposed[j][i] = matrix[i][j];

        return transposed;
    }''',
    "Yes",
    "The docstring clearly describes the input, output, and behavior of the function without ambiguity."
    ),
    ('''
    /**
    * Helper function to open input and output file streams.
    */
    void initFileStreams(const std::string& inputFile, const std::string& outputFile,
                        std::ifstream& inStream, std::ofstream& outStream) {
        inStream.open(inputFile);
        outStream.open(outputFile);
        if (!inStream.is_open() || !outStream.is_open()) {
            std::cerr << "Failed to open files.\n";
        }
    }''',
    "No",
    "The docstring only says it's a helper without describing what files are opened, how, or error behavior. It’s vague."
    ),
    (''' 
    /**
    * Reads a text file and counts the frequency of each word (case-insensitive).
    * Returns a map where keys are words and values are frequency counts.
    */
    std::unordered_map<std::string, int> countWordFrequencies(const std::string& filePath) {
        std::unordered_map<std::string, int> wordCount;
        std::ifstream file(filePath);
        std::string word;

        if (!file.is_open()) {
            std::cerr << "Error opening file.\n";
            return wordCount;
        }

        while (file >> word) {
            // Convert to lowercase
            for (char& c : word) c = std::tolower(c);
            // Remove punctuation
            word.erase(std::remove_if(word.begin(), word.end(),
                                    [](char ch) { return std::ispunct(ch); }),
                    word.end());
            if (!word.empty())
                ++wordCount[word];
        }

        return wordCount;
    }''',
    "Yes",
    "The docstring clearly describes the purpose, behavior, and output of the function, including key aspects like case-insensitivity."
        ),
    ('''
     /**
    * Checks whether the given string is a palindrome.
    * Ignores case and non-alphanumeric characters.
    */
    bool isPalindrome(const std::string& s) {
        int left = 0;
        int right = s.length() - 1;

        while (left < right) {
            while (left < right && !std::isalnum(s[left])) ++left;
            while (left < right && !std::isalnum(s[right])) --right;

            if (std::tolower(s[left]) != std::tolower(s[right]))
                return false;

            ++left;
            --right;
        }

        return true;
    }''',
    "Yes",
    "The docstring is concise but accurately specifies the functionality and relevant behavior like ignoring non-alphanumerics and case."
    ),
    ('''
     /**
    * Resizes a raw RGB image buffer by scaling it to a target width and height.
    */
    unsigned char* resizeImage(const unsigned char* inputBuffer, int width, int height,
                            int newWidth, int newHeight) {
        if (!inputBuffer || width <= 0 || height <= 0 || newWidth <= 0 || newHeight <= 0)
            return nullptr;

        unsigned char* output = new unsigned char[newWidth * newHeight * 3];

        for (int y = 0; y < newHeight; ++y) {
            for (int x = 0; x < newWidth; ++x) {
                int srcX = x * width / newWidth;
                int srcY = y * height / newHeight;
                for (int c = 0; c < 3; ++c) {
                    output[(y * newWidth + x) * 3 + c] =
                        inputBuffer[(srcY * width + srcX) * 3 + c];
                }
            }
        }

        return output;
    }
    ''',
    "No",
    "The docstring mentions resizing but omits that this is a naive nearest-neighbor implementation with no memory management or edge handling info."
    )
]

import re

# def cpp_extract_docstring(code):
#     """
#     Extracts the first C++-style docstring (/** ... */) from the given code string.
#     Returns a tuple of (docstring_content, remaining_code).
#     """
#     # Look for /** ... */ using regex
#     match = re.search(r'/\*\*(.*?)\*/', code, re.DOTALL)
#     assert match, "No C++-style docstring found"

#     doc = match.group(1)

#     # Clean up leading * characters and whitespace
#     lines = doc.strip().split('\n')
#     cleaned_lines = []
#     for line in lines:
#         line = line.strip()
#         if line.startswith("*"):
#             line = line[1:].lstrip()
#         cleaned_lines.append(line)

#     cleaned_doc = "\n".join(cleaned_lines).strip()
#     #remaining_code = code[:match.start()] + code[match.end():]

#     return cleaned_doc

def cpp_extract_docstring(code):
    """
    Extracts the first C++-style docstring (/** ... */) from the given code string.
    Returns a tuple of (docstring_content, remaining_code).
    """
    match = re.search(r'/\*\*(.*?)\*/', code, re.DOTALL)
    if not match:
        #logging.warning("No C++-style docstring found.")
        return "No Docstring found", code  # Return empty docstring and original code

    doc = match.group(1)

    # Clean up leading * characters and whitespace
    lines = doc.strip().split('\n')
    cleaned_lines = []
    for line in lines:
        line = line.strip()
        if line.startswith("*"):
            line = line[1:].lstrip()
        cleaned_lines.append(line)

    cleaned_doc = "\n".join(cleaned_lines).strip()
    remaining_code = code[:match.start()] + code[match.end():]

    return cleaned_doc, remaining_code


def prompt_fmt(code):
    doc, code = cpp_extract_docstring(code)
    random.shuffle(FEW_SHOTS)
    buf = ""
    for few in FEW_SHOTS:
        buf += template_few_shot(*few)
    buf += f"""<issue_start>username_0: I have a function in C++ and I'd like someone to check my description of this function.
I'm doing this so that I can write a good docstring for this function.

Here is the code for the function:
```c++
{code}
```

Here is my description of this program:
```
{doc}
```

Do not attempt to execute the function or to judge its correctness.
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function.
Also, answer with "No" if the description does not match the function.
Upvotes: 100<issue_comment>username_1: Sure, no problem. I will be able to help.
My answer is:"""
    return buf


def auto_dtype():
    if torch.cuda.is_bf16_supported():
        return "bfloat16"
    return "auto"


def chunkify(lst, n):
    chunks = []
    for i in range(0, len(lst), n):
        chunk = []
        for j in range(n):
            if i + j < len(lst):
                chunk.append(lst[i + j])
        chunks.append(chunk)
    return chunks


In [20]:
dataset = ds

print(f"Loaded {len(dataset)} examples. Running pre-filtering...")

BAD_COMMENT_KEYWORDS = ["todo", "fixme", "bug", "hack"]
BAD_SUBSTRINGS = [f"// {kw}" for kw in BAD_COMMENT_KEYWORDS]

def cpp_pre_filter(entry):
    try:
        content = entry["content"]
        content_dict = ast.literal_eval(content)
        code = content_dict.get("code", "").strip().lower()

        # 1. Filter out undesirable comments like TODOs or FIXMEs
        BAD_COMMENT_KEYWORDS = ["todo", "fixme", "bug", "hack"]
        if any(f"// {kw}" in code for kw in BAD_COMMENT_KEYWORDS):
            return False

        # 2. Skip functions that are too long
        if len(code.splitlines()) > 200:
            return False

        # 3. Skip overly short functions
        if len(code) < 5:
            return False

        # 4. Skip known problematic includes or system calls
        if any(x in code for x in ["#include <windows.h>", "system("]):
            return False

        # 5. Skip if signature shows no arguments (like void fn())
        for line in code.splitlines():
            if "(" in line and ")" in line and "()" in line.split(")")[0]:
                return False

        return True
    except Exception as e:
        print(f"Filter error: {e}")
        return False


Loaded 73 examples. Running pre-filtering...


In [21]:
import ast 

threads = 1 #os.cpu_count() - 1  # type: ignore
dataset = ds
dataset = dataset.filter(cpp_pre_filter, num_proc=threads)

Filter:   0%|          | 0/73 [00:00<?, ? examples/s]

In [22]:
dataset

Dataset({
    features: ['content', 'id'],
    num_rows: 70
})

In [23]:
# model = LLM(f"../../../StarCoder", dtype=auto_dtype(), gpu_memory_utilization=0.95, tensor_parallel_size=1)
#++++++++++++++++++TA has used above model but vLLM is an ass in Windows so we are going for alternatives
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
#model_name = "bigcode/starcoder"       #Bigger and better model. not to run on local system

model_name = "bigcode/starcoderbase-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # 🔧 Fix for padding
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda")

In [24]:
dummy = '/*Dummy Docstring*/\n int dummy(){ \n    /*\n    */\n return -1;}'
dummy_prompt = prompt_fmt(dummy)
few_shot_toks = len(tokenizer.encode(
    dummy_prompt)) - len(tokenizer.encode(dummy))
print(f"Few-shot prompt has {few_shot_toks} tokens")

Few-shot prompt has 1979 tokens


In [25]:
dummy_prompt

'<issue_start>username_0: I have a function in C++ and I\'d like someone to check my description of this function.\nI\'m doing this so that I can write a good docstring for this function.\n\nHere is the code for the function:\n```py\n\n     \n    unsigned char* resizeImage(const unsigned char* inputBuffer, int width, int height,\n                            int newWidth, int newHeight) {\n        if (!inputBuffer || width <= 0 || height <= 0 || newWidth <= 0 || newHeight <= 0)\n            return nullptr;\n\n        unsigned char* output = new unsigned char[newWidth * newHeight * 3];\n\n        for (int y = 0; y < newHeight; ++y) {\n            for (int x = 0; x < newWidth; ++x) {\n                int srcX = x * width / newWidth;\n                int srcY = y * height / newHeight;\n                for (int c = 0; c < 3; ++c) {\n                    output[(y * newWidth + x) * 3 + c] =\n                        inputBuffer[(srcY * width + srcX) * 3 + c];\n                }\n            }\

In [26]:
prompts = []
for ex in tqdm(dataset, total=len(dataset), desc="Generating C++ prompts"):
    code = ex["content"]

    # Optional: if your content is wrapped in a dict as string, parse it
    if isinstance(code, str) and code.strip().startswith("{"):
        try:
            code = ast.literal_eval(code).get("code", "")
        except Exception as e:
            print(f"Failed to parse code: {e}")
            prompts.append(dummy_prompt)
            continue

    toks = len(tokenizer.encode(code)) + few_shot_toks
    if toks > 16380:
        print(f"Skipping example with {toks} tokens")
        prompts.append(dummy_prompt)
        continue

    try:
        p = prompt_fmt(code)
        prompts.append(p)
    except Exception as e:
        print(f"Failed to generate prompt: {e}")
        prompts.append(dummy_prompt)

# Generate responses
responses = []
for chunk in tqdm(chunkify(prompts[:10], 1), desc="Generating responses"):
    inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True).to("cuda")
    outputs = model.generate(**inputs, max_new_tokens=5, temperature=0.0, eos_token_id=tokenizer.eos_token_id)
    contents = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    for c in contents:
        yes_count = c.lower().count("my answer is: yes")
        no_count = c.lower().count("my answer is: no")
        if yes_count > no_count:
            responses.append(True)
        else:
            responses.append(False)  # Default to "No"


Generating C++ prompts: 100%|██████████| 70/70 [00:00<00:00, 493.26it/s]
Generating responses:   0%|          | 0/10 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Generating responses:  10%|█         | 1/10 [00:18<02:46, 18.49s/it]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Generating responses:  20%|██        | 2/10 [00:34<02:18, 17.28s/it]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Generating responses:  30%|███       | 3/10 [00:50<01:55, 16.43s/it]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Generating responses:  40%|████      | 4/10 [01:35<02:47, 27.90s/it]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Generating responses:  50%|█████     | 5/10 [01:55<02:03, 24.79s/it]Setting `pad_token_id` to `eos_

In [27]:
new_ds = dataset.filter(  # horrible hack!
    lambda ex, i: responses[i], with_indices=True)

new_ds.save_to_disk("./functions_with_return_cpp_filtered")

Filter:   0%|          | 0/70 [00:00<?, ? examples/s]

IndexError: list index out of range

In [None]:
new_ds
responses

[False, True, True, True, True, True, False]