In [169]:
import os
import random
import re
from glob import glob
from pathlib import Path

import dotenv
from langchain.chat_models import ChatOpenAI
from tqdm import tqdm

from utils import get_all_function, extract_def_block

dotenv.load_dotenv()

open_ai_key = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(openai_api_key=open_ai_key, model_name="gpt-4-1106-preview")

# TODO change to PATH object 
generated_code_dir = Path('generated_code/gpt4')
generated_code_dir.mkdir(parents=True, exist_ok=True)


In [132]:
def filter_comment(code):
    regex = r"\"\"\".*\"\"\""
    matches = re.finditer(regex, code, re.MULTILINE | re.DOTALL)

    for match in matches:
        start = match.start()
        end = match.end()
        comment = code[start:end]
        code = code.replace(comment, '')

    code = code.strip()
    return code


def generate_fixed_code(code):
    prompt = (
        'Please fix the code below:\n'
        f'{code}\n'
        'The fixed code:'
    )

    llm_response = llm.invoke(prompt, timeout=20)

    return llm_response.content



In [85]:
all_code = []

all_code_file = glob('datasets/QuixBugs/python_programs/*.py')
# filter the test code
all_code_file = [x for x in all_code_file if not re.match(r'.*test.py', x) and not 'node' in x]
all_code_file = sorted(all_code_file)

index = random.randint(0, len(all_code_file) - 1)
code_file_name = all_code_file[index]

with open(code_file_name, 'r') as f:
    code = f.read()

code = filter_comment(code)
print(code)


def knapsack(capacity, items):
    from collections import defaultdict
    memo = defaultdict(int)

    for i in range(1, len(items) + 1):
        weight, value = items[i - 1]

        for j in range(1, capacity + 1):
            memo[i, j] = memo[i - 1, j]

            if weight < j:
                memo[i, j] = max(
                    memo[i, j],
                    value + memo[i - 1, j - weight]
                )

    return memo[len(items), capacity]


In [172]:
dir_done = os.listdir(generated_code_dir)
dir_done = [x for x in dir_done if '.' not in x]
dir_done = sorted(dir_done)

function_list = get_all_function()

if len(dir_done) == 0:
    last_index = 0
else:
    last_index = function_list.index(dir_done[-1])

    if len(list((generated_code_dir / dir_done[-1]).glob('*.py'))) == 5:
        last_index += 1

for i in tqdm(range(last_index, len(all_code_file))):
    code_file_name = all_code_file[i]

    with open(code_file_name, 'r') as f:
        code = f.read()
    code = filter_comment(code)
    function_name = code_file_name.split('/')[-1][:-3]

    (generated_code_dir / function_name).mkdir(parents=True, exist_ok=True)

    prompt = (
        f'{code}\n'
        'Please fix the code above.\n'
        '\n'
        'The fixed code:'
    )
    with open(generated_code_dir / function_name / (function_name + '_input.txt'), 'w') as f:
        f.write(prompt)

    for j in range(4):
        generated_file_path = generated_code_dir / function_name / (function_name + f'_{j + 1}.py')
        if generated_file_path.exists():
            continue

        llm_response = llm.invoke(prompt, timeout=60)
        fixed_code = llm_response.content

        with open(generated_file_path.with_suffix('.txt'), 'w') as f:
            f.write(fixed_code)

        cleaned_code = extract_def_block(fixed_code)
        with open(generated_file_path, 'w') as f:
            f.write(fixed_code)



100%|██████████| 39/39 [1:17:15<00:00, 118.86s/it]
