In [1]:
from datasets import load_dataset
import json
import polars as pl

In [2]:
import libcst as cst
import re
import random
from openbugger.bugger import Bugger, bugger_example
from time import perf_counter

In [3]:
all_bugs_names = ['ReturningEarlyTransformer', 'SwapForTransformer', 'VariableNameTypoTransformer', 'ForgettingToUpdateVariableTransformer', 'MutableDefaultArgumentTransformer', 'UseBeforeDefinitionTransformer', 'OffByKIndexTransformer', 'ComparisonSwapTransformer', 'InfiniteWhileTransformer', 'MissingArgumentTransformer', 'IncorrectExceptionHandlerTransformer', 'IncorrectTypeTransformer', 'ComparisonTargetTransformer', 'IncorrectVariableInitializationTransformer', 'NonExistingMethodTransformer']

In [4]:
import libcst as cst
import libcst.matchers as m
from black import format_str, FileMode

class RemoveComments(cst.CSTTransformer):
    def leave_Comment(self, original_node, updated_node):
        return cst.RemovalSentinel.REMOVE

def remove_comments_and_lint(module_str: str) -> str:
    module_cst = cst.parse_module(module_str)
    module_cst_no_comments = module_cst.visit(RemoveComments())
    code_str_no_comments = module_cst_no_comments.code
    # Lint the code using black
    linted_str = format_str(code_str_no_comments, mode=FileMode())
    return linted_str

def compare_modules(module1: str, module2: str) -> bool:
    # Remove comments and lint both modules
    module1_clean = remove_comments_and_lint(module1)
    module2_clean = remove_comments_and_lint(module2)
    # Parse the cleaned code strings back to libcst.Module for deep comparison
    module1_cst_clean = cst.parse_module(module1_clean)
    module2_cst_clean = cst.parse_module(module2_clean)
    return module1_cst_clean.deep_equals(module2_cst_clean)

In [5]:
import libcst as cst
import re
from functools import lru_cache

def is_valid_python(code):
    if len(code) == 0:
        return False
    try:
        cst.parse_module(code)
        return True
    except Exception:
        return False

def is_single_word_line(line):
    words = re.findall(r'\b\w+\b', line)
    return len(words) == 1

def find_enclosed_newlines(input_string):
    enclosed_newlines = []
    enclosing_strings = []
    string_patterns = [r'"[^"\\]*(\\.[^"\\]*)*"', r"'[^'\\]*(\\.[^'\\]*)*'", r'"""(.*?)"""', r"'''(.*?)'''"]
    
    for pattern in string_patterns:
        for match in re.finditer(pattern, input_string, re.DOTALL):
            start, end = match.span()
            inner_newlines = [i for i in range(start, end) if input_string[i] == '\n']
            enclosed_newlines.extend(inner_newlines)
            enclosing_strings.extend([match.group()] * len(inner_newlines))
            
    return enclosed_newlines, enclosing_strings


def split_input_with_respect_to_enclosed_newlines(input_string):
    enclosed_newlines, _ = find_enclosed_newlines(input_string)
    lines = []
    line_start = 0
    for i, char in enumerate(input_string):
        if char == '\n' and i not in enclosed_newlines:
            lines.append(input_string[line_start:i])
            line_start = i+1
    lines.append(input_string[line_start:])  # add the last line
    return lines



def cst_module(code):
    try:
        module = cst.parse_module(code)
        return module
    except Exception:
        return None


def extract_python_blocks(input_string, start=0, intervals=None):
    if intervals is None:
        intervals = {"Python": [], "Non-Python": []}

    lines = input_string.split('\n')
    n = len(lines)

    for i in range(n):
        if not lines[i].strip() or lines[i].lstrip().startswith('#') or is_single_word_line(lines[i]):  # ignore empty start lines
            continue
        for j in range(n-1, i-1, -1):
            if not lines[j].strip() or lines[j].lstrip().startswith('#') or is_single_word_line(lines[j]):  # ignore empty end lines
                continue
            code = '\n'.join(lines[i:j+1])
            if is_valid_python(code):
                intervals["Python"].append((i+start, j+start))
                if i > 0:
                    intervals["Non-Python"].append((start, i+start-1))
                if j < n-1:
                    remaining = '\n'.join(lines[j+1:])
                    return extract_python_blocks(remaining, j+start+1, intervals)
                return intervals
    if n > 0 and (start, start+n-1) not in intervals["Non-Python"]:
        intervals["Non-Python"].append((start, start+n-1))
    return intervals

In [6]:



def extract_strings_from_intervals(input_string, intervals):
    lines = input_string.split('\n')
    non_python_text = ''
    for interval in intervals['Non-Python']:
        start, end = interval
        non_python_text += '\n'.join(lines[start:end + 1]) + '\n'
    return non_python_text.strip()

def extract_python_from_intervals(input_string, intervals):
    lines = input_string.split('\n')
    python_code_list = []
    if len(intervals['Python']) == 0:
        return None
    for interval in intervals['Python']:
        python_code = ''
        start, end = interval
        python_code += '\n'.join(lines[start:end + 1]) + '\n'
        if python_code.strip() == '':
            python_code_list.append(None)
        else:
            python_code_list.append(python_code.strip())
        
    #check there actually is python code
    
    return python_code_list

In [7]:
def check_markdown(before, after):
    is_python_markdown = before.strip().startswith('```python') and after.strip().startswith('```')
    is_plain_markdown = before.strip() == '```' and after.strip() == '```' and not is_python_markdown
    is_other_markdown = before.strip().startswith('```') and not is_python_markdown and not is_plain_markdown
    is_no_markdown = not before.strip().startswith('```') and not after.strip().startswith('```') and not is_other_markdown
    is_unclosed = (before.strip().startswith('```') != after.strip().startswith('```'))  # Changed this line

    return is_python_markdown, is_plain_markdown, is_other_markdown, is_no_markdown, is_unclosed

def detect_markdown_blocks(input_string, intervals):
    lines = input_string.split('\n')
    markdown_info = {}
    i=0
    for interval in intervals['Python']:
        start, end = interval
        before = lines[start-1] if start-1 >= 0 else ''
        after = lines[end+1] if end+1 < len(lines) else ''
        # If the block is at the end of the string and doesn't have '```' after it,
        # don't assume it's enclosed in markdown
        if end == len(lines) - 1 and not after.strip().startswith('```'):
            after = ''
        markdown_info[interval] = {'is_python_markdown': False, 'is_plain_markdown': False, 'is_other_markdown': False, 'is_no_markdown': False, 'is_unclosed': False}
        markdown_info[interval]['is_python_markdown'], markdown_info[interval]['is_plain_markdown'], markdown_info[interval]['is_other_markdown'], markdown_info[interval]['is_no_markdown'], markdown_info[interval]['is_unclosed'] = check_markdown(before, after)
        i+=1
    return markdown_info


def uniform_markdown(input_string, markdown_info):
    lines = input_string.split('\n')
    for interval, info in markdown_info.items():
        start, end = interval
        is_python_markdown, is_plain_markdown, is_other_markdown, is_no_markdown, is_unclosed = info.values()

        # Convert all blocks to Python markdown
        if is_plain_markdown or is_other_markdown or is_no_markdown or is_unclosed:
            # Add new lines if the block was originally without markdown or unclosed
            if is_no_markdown or is_unclosed:
                if start - 1 >= 0:
                    lines[start - 1] = lines[start - 1] + '\n' + '```python'
                else:
                    lines.insert(0, '```python')

                if end + 1 < len(lines):
                    lines[end + 1] = '```' + '\n' + lines[end + 1]
                else:
                    lines.append('```')
            else:
                if start - 1 >= 0:
                    lines[start - 1] = '```python'
                if end + 1 < len(lines):
                    lines[end + 1] = '```'
    return '\n'.join(lines)


In [8]:
def check_for_original_code(message,original_code):
    python_ids = extract_python_blocks(message)
    python_code = extract_python_from_intervals(message, python_ids)
    if python_code is None:
        return False
    for code in python_code:
        comparison = compare_modules(code,original_code)
        if comparison:
            return True
    return False




In [9]:
def check_and_modify(message, original_code,bug,question=False):
    if not isinstance(message, str):
        dict_out = {"message": None, "status": None}
        return None
    added_string = "Here is the corrected code: " if not question else "Here is my code: "
    added_name = "_assistant_checked" if not question else "_user_checked"
    python_ids = extract_python_blocks(message)
    python_code = extract_python_from_intervals(message, python_ids)
    max_length_code = max(python_code, key=len, default=None) if python_code is not None else None
    
    if check_for_original_code(message, original_code):
        out = message
        status = "found"
    else:
        newline = chr(10)
        
        if max_length_code is None:
            # If no Python code is found in the message, append the original code
            out = message + newline + newline+ added_string + newline  + original_code
            status = "nocode"
        else:
            # If non matching Python code is found in the message, replace the longest one with the original code
            out = message.replace(max_length_code, original_code)
            status = "replaced"
    code_blocks = extract_python_blocks(out)
    md_blocks = detect_markdown_blocks(out, code_blocks)

    cleaned_out = uniform_markdown(out, md_blocks)
    dict_out = {bug+added_name: cleaned_out, bug+added_name+"_status": status}
    return [dict_out]


In [10]:
df = pl.read_parquet("df_complete_leetcode_dirty.parquet")

In [11]:
#now I have to check that all bugs questions contain the corresponding bugged code
# and that all the debugging instructions contain the corresponding debugged code
i = 4
code_example = df["code"][i]
bug_code_example = df["ReturningEarlyTransformer_code"][i]
question_example = df["ReturningEarlyTransformer_user"][i]
answer_example = df["ReturningEarlyTransformer_assistant"][i]

In [12]:
bug = all_bugs_names[0]

In [13]:
import os
import time
#check os if the file exists
if os.path.exists("df_complete_leetcode_corrected.parquet"):
    updated_df = pl.read_parquet("df_complete_leetcode_corrected.parquet")
else:
    updated_df = df
for bug in all_bugs_names :
    if bug+"_assistant_checked" not in updated_df.columns:
        print("correcting bug "+bug)
        start = time.time()
        updated_df = updated_df.with_columns(pl.struct(["code", bug+"_user"]).apply(lambda row: check_and_modify(row[bug+"_user"],row["code"],bug=bug,question=True)).list.first().alias(bug+"_user_corrected")).unnest(bug+"_user_corrected")
        updated_df = updated_df.with_columns(pl.struct(["code", bug+"_assistant"]).apply(lambda row: check_and_modify(row[bug+"_assistant"],row["code"],bug=bug,question=False)).list.first().alias(bug+"_assistant_corrected")).unnest(bug+"_assistant_corrected")
        updated_df.write_parquet("df_complete_leetcode_corrected.parquet")
        end = time.time()
        print("time for bug "+bug+": "+str(end-start))
    else:
        print("bug "+bug+" already corrected")


correcting bug ReturningEarlyTransformer


In [None]:
test['ReturningEarlyTransformer_assistant_checked_status'].value_counts()

ReturningEarlyTransformer_assistant_checked_status,counts
str,u32
"""found""",1250
,341
"""replaced""",495
"""nocode""",262
