### SEED GATHERING GET CONTENT

In [1]:
!pip install datasets huggingface_hub smart_open[s3] boto3 botocore
!pip install tree-sitter tree-sitter-go
!pip install -v tree-sitter-go
!huggingface-cli login --token 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

Using pip 25.0 from /mmfs1/course/2025/spring/ds/680/md748/ygc2/sae_project/envs/nlp_sae_env/lib/python3.11/site-packages/pip (python 3.11)
The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
The token `Llama3` has been saved to /home/ygc2/.cache/huggingface/stored_tokens
Your token has been saved to /home/ygc2/.cache/huggingface/token
Login successful.
The current active token is: `Llama3`


In [2]:
from datasets import load_dataset, Dataset
import os
import signal
from multiprocessing import Pool
import boto3
import smart_open
from botocore import UNSIGNED
from botocore.config import Config
from tree_sitter import Language, Parser # Import direct tree_sitter

# Importing the language function directly from the installed tree-sitter-go package
try:
    from tree_sitter_go import language as go_language_capsule_func
    go_capsule = go_language_capsule_func()
    print(f"Got capsule from tree_sitter_go.language(): {type(go_capsule)}")

    GO_LANGUAGE = Language(go_capsule)
    print("Successfully loaded Go language using the capsule.")

except ImportError:
    print("Error: Could not import 'language' from 'tree_sitter_go'.")
    print("Ensure 'tree_sitter_go' is installed correctly (pip install tree-sitter-go).")
    print("Installation requires a C compiler (like GCC or Clang).")
    raise
except Exception as e:
    print(f"An unexpected error occurred while loading the Go language: {e}")
    print("Ensure tree-sitter-go installed successfully (requires a C compiler).")
    print("You might need to reinstall it: pip uninstall tree-sitter-go && pip install tree-sitter-go")
    raise

def make_go_parser():
    parser = Parser()
    parser.language = GO_LANGUAGE # Using the loaded language object
    return parser

def node_to_string(source_bytes, node):
    return source_bytes[node.start_byte:node.end_byte].decode('utf-8')

s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
def download_contents(blob_id, src_encoding):
    s3_url = f"s3://softwareheritage/content/{blob_id}"
    try:
        with smart_open.open(s3_url, "rb", compression=".gz", transport_params={"client": s3}) as fin:
            content = fin.read().decode(src_encoding)
    except Exception as e_enc:
        try:
            with smart_open.open(s3_url, "rb", compression=".gz", transport_params={"client": s3}) as fin:
                content = fin.read().decode('utf-8')
        except Exception as e_utf8:
            raise e_utf8
    return content

Got capsule from tree_sitter_go.language(): <class 'PyCapsule'>
Successfully loaded Go language using the capsule.


In [3]:
# Complete (Revised Query/Extraction Logic with Debugging)

import botocore
import traceback
from tree_sitter import Language, Parser # Make sure Parser is imported if not already

# --- Query to find functions (keep it simple for now) ---
# GO_LANGUAGE should be defined from Cell 2 where it was loaded
GO_FUNCTION_QUERY = GO_LANGUAGE.query("""
(source_file
  (function_declaration) @function_def
)
""")

# --- DEBUG FLAG ---
DEBUG_AST = False  # Set to True to print debug info for the first few files
DEBUG_COUNTER = 0
MAX_DEBUG_COUNT = 5 # Increase slightly to get more debug examples

# --- node_to_string function (ensure it's defined) ---
def node_to_string(src_bytes, node):
    """Safely convert a tree-sitter node to a string."""
    if not node or not src_bytes:
        return ""
    start = node.start_byte
    end = node.end_byte
    if 0 <= start <= end <= len(src_bytes):
        return src_bytes[start:end].decode('utf-8', errors='replace')
    else:
        # print(f"Warning: Invalid byte range in node_to_string: start={start}, end={end}, len={len(src_bytes)}")
        return ""


def get_go_functions_with_docs(src_bytes, tree, blob_id_for_debug="N/A"):
    """
    Extracts top-level Go functions that have a preceding comment node.
    Handles whitespace/trivia between comment and function.
    Returns the full text of the comment(s) and the function definition combined.
    """
    global DEBUG_AST, DEBUG_COUNTER, MAX_DEBUG_COUNT # Allow access/modification

    if tree is None or tree.root_node is None or tree.root_node.type != 'source_file':
        return []

    res = []
    last_comments = [] # Track the most recent block of comments encountered
    printed_debug_for_file = False # Local flag for this function call

    # Determine if we should print debug info for this specific file
    should_print_debug = DEBUG_AST and DEBUG_COUNTER < MAX_DEBUG_COUNT

    try:
        if should_print_debug:
             print(f"\n--- DEBUG AST Walk for blob {blob_id_for_debug} ---")
             printed_debug_for_file = True

        # Iterate through direct children of the source_file node
        for child_node in tree.root_node.children:
            node_type = child_node.type
            node_range = f"{child_node.start_point} - {child_node.end_point}"

            if node_type == 'comment':
                last_comments.append(child_node)
                # if should_print_debug: print(f"  Found comment: {node_range}")

            elif node_type == 'function_declaration':
                func_name_node = child_node.child_by_field_name('name')
                func_name_str = node_to_string(src_bytes, func_name_node) if func_name_node else "<?>"
                # if should_print_debug: print(f"  Found function: {func_name_str} {node_range}")

                if last_comments:
                    last_comment_end_line = last_comments[-1].end_point[0]
                    func_start_line = child_node.start_point[0]

                    # Allow for blank lines (adjust threshold if needed)
                    if func_start_line - last_comment_end_line <= 2:
                        if should_print_debug: print(f"    >>> Associated func '{func_name_str}' with {len(last_comments)} comment(s). Adding.")

                        start_byte = last_comments[0].start_byte
                        end_byte = child_node.end_byte
                        if 0 <= start_byte < end_byte <= len(src_bytes):
                           full_text = src_bytes[start_byte:end_byte].decode('utf-8', errors='replace')
                           res.append(full_text)
                        # else: # Log byte range errors if needed
                        #    print(f"    !!! Invalid byte range for '{func_name_str}': start={start_byte}, end={end_byte}, len={len(src_bytes)}")

                    # elif should_print_debug: print(f"    Comments too far for '{func_name_str}' (func line {func_start_line}, last comment line {last_comment_end_line}).")

                # elif should_print_debug: print(f"    No preceding comments tracked for '{func_name_str}'.")

                last_comments = [] # Clear comments after processing a function

            elif child_node.is_named: # Any other named node
                # if should_print_debug and last_comments: print(f"  Clearing comments due to named node: {node_type} {node_range}")
                last_comments = []
            # else: # It's trivia (whitespace, etc.), ignore it, keep comments
                # if should_print_debug: print(f"  Skipping trivia node: {node_type} {node_range}")
                # pass # <--- REMOVED THIS LINE ---

        if printed_debug_for_file:
             print(f"--- END DEBUG AST Walk for blob {blob_id_for_debug} ---")
             # Increment counter only *once* per file debugged
             DEBUG_COUNTER += 1

    except Exception as e:
        print(f"!!! Worker: Error processing children in get_go_functions_with_docs: {type(e).__name__}: {e}")
        # traceback.print_exc()
        # Ensure counter increments if error happens after printing debug header
        if printed_debug_for_file and DEBUG_AST: DEBUG_COUNTER += 1 # Approx increment
        return []

    return res


def parse_ex(parser, ex):
    """Parses a single example, extracts Go functions with docs."""
    global DEBUG_AST, DEBUG_COUNTER, MAX_DEBUG_COUNT # Allow access to debug flags

    content = None
    tree = None
    blob_id = ex.get('blob_id', 'N/A')
    src_encoding = ex.get('src_encoding') if ex.get('src_encoding') else 'utf-8'
    # Track if we printed debug info for this file
    printed_debug_content = False
    should_print_debug_this_file = DEBUG_AST and DEBUG_COUNTER < MAX_DEBUG_COUNT

    try:
        content = download_contents(blob_id, src_encoding)
        if not content: return []

        # --- DEBUG: Print first few lines of content ---
        if should_print_debug_this_file:
            print(f"\n--- DEBUG Content for blob {blob_id} ---")
            lines = content.splitlines()
            for line in lines[:20]: print(line[:120])
            print(f"({len(lines)} total lines)")
            print("--------------------------------------")
            printed_debug_content = True # Mark that we printed header

        buf = bytes(content, "utf8")
        try:
            # Add timeout to parse to prevent hangs (requires parser supporting timeout)
            # Tree-sitter python binding might not directly support timeout easily here.
            # If hangs occur, might need process-level timeout in the main loop instead.
            tree = parser.parse(buf)
        except Exception as parse_e:
            print(f"!!! Worker: Error during parser.parse for blob {blob_id}: {parse_e}")
            tree = None

        if tree is None:
            # If parse failed after printing debug, ensure counter advances
            if printed_debug_content: DEBUG_COUNTER +=1
            return []

        # Pass blob_id for debugging context; counter increment now happens INSIDE get_go_functions_with_docs
        extracted_funcs = get_go_functions_with_docs(buf, tree, blob_id_for_debug=blob_id)

        # If get_go_functions_with_docs didn't print debug (e.g., no functions found),
        # but we printed content, we still need to potentially increment counter.
        # The logic inside get_go_functions_with_docs handles its own increment now.

        return extracted_funcs

    # Error handling
    except botocore.exceptions.ClientError as e:
        # If error happened after printing debug, increment counter approx.
        if printed_debug_content and DEBUG_AST: DEBUG_COUNTER += 1
        error_code = e.response.get('Error', {}).get('Code')
        if error_code != 'NoSuchKey' and '404' not in str(error_code): pass
        return []
    except UnicodeDecodeError as e:
        if printed_debug_content and DEBUG_AST: DEBUG_COUNTER += 1
        return []
    except Exception as e:
        if printed_debug_content and DEBUG_AST: DEBUG_COUNTER += 1
        print(f"\n!!! Worker: UNEXPECTED error in parse_ex for blob {blob_id}: {type(e).__name__}: {e}")
        traceback.print_exc()
        print(f"!!! End Traceback for blob {blob_id}\n")
        return []


# --- PARSERS and process_chunk ---
# These should be correct from the previous iteration

PARSERS = None # Global list for worker parsers

def process_chunk(idx_and_chunk):
    """Processes a chunk of data in a worker process."""
    idx, chunk = idx_and_chunk
    parser = None
    global PARSERS
    if PARSERS is not None and idx < len(PARSERS):
        parser = PARSERS[idx]
    else: # Fallback if PARSERS isn't set right
        # print(f"Worker {idx}: PARSERS fallback, creating new parser.")
        parser = make_go_parser()

    if parser is None:
         print(f"!!! Worker {idx}: CRITICAL - Parser could not be obtained or created.")
         return set()

    chunk_new_funs = set()
    for item_idx, ex in enumerate(chunk):
        if isinstance(ex, dict):
            try:
                new_funcs = parse_ex(parser, ex)
                if new_funcs: # new_funcs is a list
                    chunk_new_funs.update(new_funcs) # Update set with items from list
            except Exception as e_inner:
                print(f"!!! Worker {idx}: UNCAUGHT Error processing item {item_idx} (blob {ex.get('blob_id', 'N/A')}) in chunk: {e_inner}")
        # else: # Skip non-dict items
            # print(f"Warning: Skipping unexpected item type in chunk: {type(ex)}")
    return chunk_new_funs


print("Cell 3 (Query/Extraction Logic) Defined with AST Walk debugging.")

Cell 3 (Query/Extraction Logic) Defined with AST Walk debugging.


In [4]:
NUMWORKERS = os.cpu_count()

In [5]:
import os
import datasets
from datasets import load_dataset # Good practice to import explicitly
from itertools import islice

# --- Configuration ---
YOUR_UCID = "ygc2" # Make sure this is correct

# Define paths for caches on SCRATCH
# NOTE: Using wangj's scratch based on previous context, confirm if this is right!
# If it should be YOUR scratch, change SCRATCH_BASE accordingly.
SCRATCH_BASE = f"/scratch/wangj/{YOUR_UCID}" # Or maybe f"/scratch/{YOUR_UCID}" ??? Please confirm!
MODEL_CACHE_DIR = os.path.join(SCRATCH_BASE, ".cache", "huggingface")
DATASET_CACHE_DIR = os.path.join(SCRATCH_BASE, "datasets_cache_stack_go_full") # Specific cache dir

# --- Ensure Cache Directories Exist ---
# Create the directories if they don't exist
print(f"INFO: Ensuring model cache directory exists: {MODEL_CACHE_DIR}")
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
print(f"INFO: Ensuring dataset cache directory exists: {DATASET_CACHE_DIR}")
os.makedirs(DATASET_CACHE_DIR, exist_ok=True)

# --- Set Environment Variables for Model Cache (Important!) ---
# Still useful if other parts of your workflow use Hugging Face Hub directly
os.environ['HF_HUB_CACHE'] = os.path.join(MODEL_CACHE_DIR, "hub")
os.environ['HF_HOME'] = MODEL_CACHE_DIR
print(f"INFO: Set HF_HUB_CACHE to: {os.environ['HF_HUB_CACHE']}")
print(f"INFO: Set HF_HOME to: {os.environ['HF_HOME']}")

# --- Load FULL Dataset ---
print("\n" + "="*30)
print("WARNING: Attempting to load the FULL Go dataset from the-stack-v2-dedup.")
print("This is VERY large and will take a VERY long time.")
print("Ensure you have sufficient disk space in scratch and consider using a batch job.")
print(f"Using cache directory: {DATASET_CACHE_DIR}")
print("Check available space with: !df -h " + os.path.dirname(DATASET_CACHE_DIR)) # Check parent dir space
print("="*30 + "\n")

try:
    # Directly load the full dataset, assign to 'ds' (or your desired variable)
    ds_full = load_dataset(
        "bigcode/the-stack-v2-dedup",
        data_dir="data/Go",         # Selects the Go language subset
        split="train",              # Use the training split
        cache_dir=DATASET_CACHE_DIR, # Use the SCRATCH cache directory
        # Optional: Use multiple cores for processing if available and memory permits
        # num_proc=8 # Adjust based on your job's CPU allocation
        # keep_in_memory=False # Explicitly keep it disk-based
    )

    ds = list(islice(ds_full, 40000))

    print(f"\nSuccessfully loaded Go dataset.")
    print(f"Total number of examples: {len(ds)}")
    # print(f"Dataset features: {ds.features}") # Optional: See the structure

except Exception as e:
    print(f"\nERROR loading the full dataset: {type(e).__name__}: {e}")
    import traceback
    traceback.print_exc()
    print("\nLoad failed. Check disk space, memory, network connection, or potential dataset issues.")
    # Set ds to None or an empty dataset to prevent errors in later cells if needed
    ds = None

# Now 'ds' holds the full dataset (if loading succeeded)
if ds is not None:
    print(f"\nVariable 'ds' now holds the full dataset with {len(ds)} examples.")
else:
    print("\nVariable 'ds' is None due to loading error.")

print("\n--- Dataset Loading Attempt Finished ---")

INFO: Ensuring model cache directory exists: /scratch/wangj/ygc2/.cache/huggingface
INFO: Ensuring dataset cache directory exists: /scratch/wangj/ygc2/datasets_cache_stack_go_full
INFO: Set HF_HUB_CACHE to: /scratch/wangj/ygc2/.cache/huggingface/hub
INFO: Set HF_HOME to: /scratch/wangj/ygc2/.cache/huggingface

This is VERY large and will take a VERY long time.
Ensure you have sufficient disk space in scratch and consider using a batch job.
Using cache directory: /scratch/wangj/ygc2/datasets_cache_stack_go_full
Check available space with: !df -h /scratch/wangj/ygc2


Successfully loaded Go dataset.
Total number of examples: 40000

Variable 'ds' now holds the full dataset with 40000 examples.

--- Dataset Loading Attempt Finished ---


In [6]:
funs = set()
PARSERS = [make_go_parser() for _ in range(NUMWORKERS)]
total_len = len(ds)
CHUNK_SIZE = 1000 * NUMWORKERS

print(f"Total length: {total_len}")
print(f"Chunk size: {CHUNK_SIZE}")

chunk = []
print(f"Initializing Pool with {NUMWORKERS} workers...")
p = Pool(NUMWORKERS)
print("Pool initialized.")

Total length: 40000
Chunk size: 128000
Initializing Pool with 128 workers...
Pool initialized.


In [7]:
print("Starting main processing loop...")

# --- Initialize PARSERS *before* creating the Pool ---
print(f"Initializing {NUMWORKERS} parsers for the pool...")
PARSERS = [make_go_parser() for _ in range(NUMWORKERS)]
print("Parsers initialized.")

# Create the pool *after* PARSERS is initialized
print(f"Initializing Pool with {NUMWORKERS} workers...")
p = Pool(NUMWORKERS)
print("Pool initialized.")

funs = set()
chunk = []
processed_count = 0

try: # Wrap the main loop in try...finally to ensure pool closure
    for i, ex in enumerate(iter(ds)):
        processed_count = i + 1
        # Printing progress periodically (adjust frequency if needed)
        if i > 0 and i % (max(1, total_len // 100)) == 0: # Print roughly 100 times
            print(f"Processed {i}/{total_len} files. Found {len(funs)} unique functions so far.")

        try: # Inner try for processing a single example/adding to chunk
            chunk.append(ex)
            # Check if chunk is full or it's the last item
            if len(chunk) == CHUNK_SIZE or i == total_len - 1:
                current_chunk_index = (i // CHUNK_SIZE) if CHUNK_SIZE > 0 else 0
                print(f"\nProcessing chunk {current_chunk_index} (items {i-len(chunk)+1} to {i})")

                if not chunk: # Should not happen here, but safe check
                   print("Chunk is empty, skipping.")
                   continue

                if NUMWORKERS <= 0 : NUMWORKERS = 1
                subchunk_size = max(1, (len(chunk) + NUMWORKERS - 1) // NUMWORKERS)
                subchunks = []
                start_idx = 0
                for worker_idx in range(NUMWORKERS):
                    end_idx = min(start_idx + subchunk_size, len(chunk))
                    if start_idx < len(chunk):
                        subchunks.append(chunk[start_idx:end_idx])
                    start_idx = end_idx
                    if start_idx >= len(chunk): break

                if not subchunks:
                    print(f"Skipping empty subchunks for chunk index {current_chunk_index}.")
                    chunk = [] # Clear chunk
                    continue

                print(f"Submitting {len(subchunks)} subchunks to {NUMWORKERS} workers...")
                tasks = [(idx, subchunk) for idx, subchunk in enumerate(subchunks)]
                new_funs_iter = p.imap_unordered(process_chunk, tasks) # Use imap_unordered

                print("Getting new functions from workers...")
                len_before = len(funs)
                results_processed_count = 0
                for worker_results_set in new_funs_iter:
                     results_processed_count +=1
                     if worker_results_set: # Check if the set is not empty
                         funs.update(worker_results_set)
                     # Optional: print progress within chunk results
                     # print(f"  Received result {results_processed_count}/{len(subchunks)} for chunk {current_chunk_index}")


                print(f"Finished processing chunk {current_chunk_index}. Added {len(funs) - len_before} new unique functions. Total unique functions: {len(funs)}\n")
                chunk = [] # Clear chunk for the next batch

        except KeyboardInterrupt:
             print("\nKeyboard Interrupt detected. Stopping loop.")
             break # Exit the main for loop

        except Exception as e:
            print(f"\n!!! ERROR in main loop iteration {i} (processing example: {ex.get('blob_id', 'N/A')}): {type(e).__name__}: {e}")
            traceback.print_exc()
            print("  Attempting to clear chunk and continue...")
            chunk = [] # Ensure chunk is cleared

        # Loop terminates naturally after last item

finally: # Ensure pool is closed even if errors occur
    print("\nClosing the processing pool...")
    try:
        p.close()
        p.join()
        print("Pool closed successfully.")
    except Exception as pool_final_e:
        print(f"Error closing/joining the pool: {pool_final_e}")

print(f"\nFinished processing {processed_count}/{total_len} files. Found {len(funs)} unique Go functions.")

# Creating the final dataset dictionary using the "seed" key
print("Creating final dataset dictionary...")
new_ds_dict = {
    "seed": list(funs), # Use "seed" as the key directly
    "id": list(range(len(funs)))
}

print("Converting dictionary to Hugging Face Dataset...")
if not funs:
    print("WARNING: No functions were extracted. Resulting dataset will be empty.")
    new_ds = Dataset.from_dict({"seed": [], "id": []})
else:
    new_ds = Dataset.from_dict(new_ds_dict)

print("Dataset creation complete.")
print("\nFinal Dataset Info:")
print(new_ds)

Starting main processing loop...
Initializing 128 parsers for the pool...
Parsers initialized.
Initializing Pool with 128 workers...
Pool initialized.
Processed 400/40000 files. Found 0 unique functions so far.
Processed 800/40000 files. Found 0 unique functions so far.
Processed 1200/40000 files. Found 0 unique functions so far.
Processed 1600/40000 files. Found 0 unique functions so far.
Processed 2000/40000 files. Found 0 unique functions so far.
Processed 2400/40000 files. Found 0 unique functions so far.
Processed 2800/40000 files. Found 0 unique functions so far.
Processed 3200/40000 files. Found 0 unique functions so far.
Processed 3600/40000 files. Found 0 unique functions so far.
Processed 4000/40000 files. Found 0 unique functions so far.
Processed 4400/40000 files. Found 0 unique functions so far.
Processed 4800/40000 files. Found 0 unique functions so far.
Processed 5200/40000 files. Found 0 unique functions so far.
Processed 5600/40000 files. Found 0 unique functions so fa

In [8]:
# --- Create and Save Dataset for Step 1 ---
print("\nCreating Step 1 dataset object...")
if not funs:
    print("WARNING: No functions extracted in Sub-step 1.")
    seed_step1_ds = Dataset.from_dict({"seed": [], "id": []})
else:
    seed_step1_ds = Dataset.from_dict({
        "seed": list(funs),
        "id": list(range(len(funs)))
    })

print(f"Step 1 Dataset Info: {seed_step1_ds}")

save_path_step1 = "./seed1_extracted_dataset"
print(f"Saving Step 1 dataset to {save_path_step1}...")
try:
    seed_step1_ds.save_to_disk(save_path_step1)
    print("Step 1 dataset saved successfully.")
except Exception as save_e:
    print(f"ERROR saving Step 1 dataset: {save_e}")


Creating Step 1 dataset object...
Step 1 Dataset Info: Dataset({
    features: ['seed', 'id'],
    num_rows: 27218
})
Saving Step 1 dataset to ./seed1_extracted_dataset...


Saving the dataset (0/1 shards):   0%|          | 0/27218 [00:00<?, ? examples/s]

Step 1 dataset saved successfully.


In [9]:
ds = new_ds

In [10]:
ds

Dataset({
    features: ['seed', 'id'],
    num_rows: 27218
})

### SEED GATHERING HIGH-QUALITY SUBSET

In [11]:
# Define Enhanced Go Filtering Functions

import subprocess
import os
import re # Import regex module
from tree_sitter import Language, Parser

print("Defining Enhanced Go-specific filtering functions...")

# --- Tree-sitter Queries ---
# Query for return statements with values (using predicate)
GO_RETURN_QUERY = GO_LANGUAGE.query("""
(return_statement (#min-children! 2)) @return_val
""")

# Query for import spec paths
GO_IMPORT_QUERY = GO_LANGUAGE.query("""
(import_spec path: (interpreted_string_literal) @import_path)
(import_spec path: (raw_string_literal) @import_path)
""")

# Query for function declarations with empty parameter lists
GO_EMPTY_PARAMS_QUERY = GO_LANGUAGE.query("""
(function_declaration parameters: (parameter_list) @params)
""")


# --- Filter Lists ---
GO_BAD_KEYWORDS = ["todo", "fixme", "xxx", "hack", "bug"] # Case-insensitive check
GO_BAD_IMPORTS = [
    "\"unsafe\"", # Standard library unsafe package
    "\"os/exec\"", # For running external commands
    "\"syscall\"", # Low-level system calls
    # Add other potentially problematic imports specific to your needs
    # e.g., GUI libraries if not desired: "\"github.com/fyne-io/fyne/v2\"",
    # e.g., specific CGO related imports if CGO is disallowed
]
MIN_COMMENT_LENGTH = 15 # Minimum character length for the comment part
MIN_FUNCTION_BODY_LINES = 2 # Require at least a couple of lines in the function body itself

# --- Parser Instance ---
# Make sure make_go_parser() is defined (from Cell 2)
go_filter_parser = make_go_parser()

# --- Filtering Functions ---

def go_does_have_return_value(tree):
    """Checks if the parsed tree contains a non-empty return statement."""
    # Takes tree as input to avoid re-parsing
    if tree is None: return False
    try:
        captures = GO_RETURN_QUERY.captures(tree.root_node)
        return len(captures) > 0
    except Exception:
        return False

def go_check_bad_keywords(code_string):
    """Checks for bad keywords (case-insensitive)."""
    lower_code = code_string.lower()
    for keyword in GO_BAD_KEYWORDS:
        # Use regex word boundary \b to avoid matching parts of words
        if re.search(r'\b' + re.escape(keyword) + r'\b', lower_code):
            # print(f"DEBUG: Found bad keyword '{keyword}'") # Debug Optional
            return False # Found a bad keyword
    return True # No bad keywords found

def go_check_bad_imports(tree):
    """Checks for disallowed imports using tree-sitter query."""
    if tree is None: return True # Allow if parsing failed earlier
    try:
        captures = GO_IMPORT_QUERY.captures(tree.root_node)
        imported_paths = set()
        for node, name in captures:
             if name == 'import_path':
                 # node_to_string expects bytes, need to get them from root node's text
                 import_path = node_to_string(tree.root_node.text, node)
                 imported_paths.add(import_path.strip())

        # print(f"DEBUG: Imports found: {imported_paths}") # Debug Optional
        for bad_import in GO_BAD_IMPORTS:
            if bad_import in imported_paths:
                # print(f"DEBUG: Found bad import '{bad_import}'") # Debug Optional
                return False # Found a bad import
        return True # No bad imports found
    except Exception:
        return False # Disallow on error

def go_check_no_arguments(code_string):
     """Checks if the function signature has an empty parameter list '()' using regex."""
     if not code_string: return False # Should not happen if length check passed

     # Regex to find `func FunctionName()` potentially with preceding comments/whitespace
     # - `\s*` matches optional whitespace/newlines
     # - `(?:/\*.*?\*/|//.*?\n)*` matches optional multi-line or single-line comments
     # - `func\s+` matches the func keyword and space
     # - `\w+` matches the function name
     # - `\s*\(\s*\)` matches the empty parentheses, allowing whitespace inside
     # - `\s*\{` matches the opening brace of the function body
     pattern = r"func\s+\w+\s*\(\s*\)\s*\{"

     # We only care about the main function definition in the snippet
     # The regex looks for the specific pattern
     if re.search(pattern, code_string):
          # print(f"DEBUG: Found func() pattern in:\n{code_string[:100]}...") # Debug
          return False # Found empty params, FILTER OUT
     else:
          return True # No empty params found, KEEP


def go_check_comment_and_body(code_string):
    """
    Checks for minimal comment length and function body length.
    Assumes the input string starts with comments followed by the function.
    """
    try:
        # Heuristic: Find the start of the function definition `func `
        func_keyword_pos = code_string.find("func ")
        if func_keyword_pos == -1:
            return False # Cannot find function keyword

        comment_part = code_string[:func_keyword_pos].strip()
        function_part = code_string[func_keyword_pos:]

        # Check comment quality/length
        cleaned_comment = re.sub(r'(//|/\*|\*/)', '', comment_part).strip() # Remove comment markers
        if len(cleaned_comment) < MIN_COMMENT_LENGTH:
            # print(f"DEBUG: Comment too short ({len(cleaned_comment)} chars)") # Debug Optional
            return False
        if not go_check_bad_keywords(cleaned_comment): # Check for TODO etc ONLY in comment
             # print(f"DEBUG: Found bad keyword in comment part") # Debug Optional
             return False

        # Check function body length (lines between { and })
        body_start = function_part.find('{')
        body_end = function_part.rfind('}')
        if body_start == -1 or body_end == -1 or body_start >= body_end:
            return False # Malformed function body
        function_body = function_part[body_start+1:body_end]
        if len(function_body.splitlines()) < MIN_FUNCTION_BODY_LINES:
             # print(f"DEBUG: Function body too short") # Debug Optional
             return False

        return True
    except Exception:
        return False # Error during split/check

def go_length_check(code_string, max_lines=200, max_chars=5000):
    """Checks overall length."""
    if not code_string: return False
    lines = code_string.splitlines()
    if len(lines) > max_lines: return False
    if len(code_string) > max_chars: return False
    return True

# --- Combined Filter Function ---
def apply_go_filters(example):
    """Applies all filters to a dataset example."""
    code = example['seed']
    if not code: return False

    # Basic length check first (cheap)
    if not go_length_check(code): return False

    # Parse once
    try:
        tree = go_filter_parser.parse(bytes(code, "utf8"))
        if tree is None: return False # Parsing failed
    except Exception:
        return False # Parsing failed

    # Apply checks that use the parsed tree
    if not go_does_have_return_value(tree): return False
    if not go_check_bad_imports(tree): return False
    if not go_check_no_arguments(tree): return False # Check for func()

    # Apply checks on the code string itself
    if not go_check_bad_keywords(code): return False # Check keywords in whole snippet
    if not go_check_comment_and_body(code): return False # Check comment/body length/quality

    return True # All checks passed

print("Enhanced Go filtering functions defined.")

Defining Enhanced Go-specific filtering functions...
Enhanced Go filtering functions defined.


In [12]:
# Apply Filters Sequentially (Debugging Step)

from tqdm.auto import tqdm # Use tqdm for progress bars

print(f"Original dataset size: {len(ds)}") # Should be 1153

# --- Filter 1: Apply 'return value' filter ---
print("\nFiltering for functions with return values...")
ds_step1 = ds.filter(lambda example: go_does_have_return_value(go_filter_parser.parse(bytes(example['seed'], "utf8"))), num_proc=os.cpu_count())
print(f"Size after return filter: {len(ds_step1)}") # <<< Check this count! Should be ~900

# --- Filter 2: Apply length filter ---
print("\nFiltering by length (max 200 lines, 5000 chars)...")
ds_step2 = ds_step1.filter(lambda example: go_length_check(example['seed']), num_proc=os.cpu_count())
print(f"Size after length filter: {len(ds_step2)}") # <<< Check this count! Should be ~899

# --- Filter 3: Apply Bad Keywords Filter ---
print("\nFiltering for bad keywords (TODO, FIXME, etc.)...")
ds_step3 = ds_step2.filter(lambda example: go_check_bad_keywords(example['seed']), num_proc=os.cpu_count())
print(f"Size after bad keywords filter: {len(ds_step3)}") # <<< Check this count!

# --- Filter 4: Apply Bad Imports Filter ---
print("\nFiltering for bad imports (unsafe, os/exec, etc.)...")
# Note: Re-parsing here. Less efficient but isolates the filter.
ds_step4 = ds_step3.filter(lambda example: go_check_bad_imports(go_filter_parser.parse(bytes(example['seed'], "utf8"))), num_proc=os.cpu_count())
print(f"Size after bad imports filter: {len(ds_step4)}") # <<< Check this count!

# --- Filter 5: Apply No Arguments Filter ---
print("\nFiltering for functions with no arguments '()'.")
# Pass the code string directly now
ds_step5 = ds_step4.filter(lambda example: go_check_no_arguments(example['seed']), num_proc=os.cpu_count())
print(f"Size after no arguments filter: {len(ds_step5)}") # <<< Check this count!

# --- Filter 6: Apply Comment/Body Quality Filter ---
print("\nFiltering for comment/body quality (min lengths, keywords in comments)...")
ds_step6 = ds_step5.filter(lambda example: go_check_comment_and_body(example['seed']), num_proc=os.cpu_count())
print(f"Size after comment/body filter: {len(ds_step6)}") # <<< Check this count!


# --- Assign Final Filtered Dataset ---
ds_filtered = ds_step6
print("\n--- Sequential Filtering Complete ---")
print(f"Final filtered dataset size: {len(ds_filtered)}")

# Display a sample from the filtered dataset
if len(ds_filtered) > 0:
    print("\nSample filtered function:")
    print("==========================")
    import random
    # Ensure index is valid if dataset becomes empty
    if len(ds_filtered) > 0:
        print(ds_filtered[random.randint(0, len(ds_filtered)-1)]['seed'])
    print("==========================")
else:
    print("No functions remained after filtering.")



Original dataset size: 27218

Filtering for functions with return values...


Filter (num_proc=128):   0%|          | 0/27218 [00:00<?, ? examples/s]

Size after return filter: 21798

Filtering by length (max 200 lines, 5000 chars)...


Filter (num_proc=128):   0%|          | 0/21798 [00:00<?, ? examples/s]

Size after length filter: 21697

Filtering for bad keywords (TODO, FIXME, etc.)...


Filter (num_proc=128):   0%|          | 0/21697 [00:00<?, ? examples/s]

Size after bad keywords filter: 21111

Filtering for bad imports (unsafe, os/exec, etc.)...


Filter (num_proc=128):   0%|          | 0/21111 [00:00<?, ? examples/s]

Size after bad imports filter: 21111

Filtering for functions with no arguments '()'.


Filter (num_proc=128):   0%|          | 0/21111 [00:00<?, ? examples/s]

Size after no arguments filter: 20957

Filtering for comment/body quality (min lengths, keywords in comments)...


Filter (num_proc=128):   0%|          | 0/20957 [00:00<?, ? examples/s]

Size after comment/body filter: 19206

--- Sequential Filtering Complete ---
Final filtered dataset size: 19206

Sample filtered function:
// Cors sets cors headers on each request for the configured origins
func Cors(f HandlerFunc) HandlerFunc {
	return func(ctx context.Context, r events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
		logger := api.StandardLambdaLogger(ctx, pkg.EnvLogLevel)
		origin := ""
		if val, ok := r.Headers["Origin"]; ok {
			origin = val
		} else if val2, ok2 := r.Headers["origin"]; ok2 {
			origin = val2
		}
		logger.Debugw("Found origin value...", log.Fields{
			"origin": origin,
		})
		response, err := f(ctx, r)
		if err != nil {
			// don't set cors headers on a response that errored out
			return response, err
		}

		if origin == "" {
			logger.Debug("Empty 'Origin' header passed in, no cors available")
			return response, err
		}
		normalizedOrigin := normalizeOrigin(origin)
		allowedOrigins := getAllowedOrigins(pkg.EnvCorsAllowedOrig

In [13]:
# Save the Final Filtered Dataset from Sequential Filtering

save_path_filtered = "./seed2_heuristically_filtered_subset" # Define save path

print(f"\nSaving final filtered dataset (size: {len(ds_filtered)}) to {save_path_filtered}...")

if ds_filtered is not None and len(ds_filtered) > 0:
    try:
        ds_filtered.save_to_disk(save_path_filtered)
        print(f"Filtered dataset saved successfully to {save_path_filtered}.")
        # Optional: Copy to Google Drive if mounted
        # drive_save_path = "/content/drive/MyDrive/your_project_folder/go_seed_heuristically_filtered_subset"
        # print(f"Copying to Google Drive: {drive_save_path}...")
        # !mkdir -p "$(dirname "{drive_save_path}")" # Create parent directory in Drive if needed
        # !cp -r {save_path_filtered} "{drive_save_path}"
        # print("Copy to Google Drive complete.")
    except Exception as save_e:
        print(f"ERROR saving filtered dataset: {save_e}")
elif ds_filtered is not None:
     print("Filtered dataset is empty. Not saving.")
else:
     print("Filtered dataset variable 'ds_filtered' is None. Not saving.")


Saving final filtered dataset (size: 19206) to ./seed2_heuristically_filtered_subset...


Saving the dataset (0/1 shards):   0%|          | 0/19206 [00:00<?, ? examples/s]

Filtered dataset saved successfully to ./seed2_heuristically_filtered_subset.


### SEED GATHERING FILTER DATASET

In [14]:
!pip install vllm



In [15]:
# Imports & LLM Setup for LLM-based Filtering

import datasets
from datasets import load_from_disk
import os
import torch
import random
from tqdm.auto import tqdm
import re
import subprocess # Needed for pip install check

# Try importing vLLM, install if necessary
try:
    from vllm import LLM, SamplingParams
    print("vLLM found.")
except ImportError:
    print("vLLM not found. Installing (this may take a while)...")
    # Note: vLLM installation can be complex. Check vLLM docs if issues arise.
    process = subprocess.run(["pip", "install", "vllm", "-q"], capture_output=True)
    if process.returncode == 0:
         print("vLLM potentially installed successfully. Attempting import...")
         try:
             from vllm import LLM, SamplingParams
             print("vLLM imported successfully after installation.")
         except ImportError:
              print("\nERROR: Failed to import vLLM even after pip install command.")
              raise RuntimeError("vLLM import failed after installation.")
    else:
        print("\nERROR: pip install vllm command failed.")
        print("STDERR:", process.stderr.decode())
        raise RuntimeError("Failed to install vLLM.")


print("\n--- Step 1, Sub-step 3: LLM-based Filtering ---")
print("--- Cell 1: LLM Setup ---")

# --- LLM Configuration ---
MODEL_PATH = "bigcode/starcoder2-15b" # Using Hugging Face ID

llm = None
tokenizer = None

if torch.cuda.is_available():
    print(f"CUDA available. Using GPU.")
    llm_dtype = "bfloat16" if torch.cuda.is_bf16_supported() else "auto"
    tensor_parallel_size = 1 # Keep at 1 for single Colab GPU
    gpu_memory_utilization = 0.90 # Can adjust down if OOM errors occur

    print(f"Attempting to load LLM: {MODEL_PATH}")
    print(f"Config: dtype={llm_dtype}, TP size={tensor_parallel_size}, GPU memory util={gpu_memory_utilization}")

    try:
        llm = LLM(
            model=MODEL_PATH,
            dtype=llm_dtype,
            tensor_parallel_size=tensor_parallel_size,
            gpu_memory_utilization=gpu_memory_utilization,
            # trust_remote_code=True, # May be needed depending on model version
            max_model_len=4096 # Check model's actual max length if different
        )
        tokenizer = llm.get_tokenizer()
        print("LLM and Tokenizer loaded successfully.")
    except Exception as e:
        print(f"\n--- ERROR Initializing LLM: {e} ---")
        print("Check model path, GPU memory/compatibility, and vLLM installation.")
else:
    print("ERROR: CUDA not available. vLLM requires GPU for this step.")

print("-" * 30)

INFO 05-04 13:16:28 [__init__.py:239] Automatically detected platform cuda.
vLLM found.

--- Step 1, Sub-step 3: LLM-based Filtering ---
--- Cell 1: LLM Setup ---
CUDA available. Using GPU.
Attempting to load LLM: bigcode/starcoder2-15b
Config: dtype=bfloat16, TP size=1, GPU memory util=0.9
INFO 05-04 13:16:32 [config.py:2968] Downcasting torch.float32 to torch.bfloat16.
INFO 05-04 13:16:52 [config.py:717] This model supports multiple tasks: {'generate', 'reward', 'embed', 'classify', 'score'}. Defaulting to 'generate'.
INFO 05-04 13:16:52 [config.py:2003] Chunked prefill is enabled with max_num_batched_tokens=16384.
INFO 05-04 13:17:39 [__init__.py:239] Automatically detected platform cuda.
INFO 05-04 13:17:48 [core.py:58] Initializing a V1 LLM engine (v0.8.5) with config: model='bigcode/starcoder2-15b', speculative_config=None, tokenizer='bigcode/starcoder2-15b', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust

Loading safetensors checkpoint shards:   0% Completed | 0/14 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:   7% Completed | 1/14 [00:01<00:25,  2.00s/it]
Loading safetensors checkpoint shards:  14% Completed | 2/14 [00:03<00:21,  1.82s/it]
Loading safetensors checkpoint shards:  21% Completed | 3/14 [00:05<00:18,  1.67s/it]
Loading safetensors checkpoint shards:  29% Completed | 4/14 [00:06<00:13,  1.39s/it]
Loading safetensors checkpoint shards:  36% Completed | 5/14 [00:06<00:10,  1.17s/it]
Loading safetensors checkpoint shards:  43% Completed | 6/14 [00:07<00:08,  1.04s/it]
Loading safetensors checkpoint shards:  50% Completed | 7/14 [00:08<00:06,  1.01it/s]
Loading safetensors checkpoint shards:  57% Completed | 8/14 [00:09<00:05,  1.09it/s]
Loading safetensors checkpoint shards:  64% Completed | 9/14 [00:10<00:04,  1.17it/s]
Loading safetensors checkpoint shards:  71% Completed | 10/14 [00:10<00:03,  1.21it/s]
Loading safetensors checkpoint shards:  79% Completed | 11/14

INFO 05-04 13:18:06 [loader.py:458] Loading weights took 14.34 seconds
INFO 05-04 13:18:06 [gpu_model_runner.py:1347] Model loading took 29.7279 GiB and 15.465909 seconds
INFO 05-04 13:18:22 [backends.py:420] Using cache directory: /home/ygc2/.cache/vllm/torch_compile_cache/613d0c5645/rank_0_0 for vLLM's torch.compile
INFO 05-04 13:18:22 [backends.py:430] Dynamo bytecode transform time: 15.59 s
INFO 05-04 13:18:29 [backends.py:118] Directly load the compiled graph(s) for shape None from the cache, took 6.612 s
INFO 05-04 13:18:44 [monitor.py:33] torch.compile takes 15.59 s in total
INFO 05-04 13:18:45 [kv_cache_utils.py:634] GPU KV cache size: 513,184 tokens
INFO 05-04 13:18:45 [kv_cache_utils.py:637] Maximum concurrency for 4,096 tokens per request: 125.29x
INFO 05-04 13:19:12 [gpu_model_runner.py:1686] Graph capturing finished in 27 secs, took 0.59 GiB
INFO 05-04 13:19:12 [core.py:159] init engine (profile, create kv cache, warmup model) took 65.58 seconds
INFO 05-04 13:19:12 [core_c

In [16]:
# Define Go Prompting Logic

import random
import re # Ensure re is imported here too

print("--- Sub-step 3: Defining Go Prompting Logic ---")

# --- Go Comment/Code Extraction ---
def go_extract_comment_code(seed_string):
    """Splits the seed string into preceding comment and the function code."""
    try:
        # Find first 'func ' occurrence
        func_keyword_pos = seed_string.find("func ")
        if func_keyword_pos != -1:
            comment_part = seed_string[:func_keyword_pos].strip()
            # Clean comment markers
            cleaned_comment = comment_part
            if comment_part.startswith("/*") and comment_part.endswith("*/"):
                cleaned_comment = comment_part[2:-2].strip() # Remove /* */
            elif "//" in comment_part:
                # Handle single line comments
                comment_lines = [line.strip()[2:].strip() if line.strip().startswith("//") else line.strip()
                                 for line in comment_part.splitlines()]
                cleaned_comment = "\n".join(line for line in comment_lines if line) # Join non-empty lines

            code_part = seed_string[func_keyword_pos:].strip()
            # Return cleaned comment and the code part starting from "func "
            return cleaned_comment.strip(), code_part
        # Fallback if 'func ' not found (shouldn't happen often with extracted data)
        return "", seed_string
    except Exception as e:
        # print(f"Warning: Error splitting comment/code: {e}")
        return "", seed_string # Fallback

# --- Go Few-Shot Examples ---
# !!! IMPORTANT: Review and improve these examples for Go !!!
GO_FEW_SHOTS = [
    (
        """// Add calculates the sum of two integers.
// It takes two integers x and y as input.
// It returns their sum.
func Add(x, y int) int {
	return x + y
}""", "Yes", "Comment clearly states purpose, inputs, output, matching the implementation."
    ),
    (
        """// Get retrieves data based on ID.
func Get(id string) *Data {
	// Implementation omitted - complex logic here
    if id == "" { return nil }
	data := internal.fetchData(id)
    if data == nil { return nil }
	return data.Process()
}""", "No", "Comment 'Get retrieves data based on ID' is too generic. Doesn't explain data type, source, errors, or processing."
    ),
    (
        """// Crucial function
// TODO: Refactor this later
func processData(data []byte) error {
	// ... complex processing ...
    if len(data) == 0 { return nil }
    // ... more logic ...
	return someInternalCheck(data)
}""", "No", "Comment is insufficient ('Crucial function', 'TODO'). It doesn't describe the processing steps, expected input format, or return conditions."
    ),
     (
        """// findMax finds the maximum value in a slice of integers.
// It returns the maximum value found, or 0 if the slice is empty.
func findMax(nums []int) int {
	if len(nums) == 0 {
		return 0
	}
	maxVal := nums[0]
	for _, v := range nums[1:] {
		if v > maxVal {
			maxVal = v
		}
	}
	return maxVal
}""", "Yes", "Comment accurately describes the goal, input, edge case, and return value."
    ),
]

# --- Go Prompt Formatting ---
def go_template_few_shot(code_seed, answer, rationale):
    doc, code = go_extract_comment_code(code_seed)
    if not doc: return "" # Skip if comment extraction failed
    answer = "Yes" if answer.lower() == "yes" else "No"
    prompt = f"""<issue_start>username_0: I have a function in Go and I'd like someone to check my description (comment) of this function.
    I'm doing this so that I can write a good comment for this function.

    Here is the code for the function:
    ```go
    {code}
    ```
    Here is my description (the preceding comment) of this program:
    ```
    {doc}
    ```
    Do not attempt to execute the function or to judge its correctness beyond basic parsing.
    Answer with "Yes" or "No" depending only on if my description (comment) has enough information alone to re-implement the function's behavior accurately.
    Also, answer with "No" if the description does not match the function's apparent behavior or is clearly insufficient (e.g., just says 'TODO').<issue_comment>username_1: Sure, no problem. I will be able to help.
    My answer is: {answer}
    {rationale}
    Upvotes: 200"""
    return prompt

def go_prompt_fmt(code_seed):
  doc, code = go_extract_comment_code(code_seed)
  if not doc:
    return None # Signal failure if no comment extracted
  examples_to_use = random.sample(GO_FEW_SHOTS, k=len(GO_FEW_SHOTS))
  buf = ""
  for fs_seed, fs_answer, fs_rationale in examples_to_use:
      fs_prompt = go_template_few_shot(fs_seed, fs_answer, fs_rationale)
      if fs_prompt: buf += fs_prompt + "\n\n"

  buf += f"""<issue_start>username_0: I have a function in Go and I'd like someone to check my description (comment) of this function.
  I'm doing this so that I can write a good comment for this function.
  Here is the code for the function:
  ```
  {code}
  ```
  Here is my description (the preceding comment) of this program:
  ```
  {doc}
  ```
  Do not attempt to execute the function or to judge its correctness beyond basic parsing.
  Answer with "Yes" or "No" depending only on if my description (comment) has enough information alone to re-implement the function's behavior accurately.
  Also, answer with "No" if the description does not match the function's apparent behavior or is clearly insufficient (e.g., just says 'TODO').<issue_comment>username_1: Sure, no problem. I will be able to help.
  My answer is:"""
  return buf

# --- Helper for batching ---
def chunkify(lst, n):
  for i in range(0, len(lst), n):
    yield lst[i:i + n]
  print("Go prompting functions defined.")


--- Sub-step 3: Defining Go Prompting Logic ---


In [17]:
# Execute LLM Filtering and Save Final Step 1 Dataset

from datasets import load_from_disk, Dataset # Ensure Dataset is imported
from tqdm.auto import tqdm

print("--- Sub-step 3: Executing LLM-based Filtering ---")

# --- Load Input Dataset ---
load_path_heuristic_filtered = "./seed2_heuristically_filtered_subset" # Saved from previous heuristic step
dataset_to_filter = None # Initialize
ds_filtered_final = None # Initialize

if not os.path.exists(load_path_heuristic_filtered):
     print(f"ERROR: Cannot load heuristically filtered dataset from {load_path_heuristic_filtered}. Please run previous steps.")
elif llm is None or tokenizer is None:
     print("ERROR: LLM or Tokenizer not initialized in Cell 11. Cannot proceed.")
else:
    print(f"Loading heuristically filtered dataset from {load_path_heuristic_filtered}...")
    dataset_to_filter = load_from_disk(load_path_heuristic_filtered)
    print(f"Loaded dataset size for LLM filtering: {len(dataset_to_filter)}")

# --- Proceed only if dataset loaded and LLM is ready ---
if dataset_to_filter is not None and len(dataset_to_filter) > 0 and llm and tokenizer:

    # --- Calculate Prompt Overhead ---
    few_shot_toks = 2000 # Default estimate
    MAX_MODEL_LEN = 4096
    try:
        # Try calculating based on the first few-shot example
        dummy_prompt = go_prompt_fmt(GO_FEW_SHOTS[0][0])
        if dummy_prompt is None: raise ValueError("Dummy prompt format failed")
        # Use tokenizer which should be loaded
        dummy_code_tokens = len(tokenizer.encode(GO_FEW_SHOTS[0][0]))
        dummy_prompt_tokens = len(tokenizer.encode(dummy_prompt))
        few_shot_toks = dummy_prompt_tokens - dummy_code_tokens
        print(f"Calculated few-shot prompt overhead: ~{few_shot_toks} tokens")
        MAX_MODEL_LEN = llm.llm_engine.model_config.max_model_len
    except Exception as e:
        print(f"Warning: Error calculating few-shot overhead: {e}. Using default estimate.")
    MAX_TOKENS_FOR_CODE = MAX_MODEL_LEN - few_shot_toks - 100 # Subtract overhead and buffer for code/output

    # --- Generate Prompts ---
    prompts = []
    indices_to_keep = [] # Store original indices
    skipped_count = 0
    print(f"\nGenerating prompts (Max code tokens: {MAX_TOKENS_FOR_CODE})...")
    for i, ex in enumerate(tqdm(dataset_to_filter, desc="Generating prompts")):
        code_seed = ex["seed"]
        toks_estimate = len(tokenizer.encode(code_seed))
        if toks_estimate > MAX_TOKENS_FOR_CODE:
            skipped_count += 1
            continue
        p = go_prompt_fmt(code_seed)
        if p is None:
             skipped_count += 1
             continue
        prompts.append(p)
        indices_to_keep.append(i) # Keep track of original index

    print(f"Generated {len(prompts)} prompts. Skipped {skipped_count} examples.")

    # --- Run LLM Inference ---
    llm_responses_bool = []
    batch_size = 32 # Smaller batch size for potentially long prompts
    sampling_params = SamplingParams(temperature=0.0, max_tokens=5, stop=["\n", "<", "Upvotes"])

    if not prompts:
        print("No prompts generated, skipping LLM inference.")
    else:
        print(f"\nRunning LLM generation in batches of {batch_size}...")
        for i, chunk_prompts in enumerate(tqdm(chunkify(prompts, batch_size), desc="LLM Batches")):
            # print(f"  Processing batch {i+1}/{ (len(prompts) + batch_size - 1)//batch_size }...") # Less verbose
            try:
                 outputs = llm.generate(chunk_prompts, sampling_params, use_tqdm=False) # Disable inner tqdm
                 for output in outputs:
                     generated_text = output.outputs[0].text.strip().lower()
                     llm_responses_bool.append(generated_text.startswith("yes")) # True if starts with "yes", False otherwise
            except Exception as llm_e:
                 print(f"!! ERROR during LLM generation for batch {i+1}: {llm_e}")
                 print(f"!! Marking items in batch as False.")
                 llm_responses_bool.extend([False] * len(chunk_prompts)) # Add False for failed items

    print(f"Generated {len(llm_responses_bool)} boolean responses from LLM.")

    # --- Filter Dataset based on LLM Responses ---
    if len(llm_responses_bool) != len(indices_to_keep):
        print("\nERROR: Mismatch between number of generated responses and kept indices!")
        print(f"Expected {len(indices_to_keep)} responses, got {len(llm_responses_bool)}.")
        print("Cannot reliably filter based on LLM responses. Using heuristically filtered dataset.")
        # Use the dataset loaded at the start of this cell
        ds_filtered_final = dataset_to_filter
    elif not llm_responses_bool:
         print("\nNo LLM responses generated or all failed. Using heuristically filtered dataset.")
         ds_filtered_final = dataset_to_filter
    else:
        print("\nFiltering dataset based on LLM responses...")
        response_map = {idx: response for idx, response in zip(indices_to_keep, llm_responses_bool)}
        # Filter the dataset loaded at the start of the cell
        ds_filtered_final = dataset_to_filter.filter(
            lambda example, idx: response_map.get(idx, False), # Default False if index missing
            with_indices=True,
            num_proc=os.cpu_count()
        )
        print(f"Sub-step 3 Result: Final dataset size after LLM filter: {len(ds_filtered_final)}")

    # --- Final Step 1 Save ---
    save_path_step1_final = "./seed3_llm_filtered_subset" # Final dataset path for all of Step 1
    print(f"\nSaving final Step 1 LLM-filtered dataset to {save_path_step1_final}...")
    if ds_filtered_final is not None and len(ds_filtered_final) > 0:
        try:
            # Ensure the 'seed' column exists before saving
            if 'seed' not in ds_filtered_final.column_names:
                 # If 'seed' got renamed somehow, attempt to rename it back from 'content' or other likely name
                 potential_col = 'content' # Or whatever it might have been renamed to
                 if potential_col in ds_filtered_final.column_names:
                      print(f"Renaming column '{potential_col}' back to 'seed' before saving.")
                      ds_filtered_final = ds_filtered_final.rename_column(potential_col, "seed")
                 else:
                      print("ERROR: 'seed' column not found and could not be restored. Cannot save.")
                      raise ValueError("Dataset missing 'seed' column")

            ds_filtered_final.save_to_disk(save_path_step1_final)
            print("Final Step 1 LLM-filtered dataset saved successfully.")
            # Display sample
            print("\nSample final LLM-filtered function:")
            import random
            print(ds_filtered_final[random.randint(0, len(ds_filtered_final)-1)]['seed'])
        except Exception as save_e:
            print(f"ERROR saving final Step 1 dataset: {save_e}")
    elif ds_filtered_final is not None:
         print("Final Step 1 dataset is empty. Not saving.")
    else:
         print("Final Step 1 dataset is None. Not saving.")

else:
    print("\nPrerequisites not met (LLM init failed or input dataset missing/empty). Skipping LLM filtering.")

--- Sub-step 3: Executing LLM-based Filtering ---
Loading heuristically filtered dataset from ./seed2_heuristically_filtered_subset...
Loaded dataset size for LLM filtering: 19206
Calculated few-shot prompt overhead: ~1323 tokens

Generating prompts (Max code tokens: 2673)...


Generating prompts:   0%|          | 0/19206 [00:00<?, ?it/s]

Generated 19201 prompts. Skipped 5 examples.

Running LLM generation in batches of 32...


LLM Batches: 0it [00:00, ?it/s]

Go prompting functions defined.
Generated 19201 boolean responses from LLM.

Filtering dataset based on LLM responses...


Filter (num_proc=128):   0%|          | 0/19206 [00:00<?, ? examples/s]

Sub-step 3 Result: Final dataset size after LLM filter: 3511

Saving final Step 1 LLM-filtered dataset to ./seed3_llm_filtered_subset...


Saving the dataset (0/1 shards):   0%|          | 0/3511 [00:00<?, ? examples/s]

Final Step 1 LLM-filtered dataset saved successfully.

Sample final LLM-filtered function:
// SocialNameGT applies the GT predicate on the "social_name" field.
func SocialNameGT(v string) predicate.User {
	return predicate.User(func(s *sql.Selector) {
		s.Where(sql.GT(s.C(FieldSocialName), v))
	})
}


In [18]:
print(ds_filtered_final)

Dataset({
    features: ['seed', 'id'],
    num_rows: 3511
})
