In [33]:
import os
import sys
import json
import pathlib
from tqdm import tqdm
from dotenv import load_dotenv

load_dotenv()

# module_path = str(pathlib.Path(os.path.abspath(__file__)).parent.parent)
# sys.path.append(module_path)

True

In [7]:
cd /home/arnaik/OracleProject

/home/arnaik/OracleProject


In [5]:
sys.path.append("/home/arnaik/OracleProject")

In [39]:
from src.datautils import load_stack_dump, load_ruff_results, load_ruff_idiom_specs, idiom_spec_extractor_for_ruff

In [38]:
stack_data = load_stack_dump("data/STACK-V2", as_dict=True)

112it [00:14,  7.65it/s]


In [36]:
code_idiom_specs = load_ruff_idiom_specs("./data/ruff_pages")
ruff_results = load_ruff_results("data/ruff_results", as_dict=True)

94it [00:12,  7.79it/s]


KeyboardInterrupt: 

In [42]:
META_LINTING_PROMPT_V2 = """Look at the following list of code idiom specifications with definitions and examples:
{LIST_OF_IDIOM_SPECS}

Given these idioms, your task is to look at a code file and detect violations of the above idioms, and flag them like a linter. You should also suggest a fix if possible. Report the results per idiom specification mentioned above and just say 'NO VIOLATIONS FOUND' if no violations are found for a given idiom. Do not detect any idioms not specified above.

Code file:
{CODE_FILE}

Violations per idiom:
"""

In [66]:
def generate_response_from_violations(violations, stack_file_lines: list[str], meta_task_idiom_codes, include_message: bool=False, add_line_numbers: bool=False):
    filt_violations = [violation for violation in violations if violation['code'] in meta_task_idiom_codes]
    grouped_violations = {code: [] for code in meta_task_idiom_codes}
    # group violations by each idiom in the meta-task.
    for violation in filt_violations:
        grouped_violations[violation['code']].append(violation)
    # sort violations by start position.
    response = ""
    for code, violations in grouped_violations.items():
        grouped_violations[code] = sorted(violations, key=lambda x: (x['location']['row'], x['location']['column'])) 
        if len(violations) == 0:
            response += f"**Idiom {code} Violations:**\n\nNO VIOLATIONS FOUND\n\n"
        else: 
            response += f"**Idiom {code} Violations:**\n"
            for num, violation in enumerate(violations):
                if include_message:
                    det_dict = {"line": "", "span": "", "message": violation["message"], "fix": None}
                else: det_dict = {"line": "", "span": "", "fix": None}
                det_line = []
                det_span = []
                edits = []

                for lineno in range(violation['location']['row'], violation['end_location']['row']+1):
                    # print(stack_file_lines[lineno-1])
                    line = stack_file_lines[lineno-1]

                    if add_line_numbers:
                        det_line.append(f"{str(lineno).rjust(3)} {line}")
                    else: det_line.append(f"{line}")
                    # populate span.
                    span_line = line
                    if lineno == violation['location']['row'] and lineno == violation['end_location']['row']:
                        span_line = line[violation["location"]["column"]-1:violation["end_location"]["column"]-1]
                    elif lineno == violation['location']['row']: # start line for multi-line span.
                        span_line = line[violation["location"]["column"]-1:]
                    elif lineno == violation['end_location']['row']: # end line for multi-line span.
                        span_line = line[:violation["end_location"]["column"]-1]
                    else: # intermediate line for multi-line span.
                        span_line = line
                    det_span.append(span_line)
                det_dict["line"] = "\n".join(det_line)
                det_dict["span"] = "\n".join(det_span)
                if violation["fix"] is not None and violation["fix"]["applicability"] == "safe":
                    for edit in violation["fix"]["edits"]:
                        before_span = []
                        after_span = edit["content"]
                        for lineno in range(edit["location"]["row"], edit["end_location"]["row"]+1):
                            # print(violation["fix"])
                            line = stack_file_lines[lineno-1]
                            # populate span.
                            span_line = line
                            if lineno == edit['location']['row'] and lineno == edit['end_location']['row']:
                                span_line = line[edit["location"]["column"]-1:edit["end_location"]["column"]-1]
                            elif lineno == edit['location']['row']: # start line for multi-line span.
                                span_line = line[edit["location"]["column"]-1:]
                            elif lineno == edit['end_location']['row']: # end line for multi-line span.
                                span_line = line[:edit["end_location"]["column"]-1]
                            else: # intermediate line for multi-line span.
                                span_line = line
                            before_span.append(span_line)

                        before_span = "\n".join(before_span)
                        edits.append({"before": before_span, "after": after_span})
                    det_dict["fix"] = edits

                response += f"\n{json.dumps(det_dict)}"
            response += "\n\n"

    return response

def reprocess_data(train_data, code_idiom_specs: dict, ruff_results: dict, stack_data: dict, add_line_numbers: bool=True):
    for rec in tqdm(train_data):
        blob_id = rec["id"].split("_")[-1].strip()
        meta_task_idiom_codes = rec["id"].split("_")[0].strip().split("-")
        stack_file = stack_data[blob_id]['content']
        stack_file_lines = stack_data[blob_id]['content'].split("\n")
        violations = ruff_results[blob_id]['violations']

        response = generate_response_from_violations(
            violations=violations, 
            stack_file_lines=stack_file_lines, 
            meta_task_idiom_codes=meta_task_idiom_codes,
            add_line_numbers=add_line_numbers
        )

        stack_file_with_lineno = []
        if add_line_numbers:
            for lineno, line in enumerate(stack_file.split("\n")):
                stack_file_with_lineno.append(f"{str(lineno+1).rjust(3)} {line}")
        
        if add_line_numbers:
            CODE_FILE = "\n".join(stack_file_with_lineno)
        else: CODE_FILE = stack_file
        LIST_OF_IDIOM_SPECS = "\n\n".join([idiom_spec_extractor_for_ruff(code_idiom_specs[idiom_code]) for idiom_code in meta_task_idiom_codes])

        rec["messages"][0]['content'] = META_LINTING_PROMPT_V2.format(LIST_OF_IDIOM_SPECS=LIST_OF_IDIOM_SPECS, CODE_FILE=CODE_FILE)
        rec["messages"][1]['content'] = response

    return train_data

In [68]:
split = "train"
data = json.load(open(f"./data/ruff_meta_linting/{split}_v4.json"))

proc_data = reprocess_data(train_data=data, code_idiom_specs=code_idiom_specs, ruff_results=ruff_results, stack_data=stack_data, add_line_numbers=True)
with open(f"./data/ruff_meta_linting/{split}_v4_new_format_with_lineno.json", "w") as f:
    json.dump(proc_data, f, indent=4)

100%|██████████| 86782/86782 [00:04<00:00, 17852.61it/s]
