In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'
os.environ['TRANSFORMERS_CACHE'] = "/scratch/.cache"
os.environ['HF_DATASETS_CACHE'] = "/scratch/.cache"
os.environ['HF_HOME'] =  "/scratch/.cache"

In [None]:
from transformers import AutoTokenizer
import transformers
import torch
import jsonlines
from tqdm import tqdm
import re 

# version = 
model = "bigcode/octocoder"

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    torch_dtype=torch.float16,
    device_map="auto",
)


In [None]:


prompt_template = '''Question: You are an expert programmer. Your tasl is to refactor the given code to improve readability, maintainability, and performance without changing functionality. 

The code to refactor is provided between `<code>` and `</code>` tags.

Refactor the following code:
<code>{input_code}</code>

Please return the refactored code without any explanation.
Answer:'''


def run_llm(prompt, max_new_tokens=512):
    response = pipeline(
        prompt,
        do_sample=True,
        top_k=10,
        temperature=0.1,
        top_p=0.95,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        max_new_tokens=max_new_tokens,
        return_full_text=False,
        )
    return response[0]['generated_text']



def get_clean_code(response):
    pattern = r'```(\w+)?\n(.*?)\n```'
    code_match = re.search(pattern, response, re.DOTALL)
    code = code_match.group(2).strip() if code_match else response
    return code

## Running on HumanEval

In [None]:


data_path = '../dataset/HumanEvalPlusOrig.jsonl'
output_path = f'../results/octocoder/HumanEvalRefactored.jsonl'



humaneval = []
with open(data_path, 'r') as fp:
    for item in jsonlines.Reader(fp):
        humaneval.append(item)
print(len(humaneval))


for sample in tqdm(humaneval):
    prompt = prompt_template.format(input_code = sample['prompt'])
    response = run_llm(prompt=prompt)
    # print(response)
    refactored_code = get_clean_code(response)
    if refactored_code is  None:
        print("Code Parsing ERROR: ", response)
        refactored_code = '' 
    sample['refactored_code'] = refactored_code

import json
with open(output_path,'w') as fp:
    for sample in humaneval:
        fp.write(json.dumps(sample) + '\n')
print('Output Saved!')

In [None]:
count = 0
for sample in humaneval:
    if len(sample['refactored_code']) == 0:
        count += 1
    else:
        print(sample['refactored_code'], '\n')
print(count)

## Running on GuruSamples

In [None]:
import glob
import os 

data_folder = '../dataset/GuruSamples'
output_folder = f'../results/octocoder/GuruSamples'

for folder_name in os.listdir(data_folder):
    folder_path = os.path.join(data_folder, folder_name)
    if os.path.isdir(folder_path):
        pattern = os.path.join(folder_path, '*_before.*')

# iter  = 10
    print(folder_name)
    # if folder_name != 'typescript':
    #     continue
    for file_path in glob.glob(pattern):
        filename = os.path.basename(file_path)
        with open(file_path, 'r') as file:
            code_content = file.read()
        # print(filename)
        
        # print(output_filename)

        # print('Input: ', code_content, '\n')
        prompt = prompt_template.format(input_code = code_content)

        response = run_llm(prompt=prompt)
        # print('LLM OUTPUT: ', response)
        refactored_code = get_clean_code(response)
        # print('CLEAN CODE: ', refactored_code)
        if refactored_code is None:
            refactored_code = ''
            print('Parsing Error!, ', response)

        # break

        ouput_folder_path = os.path.join(output_folder, folder_name)
        if not os.path.exists(ouput_folder_path):
            os.mkdir(ouput_folder_path)
        output_filename = filename.replace('_before', '_refactored')
        with open(os.path.join(ouput_folder_path, output_filename), 'w') as output_file:
            output_file.write(refactored_code)

# parent_folder = os.path.dirname(output_folder)
# print(parent_folder)
# if not os.path.exists(parent_folder):
#     os.mkdir(parent_folder)

164


164
