In [7]:
import subprocess
from typing import List, Optional

def compute_diff(
    buggy_code: str, fixed_code: str, context_len: Optional[int] = None
) -> List[str]:
    """
    Computes the diff between the buggy and fixed code.
    """
    context_len = (
        context_len
        if context_len is not None
        else max(len(buggy_code), len(fixed_code))
    )
    with open("/tmp/buggy.java","w") as f: f.write(buggy_code+"\n")
    with open("/tmp/fixed_code.java","w") as f: f.write(fixed_code+"\n")
    # we want to ignore whitespace changes with -w which does not exist in difflib.unified_diff
    # with git diff, we even the name of the changed function in the diff, which helps a lot
    cmd = ["git","diff","--patience",f"-U{context_len}", "-w","/tmp/buggy.java","/tmp/fixed_code.java"]
    return subprocess.run(cmd, capture_output=True).stdout.decode("utf-8")

In [8]:
from datasets import load_dataset

megadiff_sf = load_dataset("ASSERT-KTH/megadiff-single-function")
megadiff_sf = megadiff_sf.map(lambda x: {"short_diff": compute_diff(x["buggy_function"], x["fixed_function"], context_len=3)})

In [13]:
def user_prompt(fixed_function: str) -> str:
    return f"""
### Fixed Function
```java
{fixed_function}
```
"""

def system_prompt() -> str:
    return """You are an assistant designed to help generate synthetic data for fine-tuning a language model for automatic program repair. You will receive a correctly implemented function as input. Your task is to generate:

1. A buggy version of the provided function.
2. A unit test method that exposes the behavioral difference between the buggy and fixed versions.
3. The stack trace or error message resulting from executing that unit test method with the buggy function.

Please ensure the following when generating the output:

- The buggy function should have a realistic and plausible error that one might encounter in actual programming.
- The unit test MUST be a method and directly target the introduced bug. ONLY the unit test method should be provided, not the entire test suite.
- The stack trace or error message must be realistic and show the exact execution of the unit test.
- Do NOT add any comment in the code about the bug, the bug injection or any other information.

Format the output in the following structure:
### Buggy Function
```java
<Buggy function code>
```

### Unit Test
```java
<Unit test code>
```

### Error Message or Stack Trace
```
<Error message or stack trace>
```
"""

def user_prompt_v2(diff: str) -> str:
    return f"""
### Diff between fixed and buggy functions
```diff
{diff}
```
"""

def system_prompt_v2() -> str:
    return """You are an assistant designed to help generate synthetic data for fine-tuning a language model for automatic program repair. You will receive a bug-fixing diff as input. Your task is to generate:

1. A unit test method that exposes the behavioral difference between the buggy and fixed versions.
2. The stack trace or error message resulting from executing the unit test method.

Please ensure the following when generating the output:

- The unit test MUST be a method and directly target the introduced bug. ONLY the unit test method should be provided, not the entire test suite.
- The stack trace or error message must be realistic and show the exact execution of the unit test.
- Do NOT add any comment in the code about the bug, the bug injection or any other information.

Format the output in the following structure:
### Unit Test
```java
<Unit test code>
```

### Error Message or Stack Trace
```
<Error message or stack trace>
```
"""

In [14]:
import tiktoken

from dotenv import load_dotenv
from openai import OpenAI

load_dotenv()
client = OpenAI()

def call_openai(system_prompt: str, user_prompt: str, model: str = "gpt-4o-mini") -> str:
    enc = tiktoken.encoding_for_model(model)
    n_tokens = len(enc.encode(system_prompt)) + len(enc.encode(user_prompt))
    if n_tokens >= 128000:
        return None

    completion = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
    )

    return completion

def extract_output(completion):
    """
    Extracts the output from the completion object.
    
    Returns:
    - The generated buggy function
    - The generated test case
    - The generated test error or stack trace
    """
    message = completion.choices[0].message.content

    try:
        buggy_function = message.split("### Buggy Function")[1].split("```java")[1].split("```")[0].strip()
        test_case = message.split("### Unit Test")[1].split("```java")[1].split("```")[0].strip()
        error_message = message.split("### Error Message or Stack Trace")[1].split("```")[1].strip()
        return buggy_function, test_case, error_message
    except Exception as e:
        return None, None, None

def extract_output_v2(completion):
    """
    Extracts the output from the completion object.
    
    Returns:
    - The generated test case
    - The generated test error or stack trace
    """
    message = completion.choices[0].message.content

    try:
        test_case = message.split("### Unit Test")[1].split("```java")[1].split("```")[0].strip()
        error_message = message.split("### Error Message or Stack Trace")[1].split("```")[1].strip()
        return test_case, error_message
    except Exception as e:
        return None, None

In [15]:
completion = call_openai(system_prompt(), user_prompt(megadiff_sf["train"][0]["fixed_function"]), model="gpt-4o-mini")


buggy_function, test_case, error_message = extract_output(completion)

diff = compute_diff(megadiff_sf["train"][0]["fixed_function"].strip(), buggy_function.strip())

print("Generated Bug")
print(diff)

print("Generated Test Case")
print(test_case)

print("Generated Error Message")
print(error_message)

Generated Bug
diff --git a/tmp/buggy.java b/tmp/fixed_code.java
index 0fd41ef..e61758b 100644
--- a/tmp/buggy.java
+++ b/tmp/fixed_code.java
@@ -1,24 +1,24 @@
 private void ping() throws IOException, InterruptedException {
         Future<?> f = channel.callAsync(new Ping());
         long start = System.currentTimeMillis();
         long end = start +timeout;
 
         long remaining;
         do {
             remaining = end-System.currentTimeMillis();
             try {
-                f.get(Math.max(0,remaining),MILLISECONDS);
+                f.get(remaining, MILLISECONDS);
                 return;
             } catch (ExecutionException e) {
                 if (e.getCause() instanceof RequestAbortedException)
                     return; // connection has shut down orderly.
                 onDead(e);
                 return;
             } catch (TimeoutException e) {
                 // get method waits "at most the amount specified in the timeout",
                 // so 

In [16]:
completion = call_openai(system_prompt_v2(), user_prompt_v2(megadiff_sf["train"][1]["short_diff"]), model="gpt-4o-mini")
test_case, error_message = extract_output_v2(completion)

print(compute_diff(megadiff_sf["train"][1]["buggy_function"], megadiff_sf["train"][1]["fixed_function"], context_len=3))

print(test_case)

print(error_message)

diff --git a/tmp/buggy.java b/tmp/fixed_code.java
index 675f029..0289ad2 100644
--- a/tmp/buggy.java
+++ b/tmp/fixed_code.java
@@ -7,13 +7,13 @@
         boolean isBetter = false;
         switch (policy) {
             case MINIMIZE:
-                if (bestVal > val) {
+                if (bestVal > val || nbSol==1) {
                     bestVal = val;
                     isBetter = true;
                 }
                 break;
             case MAXIMIZE:
-                if (bestVal < val) {
+                if (bestVal < val || nbSol==1) {
                     bestVal = val;
                     isBetter = true;
                 }

@Test
public void testPolicyWithSingleSolution() {
    double bestVal = 10.0;
    double val = 5.0;
    int nbSol = 1; // Only one solution
    String policy = "MINIMIZE"; // Testing MINIMIZE policy

    // Call the buggy version of the method
    boolean result = buggyPolicyFunction(bestVal, val, nbSol, policy);

    // Assert that the bestVal has

In [126]:
import concurrent.futures

completions = []
tasks = []

with concurrent.futures.ThreadPoolExecutor() as executor:
    for i in range(len(megadiff_sf["train"][:1000]["short_diff"])):
        task = executor.submit(call_openai, system_prompt_v2(), user_prompt_v2(megadiff_sf["train"][i]["short_diff"]), model="gpt-4o-mini")
        tasks.append(task)

    completions = [future.result() for future in tasks]

  def _eval_type(t, globalns, localns, recursive_guard=frozenset()):


In [158]:
outputs = [extract_output_v2(completion) for completion in completions]

megadiff_sf_plus = megadiff_sf["train"].select(range(1000)).add_column("generated_test_case", [output[0] for output in outputs]).add_column("generated_error_message", [output[1] for output in outputs]).add_column("completion", [completion for completion in completions])
megadiff_sf_plus.save_to_disk("megadiff_sf_plus")

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [2]:
from datasets import load_from_disk

megadiff_sf_plus = load_from_disk("megadiff_sf_plus")

In [5]:
def build_prompt(buggy_code: str, failing_test_case: str, failing_test_case_error: str) -> str:
    return f"""You are an automatic program repair tool. Your task is to fix the provided buggy code.

The following code contains a buggy function:
```java
{buggy_code}
```

The code fails the following test:
```java
{failing_test_case}
```

With the following test error:
```
{failing_test_case_error}
```

Please provide a fixed version of the buggy function, and only that function:
"""

def build_answer(fixed_code: str) -> str:
    return f"```java\n{fixed_code}\n```"

In [6]:
# Create new columns with the prompt and answer
megadiff_sf_plus = megadiff_sf_plus.map(lambda x: {"prompt": build_prompt(x["buggy_function"], x["generated_test_case"], x["generated_error_message"])})
megadiff_sf_plus = megadiff_sf_plus.map(lambda x: {"answer": build_answer(x["fixed_function"])})

Map:   0%|          | 0/72393 [00:00<?, ? examples/s]

Map:   0%|          | 0/72393 [00:00<?, ? examples/s]

In [9]:
print(megadiff_sf_plus[0]["prompt"])
print(megadiff_sf_plus[0]["answer"])

You are an automatic program repair tool. Your task is to fix the provided buggy code.

The following code contains a buggy function:
```java
    private void ping() throws IOException, InterruptedException {
        Future<?> f = channel.callAsync(new Ping());
        long start = System.currentTimeMillis();
        long end = start +timeout;

        long remaining;
        do {
            remaining = end-System.currentTimeMillis();
            try {
                f.get(Math.max(0,remaining),MILLISECONDS);
                return;
            } catch (ExecutionException e) {
                if (e.getCause() instanceof RequestAbortedException)
                    return; // connection has shut down orderly.
                onDead(e);
                return;
            } catch (TimeoutException e) {
                // get method waits "at most the amount specified in the timeout",
                // so let's make sure that it really waited enough
            }
        } while(remain

In [10]:
megadiff_sf_plus.save_to_disk("megadiff_sf_plus_processed")

Saving the dataset (0/6 shards):   0%|          | 0/72393 [00:00<?, ? examples/s]

In [34]:
from datasets import load_from_disk

megadiff_sf_plus = load_from_disk("megadiff_sf_plus_processed")

In [35]:
# Deduplicate the dataset based on the exact match the short_diff and prompt
from functools import partial

def is_unique(elem , column: str, memory: set) -> bool:
    if elem[column] in memory:
        return False
    else:
        memory.add(elem[column])
        return True

memory = set()
megadiff_sf_plus_shortdiff_exactdedup = megadiff_sf_plus.filter(partial(is_unique, column="short_diff", memory=memory))

print(f"Number of duplicates (short_diff): {len(megadiff_sf_plus) - len(megadiff_sf_plus_shortdiff_exactdedup)}")

memory = set()
megadiff_sf_plus_prompt_exactdedup = megadiff_sf_plus_shortdiff_exactdedup.filter(partial(is_unique, column="prompt", memory=memory))

print(f"Number of duplicates (prompt): {len(megadiff_sf_plus_shortdiff_exactdedup) - len(megadiff_sf_plus_prompt_exactdedup)}")

Number of duplicates (short_diff): 6277
Number of duplicates (prompt): 0


In [59]:
import re

class Shingler:
    def __init__(self, k: int = 3):
        self.k = k
        # A very approximate tokenization for most programming languages
        self.NON_ALPHA = re.compile("[^A-Za-z_0-9]")

    def __process_doc(self, document):
        return [t for t in self.NON_ALPHA.split(document) if len(t.strip()) > 0]

    def get_shingles(self, document):
        shingles = set()
        document = self.__process_doc(document)
        for i in range(0, len(document)-self.k+1 ):
            shingles.add(" ".join(document[i:i+self.k]))
        return shingles

In [None]:
from datasketch import MinHash

import tqdm

min_hashes = []

# Compute minhashes for the prompt
for i, row in tqdm.tqdm(enumerate(megadiff_sf_plus_prompt_exactdedup), total=len(megadiff_sf_plus_prompt_exactdedup)):
    min_hash = MinHash(num_perm=128)
    shingles = Shingler().get_shingles(row["prompt"])
    for shingle in shingles:
        min_hash.update(shingle.encode('utf-8'))
    min_hashes.append(min_hash)

In [93]:
from datasketch import MinHashLSH

lsh = MinHashLSH(threshold=0.85, num_perm=128)

for i, min_hash in enumerate(min_hashes):
    lsh.insert(i, min_hash)

prompt_duplicates = []
for i, min_hash in enumerate(min_hashes):
    idxs = lsh.query(min_hash)
    if len(idxs) > 1:
        prompt_duplicates.append(idxs)

print(f"Number of duplicates: {sum(len(dups) for dups in prompt_duplicates)}")

Number of duplicates: 4470


In [85]:
short_diff_min_hashes = []

# Compute minhashes for the short_diff
for i, row in tqdm.tqdm(enumerate(megadiff_sf_plus_prompt_exactdedup), total=len(megadiff_sf_plus_prompt_exactdedup)):
    min_hash = MinHash(num_perm=128)
    shingles = Shingler().get_shingles(row["short_diff"])
    for shingle in shingles:
        min_hash.update(shingle.encode('utf-8'))
    short_diff_min_hashes.append(min_hash)

100%|██████████| 66116/66116 [01:08<00:00, 962.87it/s]


In [94]:
lsh = MinHashLSH(threshold=0.85, num_perm=128)

for i, min_hash in enumerate(short_diff_min_hashes):
    lsh.insert(i, min_hash)

shortdiff_duplicates = []
for i, min_hash in enumerate(short_diff_min_hashes):
    idxs = lsh.query(min_hash)
    if len(idxs) > 1:
        shortdiff_duplicates.append(idxs)

print(f"Number of duplicates: {sum(len(dups) for dups in shortdiff_duplicates)}")

Number of duplicates: 774


In [100]:
# Remove all samples that are duplicated (naive approach)
duplicates = set(sum(prompt_duplicates, [])) | set(sum(shortdiff_duplicates, []))

megadiff_sf_plus_dedup = megadiff_sf_plus_prompt_exactdedup.select(
    [i for i in range(len(megadiff_sf_plus_prompt_exactdedup)) if i not in duplicates]
)

megadiff_sf_plus_dedup.save_to_disk("megadiff_sf_plus_dedup")

Saving the dataset (0/5 shards):   0%|          | 0/64388 [00:00<?, ? examples/s]

In [103]:
megadiff_sf_plus_dedup.push_to_hub("ASSERT-KTH/megadiff-sf-synthetic_test_error")

Uploading the dataset shards:   0%|          | 0/5 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/13 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/13 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/13 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/13 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/13 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/ASSERT-KTH/megadiff-sf-synthetic_test_error/commit/ac8f2cbc294c4b37f8c71560b5684bcd1644e741', commit_message='Upload dataset', commit_description='', oid='ac8f2cbc294c4b37f8c71560b5684bcd1644e741', pr_url=None, pr_revision=None, pr_num=None)

In [9]:
from datasets import load_from_disk

import re

ds = load_from_disk("megadiff_sf_plus_dedup")

# remove comments from answer
def remove_java_comments(source: str) -> str:
    # Define states
    NORMAL, SINGLE_COMMENT, MULTI_COMMENT, STRING_LITERAL, CHAR_LITERAL = range(5)

    state = NORMAL
    result = []
    i = 0

    while i < len(source):
        # Check the current state and process accordingly
        if state == NORMAL:
            if source[i : i + 2] == "//":
                state = SINGLE_COMMENT
                i += 2
            elif source[i : i + 2] == "/*":
                state = MULTI_COMMENT
                i += 2
            elif source[i] == '"':
                state = STRING_LITERAL
                result.append(source[i])
                i += 1
            elif source[i] == "'":
                state = CHAR_LITERAL
                result.append(source[i])
                i += 1
            else:
                result.append(source[i])
                i += 1
        elif state == SINGLE_COMMENT:
            if source[i] == "\n":
                state = NORMAL
                result.append(source[i])
                i += 1
            else:
                i += 1
        elif state == MULTI_COMMENT:
            if source[i : i + 2] == "*/":
                state = NORMAL
                i += 2
            else:
                i += 1
        elif state == STRING_LITERAL:
            if source[i] == "\\":
                result.append(source[i])
                i += 1
                result.append(source[i])
                i += 1
            elif source[i] == '"':
                state = NORMAL
                result.append(source[i])
                i += 1
            else:
                result.append(source[i])
                i += 1
        elif state == CHAR_LITERAL:
            if source[i] == "\\":
                result.append(source[i])
                i += 1
                result.append(source[i])
                i += 1
            elif source[i] == "'":
                state = NORMAL
                result.append(source[i])
                i += 1
            else:
                result.append(source[i])
                i += 1

    return "".join(result)

def remove_empty_lines(source):
    """Remove all empty lines from Java source code."""
    return re.sub(r"^\s*$\n", "", source, flags=re.MULTILINE)

ds_no_comments = ds.map(lambda x: {"answer": "```java\n" + remove_empty_lines(remove_java_comments(x["answer"].split("```java")[1].split("```")[0])) + "```"})

Map:   0%|          | 0/64388 [00:00<?, ? examples/s]

In [13]:
ds_no_comments.push_to_hub("ASSERT-KTH/megadiff-sf-synthetic_test_error_no_comments")

Uploading the dataset shards:   0%|          | 0/4 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/17 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/17 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/17 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/17 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/ASSERT-KTH/megadiff-sf-synthetic_test_error_no_comments/commit/d3114a397ac24e0dce1c69789b439a91f2d7ce34', commit_message='Upload dataset', commit_description='', oid='d3114a397ac24e0dce1c69789b439a91f2d7ce34', pr_url=None, pr_revision=None, pr_num=None)