In [1]:
# Install all required libraries (safe reinstallation)
#!pip install -q tree-sitter==0.19.0 tree-sitter-java==0.19.1 datasets boto3 smart_open


In [2]:
!pip uninstall tree-sitter -y
!pip install tree-sitter==0.20.1
!pip install boto3 smart_open  datasets

Found existing installation: tree_sitter 0.20.1
Uninstalling tree_sitter-0.20.1:
  Successfully uninstalled tree_sitter-0.20.1
Collecting tree-sitter==0.20.1
  Using cached tree_sitter-0.20.1-cp311-cp311-linux_x86_64.whl
Installing collected packages: tree-sitter
Successfully installed tree-sitter-0.20.1


In [3]:
import os
import tempfile
import boto3
import smart_open
from botocore.config import Config
from botocore import UNSIGNED
from tree_sitter import Language, Parser
from datasets import load_dataset

# --- Build Tree-sitter Java Parser ---
print("Building Java parser...")
temp_dir = tempfile.mkdtemp()
java_repo_path = os.path.join(temp_dir, "tree-sitter-java")
build_path = os.path.join(temp_dir, "build")
os.makedirs(build_path, exist_ok=True)

if not os.path.exists(os.path.join(java_repo_path, "src")):
    !git clone --quiet https://github.com/tree-sitter/tree-sitter-java.git {java_repo_path}
    !cd {java_repo_path} && git checkout v0.19.1

# Build
Language.build_library(
    os.path.join(build_path, "java.so"),
    [java_repo_path]
)

JAVA_LANGUAGE = Language(os.path.join(build_path, "java.so"), "java")
print("Java parser built successfully!")

# --- Helper Functions ---

def make_parser():
    parser = Parser()
    parser.set_language(JAVA_LANGUAGE)
    return parser

def node_to_string(src_bytes, node):
    return src_bytes[node.start_byte:node.end_byte].decode('utf8')

# --- Setup Unsigned S3 Access ---

s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))

def download_contents(blob_id, src_encoding):
    """Download Java file content from Software Heritage S3"""
    s3_url = f"s3://softwareheritage/content/{blob_id}"
    try:
        with smart_open.open(s3_url, "rb", compression=".gz", transport_params={"client": s3}) as fin:
            content = fin.read()
            if not content:
                raise Exception("Empty download")
            try:
                return content.decode(src_encoding)
            except UnicodeDecodeError:
                return content.decode('utf-8', errors='ignore')
    except Exception as e:
        print(f"Download error: {e}")
        return ""

# --- Sanity Checks ---

print("Testing parser and downloader...")

# Check 1: Parser creation
try:
    _test_parser = make_parser()
    print("Parser creation: OK")
except Exception as e:
    print(f"Parser creation failed: {e}")

# Check 2: Dummy content download
TEST_BLOB_ID = "1a70a9a76b9d67354f3edcaa2e9a5092184d26ba"  # Small random blob_id
test_content = download_contents(TEST_BLOB_ID, "utf-8")

if test_content and len(test_content) > 20:
    print(f"Sample content download: {test_content[:100]}...")
else:
    print("Content download failed or empty.")

print("Initial setup complete!")


Building Java parser...
Note: switching to 'v0.19.1'.

You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by switching back to a branch.

If you want to create a new branch to retain commits you create, you may
do so (now or later) by using -c with the switch command. Example:

  git switch -c <new-branch-name>

Or undo this operation with:

  git switch -

Turn off this advice by setting config variable advice.detachedHead to false

HEAD is now at 11de4cf 0.19.1
Java parser built successfully!
Testing parser and downloader...
Parser creation: OK
Sample content download: package Programms;
...
Initial setup complete!


In [4]:
import signal
import logging
from multiprocessing import Pool
from datasets import load_dataset, Dataset

# Set up logging
logging.basicConfig(
    filename='java_extraction.log',  # Log to a file
    level=logging.INFO,  # Set log level to INFO;
    format='%(asctime)s - %(levelname)s - %(message)s',
)

logger = logging.getLogger()


JAVA_METHOD_QUERY = JAVA_LANGUAGE.query("""
(method_declaration) @method.def
""")

def get_methods(src_bytes, tree):
    captures = JAVA_METHOD_QUERY.captures(tree.root_node)
    methods = []
    for node, _ in captures:
        method_text = node_to_string(src_bytes, node)
        methods.append(method_text)
    return methods

def parse_ex(parser, ex):
    try:
        # Download from blob_id
        ex_content = download_contents(ex["blob_id"], ex["src_encoding"])
        if not ex_content.strip():
            return []

        buf = bytes(ex_content, "utf8")
        tree = parser.parse(buf)
        return get_methods(buf, tree)
    except Exception as e:
        logger.error(f"Parse error for blob_id {ex.get('blob_id', 'N/A')}: {e}")
        return []

def process_chunk(idx_and_chunk):
    idx, chunk = idx_and_chunk
    parser = PARSERS[idx]
    chunk_methods = set()
    for ex in chunk:
        methods = parse_ex(parser, ex)
        chunk_methods.update(methods)
    return chunk_methods

def main(args):
    global PARSERS
    logger.info("Starting Java extraction...")

    try:
        ds = load_dataset(
            args.dataset,
            data_dir=args.data_dir,
            split="train",
            streaming=True,
        )
    except Exception as e:
        logger.error(f"Failed to load dataset: {e}")
        return

    methods = set()
    PARSERS = [make_parser() for _ in range(args.num_workers)]
    CHUNK_SIZE = 1000 * args.num_workers

    logger.info(f"Chunk size per batch: {CHUNK_SIZE}")

    chunk = []
    p = Pool(args.num_workers)

    sample_count = 0
    max_samples = getattr(args, "max_samples", None)

    for ex in ds:
        if sample_count % 1000 == 0:
            logger.info(f"Processed {sample_count} samples...")

        try:
            chunk.append(ex)
            sample_count += 1

            if len(chunk) == CHUNK_SIZE or (max_samples and sample_count >= max_samples):
                logger.info(f"Processing chunk at {sample_count} samples")

                subchunk_size = max(1, len(chunk) // args.num_workers)
                subchunks = [chunk[i:i + subchunk_size] for i in range(0, len(chunk), subchunk_size)]

                new_methods_iter = p.imap(
                    process_chunk, [(i, subchunk) for i, subchunk in enumerate(subchunks)]
                )

                len_before = len(methods)
                while True:
                    try:
                        def timeout_handler(_, __): raise KeyboardInterrupt
                        signal.signal(signal.SIGALRM, timeout_handler)
                        signal.alarm(60)
                        methods.update(next(new_methods_iter))
                        signal.alarm(0)
                    except KeyboardInterrupt:
                        signal.alarm(0)
                        logger.warning("Timeout/Interrupt: Restarting workers...")
                        p.terminate()
                        p = Pool(args.num_workers)
                        break
                    except StopIteration:
                        break
                    except Exception as e:
                        logger.error(f"Error during chunk processing: {e}")
                        break

                signal.alarm(0)

                PARSERS = [make_parser() for _ in range(args.num_workers)]

                logger.info(f"Done chunk. New methods found: {len(methods) - len_before}")

                chunk = []

            if max_samples and sample_count >= max_samples:
                logger.info("Reached max_samples limit. Stopping...")
                break

        except Exception as e:
            logger.error(f"Main loop error: {e}")
            chunk = []

    # Finalize
    p.close()
    p.join()

    # Create Huggingface dataset
    new_ds_dict = {
        "content": list(methods),
        "id": list(range(len(methods)))
    }

    new_ds = Dataset.from_dict(new_ds_dict)
    logger.info("Extraction complete!")
    return new_ds


In [5]:
NUMWORKERS = os.cpu_count()

In [None]:

from datasets import load_dataset
from huggingface_hub import login
login('your_login_token')  # Replace with your Hugging Face token



# --- Load the Java dataset ---
print("Loading Java dataset (streaming mode)...")

ds = load_dataset(
    "bigcode/the-stack-v2-dedup",     # Dataset name
    "Java",                           # Subset = Java
    cache_dir="/content/Project/Stack",  # Your local cache
    streaming=True,                   # Stream because it's huge
    split="train"                     # Training split
)

print("Java dataset loaded (streaming mode)!")


Loading Java dataset (streaming mode)...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Resolving data files:   0%|          | 0/757 [00:00<?, ?it/s]

Java dataset loaded (streaming mode)!


In [7]:

from multiprocessing import Pool

# --- Initialize necessary variables ---
funs = set()  # To collect all extracted Java methods
PARSERS = [make_parser() for _ in range(NUMWORKERS)]  # Create one parser per worker

# In streaming mode we can't call len(ds)
# CHUNK_SIZE is manually based on number of CPU workers
CHUNK_SIZE = 1000 * NUMWORKERS

print(f"Chunk size: {CHUNK_SIZE}")

chunk = []  # Temporary holder for samples
p = Pool(NUMWORKERS)  # Multiprocessing pool


Chunk size: 12000


In [8]:
if __name__ == "__main__":
    import multiprocessing
    from multiprocessing.dummy import Pool
    import signal

    multiprocessing.set_start_method("spawn", force=True)

    PARSERS = [make_parser() for _ in range(NUMWORKERS)]
    p = Pool(NUMWORKERS)

    sample_count = 0
    MAX_SAMPLES = 1000
    chunk = []
    funs = set()

    FINAL_CHUNK_SIZE = 4000
    print(f"Chunk size: {FINAL_CHUNK_SIZE}")

    for ex in iter(ds):
        if sample_count % 1000 == 0:
            print(f"Processed {sample_count} samples...")

        try:
            chunk.append(ex)
            sample_count += 1

            if len(chunk) == FINAL_CHUNK_SIZE or sample_count >= MAX_SAMPLES:
                print(f"Processing chunk at {sample_count} samples...")

                subchunk_size = max(1, len(chunk) // NUMWORKERS)
                subchunks = [chunk[i:i + subchunk_size] for i in range(0, len(chunk), subchunk_size)]

                new_funs_iter = p.imap(
                    process_chunk, [(i, subchunk) for i, subchunk in enumerate(subchunks)]
                )

                len_before = len(funs)
                failed = False  #Track if timeout happened

                while True:
                    try:
                        def timeout_handler(_, __): raise KeyboardInterrupt
                        signal.signal(signal.SIGALRM, timeout_handler)
                        signal.alarm(300)  #5 min timeout
                        funs.update(next(new_funs_iter))
                        signal.alarm(0)
                    except KeyboardInterrupt:
                        signal.alarm(0)
                        print("Timeout/Interrupt detected. Restarting pool...")
                        failed = True  #Mark chunk as failed
                        break
                    except StopIteration:
                        break
                    except Exception as e:
                        print(f"Error while processing subchunk: {e}")
                        break

                if failed:
                    print("Reprocessing failed chunk...")
                    # Restart pool
                    p.close()
                    p.join()
                    p = Pool(NUMWORKERS)
                    PARSERS = [make_parser() for _ in range(NUMWORKERS)]

                    #Re-run the same chunk safely again
                    subchunk_size = max(1, len(chunk) // NUMWORKERS)
                    subchunks = [chunk[i:i + subchunk_size] for i in range(0, len(chunk), subchunk_size)]

                    new_funs_iter = p.imap(
                        process_chunk, [(i, subchunk) for i, subchunk in enumerate(subchunks)]
                    )

                    while True:
                        try:
                            signal.signal(signal.SIGALRM, timeout_handler)
                            signal.alarm(300)
                            funs.update(next(new_funs_iter))
                            signal.alarm(0)
                        except (KeyboardInterrupt, StopIteration):
                            signal.alarm(0)
                            break
                        except Exception as e:
                            print(f"Error while retrying: {e}")
                            break

                signal.alarm(0)

                print(f"Done processing chunk. New functions added: {len(funs) - len_before}")
                chunk = []

            if sample_count >= MAX_SAMPLES:
                print(f"Reached MAX_SAMPLES ({MAX_SAMPLES}). Breaking...")
                break

        except Exception as e:
            print(f"Main loop error: {e}")
            chunk = []

    p.close()
    p.join()

    new_ds_dict = {
        "content": list(funs),
        "id": list(range(len(funs)))
    }

    new_ds = Dataset.from_dict(new_ds_dict)


Chunk size: 4000
Processed 0 samples...
Processing chunk at 1000 samples...




Error while processing subchunk: list index out of range
Done processing chunk. New functions added: 5954
Reached MAX_SAMPLES (1000). Breaking...


In [9]:
ds = new_ds

In [10]:
ds

Dataset({
    features: ['content', 'id'],
    num_rows: 5954
})

### SEED GATHERING HIGH-QUALITY SUBSET

In [11]:
import signal
import hashlib
import os
import tempfile
from typing import List, Dict
from tqdm import tqdm

#Reuse the Java parser we built
RETURN_QUERY = JAVA_LANGUAGE.query("""
(return_statement) @return
""")

#Parse using Java parser, check if method has a return value
def does_have_return(src: str, parser=None) -> bool:
    if parser is None:
        parser = make_parser()
    tree = parser.parse(bytes(src, "utf8"))
    root = tree.root_node
    captures = RETURN_QUERY.captures(root)
    for node, _ in captures:
        if len(node.children) <= 1:
            continue
        else:
            return True  # Found return with a value
    return False


# Helper: batch checking (List[str] -> Dict[str, str])
def typecheck_batch(files: List[str]) -> Dict[str, str]:
    filemap: Dict[str, str] = {}

    for src in tqdm(files, desc="Checking returns"):
        try:
            has_return = does_have_return(src)
            if has_return:
                file_hash = hashlib.md5(src.encode('utf-8')).hexdigest()
                filemap[file_hash] = src
        except Exception as e:
            print(f"Return check error: {e}")
            continue

    return filemap

print("Java return checker ready!")


Java return checker ready!


In [12]:
print("Filtering to only methods with return statements...")

# Filter the dataset using multiprocessing
ds = ds.filter(
    lambda ex: does_have_return(ex["content"]),
    num_proc=os.cpu_count(),   # Use all CPU cores available
    desc="Filtering methods"
)

print(f"Finished filtering. Remaining samples: {len(ds)}")


  StockPickler.save(self, obj, save_persistent_id)
  StockPickler.save(self, obj, save_persistent_id)


Filtering to only methods with return statements...


Filtering methods (num_proc=12):   0%|          | 0/5954 [00:00<?, ? examples/s]

Finished filtering. Remaining samples: 2286


In [13]:
ds

Dataset({
    features: ['content', 'id'],
    num_rows: 2286
})

In [14]:
import datasets
from tqdm import tqdm

batch = []
max_i = len(ds) - 1

# Initialize new dataset structure
final_ds = {
    "content": [],
    "sha1": [],
    "id": [],
}

e_id = 0

print("Typechecking methods and building new dataset...")

for i, ex in enumerate(tqdm(ds, total=len(ds))):
    try:
        code = ex["content"]
        batch.append(code)

        # Process batch
        if len(batch) == 250 or i == max_i:
            filemap = typecheck_batch(batch)
            for sha1, contents in filemap.items():
                final_ds["content"].append(contents)
                final_ds["sha1"].append(sha1)
                final_ds["id"].append(e_id)
                e_id += 1
            batch = []

    except Exception as e:
        print(f"There was an error processing example {i}: {e}")
        continue

# Convert dictionary to Huggingface Dataset
new_ds_hf = datasets.Dataset.from_dict(final_ds)

print(f"Built new dataset with {len(new_ds_hf)} Java methods after typechecking!")


Typechecking methods and building new dataset...


  0%|          | 0/2286 [00:00<?, ?it/s]
Checking returns: 100%|██████████| 250/250 [00:00<00:00, 6906.52it/s]

Checking returns: 100%|██████████| 250/250 [00:00<00:00, 4959.00it/s]
 22%|██▏       | 500/2286 [00:00<00:00, 4049.70it/s]
Checking returns: 100%|██████████| 250/250 [00:00<00:00, 6739.22it/s]

Checking returns: 100%|██████████| 250/250 [00:00<00:00, 5829.73it/s]
 44%|████▎     | 1000/2286 [00:00<00:00, 4138.20it/s]
Checking returns: 100%|██████████| 250/250 [00:00<00:00, 6757.33it/s]

Checking returns: 100%|██████████| 250/250 [00:00<00:00, 7241.55it/s]
 66%|██████▌   | 1500/2286 [00:00<00:00, 4287.72it/s]
Checking returns: 100%|██████████| 250/250 [00:00<00:00, 5880.04it/s]

Checking returns: 100%|██████████| 250/250 [00:00<00:00, 6719.19it/s]
 87%|████████▋ | 2000/2286 [00:00<00:00, 4301.17it/s]
Checking returns: 100%|██████████| 250/250 [00:00<00:00, 7410.38it/s]

Checking returns: 100%|██████████| 36/36 [00:00<00:00, 12695.05it/s]
100%|██████████| 2286/2286 [00:00<00:00,

Built new dataset with 2286 Java methods after typechecking!





In [15]:
print(new_ds_hf['content'][0])

@CaseAttributes(isSupportAutoTest = true)
    public boolean CASE_setExtPinpadPortMode_CNPinpadMode() {
        int ret = API07_setExtPinpadPortMode(ExternalSerialConst.MODE_PP1000V3_PINPAD);
        if (ret == ExternalSerialConst.MODE_PP1000V3_PINPAD) {
            return true;
        }

        return false;
    }


In [16]:
save_dir = "/content/Project/Datasets/Seed2"

In [17]:

# Save the final HuggingFace dataset to disk
new_ds_hf.save_to_disk(save_dir)

print(f"Dataset successfully saved to {save_dir}")


Saving the dataset (0/1 shards):   0%|          | 0/2286 [00:00<?, ? examples/s]

Dataset successfully saved to /content/Project/Datasets/Seed2


### SEED GATHERING FILTER DATASET

In [18]:
!pip install vllm



In [19]:
import datasets
import os
from tqdm import tqdm
import torch
import argparse
from vllm import LLM, SamplingParams
import random
import torch
from tree_sitter import Language, Parser

# Make parser for Java
def make_parser():
    parser = Parser()
    parser.set_language(JAVA_LANGUAGE)
    return parser

INFO 05-02 17:43:16 [__init__.py:239] Automatically detected platform cuda.


In [20]:
FN_BLOCK_QUERY = JAVA_LANGUAGE.query("""
(method_declaration
  body: (block) @fn-block)
""")


def template_few_shot(code, comment, answer, rationale):
    doc = ""  # Java methods often don't have inline docstrings
    assert answer == "No" or answer == "Yes"

    prompt = f"""<issue_start>username_0: I have a method in Java and I'd like someone to check my description of this method.
I'm doing this so that I can write a good comment for this method.

Here is the code for the method:
```java
{code}
```

Here is my description of this method:
```
{doc}
```

Do not attempt to execute the function or to judge its correctness.
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function.
Also, answer with "No" if the description does not match the function.<issue_comment>username_1: Sure, no problem. I will be able to help.
My answer is: {answer}

{rationale}

Upvotes: 200"""

    return prompt


FEW_SHOTS= [
    (
        '''/**
 * Converts a JSON string to a Config object.
 */
public Config parseConfig(String json) throws IOException {
    ObjectMapper mapper = new ObjectMapper();
    return mapper.readValue(json, Config.class);
}''',
        "Parses a JSON string and returns a Config object.",
        "Yes",
        "The comment clearly states what the method does and matches its actual behavior."
    ),
    (
        '''/**
 * Calculates discount based on user type and purchase amount.
 */
public double calculateDiscount(String userType, double amount) {
    if ("premium".equalsIgnoreCase(userType)) {
        return amount * 0.2;
    } else if ("regular".equalsIgnoreCase(userType)) {
        return amount * 0.1;
    }
    return 0;
}''',
        "Applies shipping tax based on product weight.",
        "No",
        "The comment does not match the logic — the method calculates user discounts, not shipping taxes."
    ),
    (
        '''/**
 * Returns a sorted list of usernames in alphabetical order.
 */
public List<String> sortUsernames(List<String> usernames) {
    return usernames.stream()
        .sorted()
        .collect(Collectors.toList());
}''',
        "Sorts the usernames alphabetically.",
        "Yes",
        "The description and implementation align well — the function sorts and returns usernames."
    ),
    (
        '''/**
 * Attempts to connect to the given URL and returns true if successful.
 */
public boolean isReachable(String urlStr) {
    try {
        URL url = new URL(urlStr);
        HttpURLConnection conn = (HttpURLConnection) url.openConnection();
        conn.setRequestMethod("HEAD");
        conn.setConnectTimeout(3000);
        conn.connect();
        return conn.getResponseCode() < 400;
    } catch (IOException e) {
        return false;
    }
}''',
        "Checks if a URL is reachable via a HEAD request.",
        "Yes",
        "The comment matches both the method's name and its detailed implementation."
    ),
    (
        '''/**
 * Logs the current memory usage in MB.
 */
public void logMemoryUsage() {
    Runtime runtime = Runtime.getRuntime();
    long usedMem = (runtime.totalMemory() - runtime.freeMemory()) / (1024 * 1024);
    System.out.println("Used Memory: " + usedMem + " MB");
}''',
        "Cleans up unused memory by forcing garbage collection.",
        "No",
        "The comment claims the method forces GC, but it only logs memory usage."
    )
]


def prompt_fmt(code):
    random.shuffle(FEW_SHOTS)
    buf = ""
    for few in FEW_SHOTS:
        buf += template_few_shot(*few)

    buf += f"""<issue_start>username_0: I have a method in Java and I'd like someone to check my description of this method.
I'm doing this so that I can write a good comment for this method.

Here is the code for the method:
```java
{code}
```

Here is my description of this method:
```

```

Do not attempt to execute the function or to judge its correctness.
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function.
Also, answer with "No" if the description does not match the function.
Upvotes: 100<issue_comment>username_1: Sure, no problem. I will be able to help.
My answer is:"""

    return buf


def auto_dtype():
    if torch.cuda.is_bf16_supported():
        return "bfloat16"
    return "auto"

def chunkify(lst, n):
    chunks = []
    for i in range(0, len(lst), n):
        chunk = []
        for j in range(n):
            if i + j < len(lst):
                chunk.append(lst[i + j])
        chunks.append(chunk)
    return chunks


In [21]:
dataset = new_ds_hf

In [22]:
!pip install benchmark_data


[31mERROR: Could not find a version that satisfies the requirement benchmark_data (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for benchmark_data[0m[31m
[0m

In [23]:
print(f"Loaded {len(new_ds_hf)} examples. Running pre-filtering...")

# Words indicating incomplete or unsafe code
BAD_WORDS = ["TODO", "FIXME", "BUG", "HACK"]

# Java imports or APIs we might want to filter out
BAD_IMPORTS = [
    "java.io", "java.net", "java.nio", "java.util.concurrent",
    "System.exit", "Runtime.getRuntime", "ProcessBuilder"
]

# Match both import forms
BAD_IMPORTS = [f"import {b}" for b in BAD_IMPORTS] + \
              [f"from {b}" for b in BAD_IMPORTS]

# Merge them all
BAD_SUBSTRINGS = BAD_WORDS + BAD_IMPORTS

# Placeholder for known benchmark overlap filtering
all_bench = []  # No HumanEval/MBPP for Java



Loaded 2286 examples. Running pre-filtering...


In [24]:
def pre_filtering(ex):
    code = ex["content"]
    code_bytes = code.encode('utf-8')

    # filter out bad substrings
    lower = code.lower()
    for word in BAD_SUBSTRINGS:
        if word in lower:
            return False

    for b in all_bench:
        if b in code:  # contaminated sample!
            return False

    lines = code.split("\n")
    if len(lines) > 150:
        return False

    # filter functions which don't have an argument
    # 1. find first def statement in lines
    # 2. check if contains ():
    for line in lines:
        if line.startswith("def ") and "():" in line:
            return False

    # filter out functions with no return statement
    if not does_have_return(code):
        return False

    #No docstring checking for Java

    return True  # all good!

threads = os.cpu_count() - 1  # type: ignore
dataset = dataset.filter(pre_filtering, num_proc=threads)


  StockPickler.save(self, obj, save_persistent_id)
  StockPickler.save(self, obj, save_persistent_id)


Filter (num_proc=11):   0%|          | 0/2286 [00:00<?, ? examples/s]

In [25]:
dataset

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 2279
})

In [26]:
import torch
print(torch.cuda.is_available())  # Should return True if GPU is available


True


In [27]:
!apt-get update
!apt-get install --reinstall cuda-toolkit-12-1

Hit:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:2 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:5 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:6 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:7 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:8 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:9 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Hit:10 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Reading package lists... Done
W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
Reading package lists... Done
Building dependency tree... Done
Reading

In [28]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0


In [29]:
!pip install pyzmq --upgrade



In [30]:
model = LLM(f"bigcode/starcoder", dtype=torch.float16,
            gpu_memory_utilization=0.95, tensor_parallel_size=1)


INFO 05-02 17:43:37 [config.py:2968] Downcasting torch.float32 to torch.float16.
INFO 05-02 17:43:51 [config.py:717] This model supports multiple tasks: {'score', 'generate', 'classify', 'reward', 'embed'}. Defaulting to 'generate'.
INFO 05-02 17:43:51 [config.py:2003] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 05-02 17:43:53 [core.py:58] Initializing a V1 LLM engine (v0.8.5) with config: model='bigcode/starcoder', speculative_config=None, tokenizer='bigcode/starcoder', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config

model-00003-of-00007.safetensors:   0%|          | 0.00/9.85G [00:00<?, ?B/s]

model-00004-of-00007.safetensors:   0%|          | 0.00/9.86G [00:00<?, ?B/s]

model-00001-of-00007.safetensors:   0%|          | 0.00/9.90G [00:00<?, ?B/s]

model-00005-of-00007.safetensors:   0%|          | 0.00/9.85G [00:00<?, ?B/s]

model-00007-of-00007.safetensors:   0%|          | 0.00/4.08G [00:00<?, ?B/s]

model-00002-of-00007.safetensors:   0%|          | 0.00/9.86G [00:00<?, ?B/s]

model-00006-of-00007.safetensors:   0%|          | 0.00/9.86G [00:00<?, ?B/s]

INFO 05-02 17:46:52 [weight_utils.py:281] Time spent downloading weights for bigcode/starcoder: 176.727273 seconds


model.safetensors.index.json:   0%|          | 0.00/38.2k [00:00<?, ?B/s]

Loading safetensors checkpoint shards:   0% Completed | 0/7 [00:00<?, ?it/s]


INFO 05-02 17:50:44 [loader.py:458] Loading weights took 231.64 seconds
INFO 05-02 17:50:45 [gpu_model_runner.py:1347] Model loading took 28.9427 GiB and 409.694893 seconds
INFO 05-02 17:50:54 [backends.py:420] Using cache directory: /root/.cache/vllm/torch_compile_cache/3284d1dd63/rank_0_0 for vLLM's torch.compile
INFO 05-02 17:50:54 [backends.py:430] Dynamo bytecode transform time: 9.38 s
INFO 05-02 17:50:57 [backends.py:136] Cache the graph of shape None for later use
INFO 05-02 17:51:19 [backends.py:148] Compiling a graph for general shape takes 23.78 s
INFO 05-02 17:51:56 [monitor.py:33] torch.compile takes 33.16 s in total
INFO 05-02 17:51:58 [kv_cache_utils.py:634] GPU KV cache size: 355,712 tokens
INFO 05-02 17:51:58 [kv_cache_utils.py:637] Maximum concurrency for 8,192 tokens per request: 43.42x
INFO 05-02 17:52:38 [gpu_model_runner.py:1686] Graph capturing finished in 40 secs, took 0.56 GiB
INFO 05-02 17:52:38 [core.py:159] init engine (profile, create kv cache, warmup model)

In [31]:
tokenizer = model.get_tokenizer()

In [32]:
print(f"Now running stage 3 filtering on {len(dataset)} examples...")

Now running stage 3 filtering on 2279 examples...


In [33]:
def unindent(s):
    lines = s.splitlines()
    non_blank_lines = [line for line in lines if line.strip()]
    min_indent = min(len(line) - len(line.lstrip())
                     for line in non_blank_lines) if non_blank_lines else 0
    unindented_lines = [line[min_indent:] if len(
        line) >= min_indent else line for line in lines]
    return '\n'.join(unindented_lines)


def java_extract_docstring(code):
    first_doc = code.find('"""')
    assert first_doc != -1
    first_doc = first_doc + 3
    second_doc = code[first_doc+1:].find('"""')
    assert second_doc != -1
    second_doc = second_doc + first_doc + 1
    doc = code[first_doc:second_doc]
    doc = unindent(doc).strip()
    code = code[:first_doc-3] + code[second_doc+3:]
    return doc, code


In [34]:
dummy = 'public void dummy() {\n    // TODO\n}'
dummy_prompt = prompt_fmt(dummy)
few_shot_toks = len(tokenizer.encode(dummy_prompt)) - len(tokenizer.encode(dummy))
print(f"Few-shot prompt has {few_shot_toks} tokens")


Few-shot prompt has 1438 tokens


In [35]:
prompts = []
for ex in tqdm(dataset, total=len(dataset), desc="Generating prompts"):
    code = ex["content"]
    toks = len(tokenizer.encode(code)) + few_shot_toks
    if toks > 16380:
        print(f"Skipping example with {toks} tokens")
        prompts.append(dummy_prompt)
        continue
    p = prompt_fmt(code)
    prompts.append(p)

responses = []
for chunk in tqdm(chunkify(prompts, 512), desc="Generating responses"):
    outs = model.generate(chunk, SamplingParams(
        temperature=0.0, stop="\n", max_tokens=5))
    contents = [o.outputs[0].text for o in outs]
    for c in contents:
        yes_count = c.lower().count("yes")
        no_count = c.lower().count("no")
        if yes_count > no_count:
            responses.append(True)
        elif yes_count < no_count:
            responses.append(False)
        else:
            # default to No
            responses.append(False)



Generating prompts: 100%|██████████| 2279/2279 [00:01<00:00, 2183.99it/s]
Generating responses:   0%|          | 0/5 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/512 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Generating responses:  20%|██        | 1/5 [00:26<01:47, 26.82s/it]

Processed prompts:   0%|          | 0/512 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Generating responses:  40%|████      | 2/5 [00:41<00:59, 19.70s/it]

Processed prompts:   0%|          | 0/512 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Generating responses:  60%|██████    | 3/5 [00:56<00:35, 17.66s/it]

Processed prompts:   0%|          | 0/512 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Generating responses:  80%|████████  | 4/5 [01:12<00:16, 16.80s/it]

Processed prompts:   0%|          | 0/231 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Generating responses: 100%|██████████| 5/5 [01:18<00:00, 15.67s/it]


In [36]:
dataset

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 2279
})

In [37]:
subset = dataset.select(range(min(75000, len(dataset))))


In [38]:
subset

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 2279
})

In [39]:
subset_responses = responses

def filter_fn(ex, i):
    if i >= len(subset_responses):
        return False
    return subset_responses[i] and "def dummy()" not in ex["content"]

new_ds = subset.filter(filter_fn, with_indices=True)

print(f"Filtered {len(subset) - len(new_ds)} examples")


Filter:   0%|          | 0/2279 [00:00<?, ? examples/s]

Filtered 1126 examples


In [40]:
print("Total prompts generated:", len(prompts))
print("Total responses received:", len(responses))
print("How many 'yes':", sum(responses))
print("How many 'no':", len(responses) - sum(responses))


Total prompts generated: 2279
Total responses received: 2279
How many 'yes': 1153
How many 'no': 1126


In [41]:
new_ds.save_to_disk("/content/Project/Datasets/Seed3")

Saving the dataset (0/1 shards):   0%|          | 0/1153 [00:00<?, ? examples/s]

In [42]:
new_ds

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 1153
})

In [43]:
from datasets import Dataset

dataset = Dataset.load_from_disk('/content/Project/Datasets/Seed3/')
dataset


Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 1153
})

In [44]:
dataset.to_pandas().head()

Unnamed: 0,content,sha1,id
0,@CaseAttributes(isSupportAutoTest = true)\r\n ...,55e6ee80451fb6e81882fd927a8ce894,0
1,@Override\n protected String getConfFile() ...,74778c570c7e1cb4972eb6126f632f7e,4
2,@Override\n public boolean onKey(Di...,d2ebcc9aa4a46d9e3e177873a2174cb8,5
3,@Override\n public Document findById(String...,23a69de38312702d0dde94541faf3665,6
4,private P getInstance() {\n try {\n ...,5065494f61b5ec97077775323cb01089,9
