In [1]:
!pip install --upgrade openai



In [2]:
import os
import openai
import glob
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

In [5]:
def debug_code_with_gpt(code_content, API_KEY, few_shot, CoT):
    if few_shot == '0' and CoT == 'n':
        prompt = "The provided Java code is buggy. Fix the bug, using minimal changes. \
        Do not reorganize. Do not optimize. Do not provide explanation or justification. \
        Format your code in markdown.\n```java\n" + code_content + "\n```"
    elif few_shot == '0' and CoT == 'c':
        prompt = "The provided Java code is buggy. Review the Java code and identify the bug. \
        Explain the reasoning process, thinking step-by-step, for identifying and fixing the bug. \
        Apply the fix using minimal changes. Do not reorganize or optimize the code. \
        Format your code in markdown.\n```java\n" + code_content + "\n```"
    elif few_shot == 'f' and CoT == 'n':
        pass
    elif few_shot == 'f' and CoT == 'c':
        pass
    else:
        raise ValueError("few_shot should be '0' or 'f', CoT should be 'n' or 'c'.")
    try:
        client = openai.OpenAI(api_key=API_KEY)
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "user", "content": f"{prompt}"}
            ],
            max_tokens=512  # Adjust as needed
        )
        output = response.choices[0].message.content
#         lines = output.split('\n')
#         # Remove the first and last lines
#         modified_content = '\n'.join(lines[1:-1])
#         return modified_content
        return output
    except Exception as e:
        print("Error during API call:", e)
        return None

def process_file(file_path, target_directory, api_key, few_shot, CoT):
    with open(file_path, 'r') as file:
        file_content = file.read()

    debugged_content = debug_code_with_gpt(file_content, api_key, few_shot, CoT)

    if debugged_content:
        target_file_path = os.path.join(target_directory, os.path.basename(file_path))
        with open(target_file_path, 'w') as file:
            file.write(debugged_content)
#         print(f"Debugged code written to {target_file_path}")

def process_files_in_parallel(source_directory, target_directory, api_key, few_shot, CoT, max_workers=45):
    if not os.path.exists(target_directory):
        os.makedirs(target_directory)

    file_paths = glob.glob(os.path.join(source_directory, '*.java'))
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_file, file_path, target_directory, api_key, few_shot, CoT)\
                   for file_path in file_paths]
        for future in tqdm(as_completed(futures), total=len(file_paths)):
            future.result()  # This will re-raise any exception caught during process_file execution

In [6]:
if __name__ == "__main__":
    
    # '0' (zero-shot) or 'f' (few-shot)
    few_shot = '0'
    # 'n' (no Chain-of-Thought) or 'c' (with Chain-of-Thought)
    CoT = 'c'
    
    with open('API_KEY.txt', 'r') as file:
        api_key = file.read().strip()
    source_dir = 'data/raw/bug_codes'  # Source directory
    target_dir = f'data/raw/debugged_codes_{few_shot}_{CoT}'  # Target directory

    process_files_in_parallel(source_dir, target_dir, api_key, few_shot, CoT)

  0%|                                       | 1/1784 [00:30<15:00:08, 30.29s/it]


KeyboardInterrupt: 