#### BFS + FLATTEN + LCS_LENGTH Config ---> 4090

In [1]:
import json
from typing import Optional
from dataclasses import dataclass, field
from pathlib import Path

import torch
import transformers
from peft import PeftModel
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    GenerationConfig, 
    HfArgumentParser, 
    BitsAndBytesConfig,
)
from tqdm import tqdm

# If you need to use a specific GPU, you can set it here
# if torch.cuda.is_available():
#     # Set GPU:1 as the device
#     torch.cuda.set_device(1)
#     print(f"Using GPU: {torch.cuda.current_device()}")
# else:
#     print("CUDA is not available.")

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf", trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    "codellama/CodeLlama-7b-hf",
    torch_dtype=torch.float32,
    # load_in_8bit=True,
    trust_remote_code=True,
    quantization_config=BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_threshold=6.0
    ),
)

LORA = '16'
EPOCH = 1
model_folder_path = 'F:/My_APR/Experiment_CodeLlama/repairllama/model_CodeLlama/'
lora_folder_path = 'model_Lora{}/'.format(LORA)
epoch_nums = 'checkpoint-epoch-{}.0/'.format(EPOCH)

model = PeftModel.from_pretrained(
    model,
    model_folder_path + lora_folder_path + epoch_nums,
    torch_dtype=torch.float32,
)
model.config.pad_token = tokenizer.pad_token = tokenizer.unk_token
model.to(device)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32016, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_pro

In [2]:
def flatten_BFS_Beam_Search(buggy_Code, BEAM_NUM):

    inputs = tokenizer(buggy_Code, return_tensors="pt")
    inputs_len = inputs["input_ids"].shape[1]
    inputs_ids = inputs["input_ids"].to(device)

    generation_config = GenerationConfig(
        num_beams=BEAM_NUM,
        max_length = 512,
        early_stopping=True,
    )

    outputs = model.generate(
        input_ids=inputs_ids,
        max_new_tokens=256,
        num_return_sequences=BEAM_NUM, 
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        generation_config=generation_config,
    )

    output_ids = outputs[:, inputs_len:]
    output_patch = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

    return output_patch

#### BFS + Flatten + LCS_LENGTH Beam Search

In [3]:
import json
import os
import subprocess
import tempfile
import shutil
from tqdm import tqdm

def readJsonLine(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data.append(json.loads(line))
    return data

def createFolder(folder_path):
    try:
        if os.path.exists(folder_path):
            shutil.rmtree(folder_path)
        os.mkdir(folder_path)
    except:
        print('remove {} error'.format(folder_path))

def checkJavaFormat(java_code, jar_path, folder_path, patchFileName, buggy_ID):
    script_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in globals() else os.getcwd()
    os.chdir(script_dir)

    if not os.path.isfile(jar_path):
        return (f"Google Java Format JAR file not found: {jar_path}")

    with tempfile.NamedTemporaryFile(delete=False, suffix=".java") as temp_file:
        temp_filename = temp_file.name

        full_java_code = f"""
        public class {patchFileName} {{
            {java_code}
        }}
        """
        temp_file.write(full_java_code.encode())

    if not os.path.isfile(temp_filename):
        raise FileNotFoundError(f"Temporary file not found: {temp_filename}")

    result = subprocess.run(
        ["java", "-jar", jar_path, "--replace", temp_filename],
        capture_output=True,
        text=True
    )

    if result.returncode != 0:
        return (f"Google Java Format Error: {result.stderr}")

    with open(temp_filename, "r") as f:
        formatted_code = f.read()

    os.remove(temp_filename)

    formatted_code = importContent(buggy_ID) + '\n' + formatted_code

    print("PATH:", folder_path + '/' + patchFileName + '.java')


    with open(folder_path + '/' + patchFileName + '.java', 'w', encoding='utf-8') as file:
        file.write(formatted_code)

    # print(formatted_code)

    if 'Node' in formatted_code and 'WeightedEdge' in formatted_code:
        return 'Node WeightedEdge'
    if 'Node' in formatted_code:
        return 'Node'
    if 'WeightedEdge' in formatted_code:
        return 'WeightedEdge'
    
    return 'Java Format Check Successfully'

def checkJavaCompile(patchFilePath, javaFormatResult):
    try:
        output_dir = './class_file/'
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        java_files = [patchFilePath]

        if 'Node' in javaFormatResult:
            java_files.append('F:/My_APR/QuixBugTest/dataStructure/Node.java')
        
        if 'WeightedEdge' in javaFormatResult:
            java_files.append('F:/My_APR/QuixBugTest/dataStructure/WeightedEdge.java')

        result = subprocess.run(['javac', '-d', output_dir] + java_files, capture_output=True, text=True)
        
        if result.returncode == 0:
            return True
        else:
            return False
    except FileNotFoundError:
        print("Error: javac is not installed or not found in PATH.")
        return False

def importContent(fileName):
    import_folder_path = "F:/My_APR/QuixBugs_Program/eachFileImport/"
    import_content = ""
    file_import_path = import_folder_path + fileName + '_ImportInfo.java'

    with open(file_import_path, 'r', encoding='utf-8') as importFile:
        for line in importFile:
            import_content += line

    return import_content

if __name__ == '__main__':

    LORA = '16'
    PATCH = '01'
    EPOCH = '1'
    
    file_path = '../QuixBugs_Lora{}/QuixBugs_Lora{}_Patch{}/QuixBugs_Lora{}_E{}_Patch{}.jsonl'.format(LORA, LORA, PATCH, LORA, EPOCH, PATCH)
    google_java_format_path = "F:/My_APR/util/javaFormat/google-java-format-1.15.0-all-deps.jar"

    data = readJsonLine(file_path)

    pendingList = ['BREADTH_FIRST_SEARCH', 'FLATTEN']# 'LCS_LENGTH']
    # pendingList = ['LCS_LENGTH']

    
    with tqdm(total=len(data), desc="Processing Patches") as pbar:

        for item in data:
            index = 0

            buggy_ID = item['bug_id']
            buggy_Code = item['buggy_code']
            folder_path = 'F:/My_APR/Experiment_CodeLlama/repairllama/Verification_QuixBugs_Output/Analysis/Module_{}'.format(buggy_ID)

            if buggy_ID not in pendingList:
                pbar.update(1)
                continue
            
            print("folder_path:", folder_path)
            print("LORA:{} PATCH:{} EPOCH:{}".format(LORA, PATCH, EPOCH))
            createFolder(folder_path)        
                
            print(buggy_ID)

            BEAM_NUM = len(item['output'])

            for i in range(BEAM_NUM):
                patch = item['output'][str(i)]['output_patch']
                patch = patch.replace('</s>', '')
                patch = patch.strip()
                patchCode = buggy_Code.replace('<FILL_ME>', patch, 1)
                patchCode = patchCode.replace('// buggy code', '', 1)

                # print("Patch Code:", patchCode)

                results = flatten_BFS_Beam_Search(patchCode, BEAM_NUM)

                print("i:",i,"results:",results)

                for result in results:
                    patchFileName = buggy_ID + '_TEST_' + str(index)
                    patchFilePath = folder_path + '/' + patchFileName + '.java'


                    patchCodeTwice = patchCode.replace('<FILL_ME>', result.replace('</s>','').strip(), 1)
                    javaFormatResult = checkJavaFormat(patchCodeTwice, google_java_format_path, folder_path, patchFileName, buggy_ID)

                    if javaFormatResult.startswith('Google Java Format Error'):
                        # print(javaFormatResult)
                        continue
                    
                    checkCompileResult = checkJavaCompile(folder_path + '/' + patchFileName + '.java', javaFormatResult)
                    # print(checkCompileResult)
                    # print(index,folder_path + '/' + patchFileName + '.java')
                    
                    if checkCompileResult is False:
                        os.remove(patchFilePath)

                    index = index + 1
                
            pbar.update(1)

    print("============================ Step1 LORA:{} PATCH:{} EPOCH:{} Done =================================".format(LORA, PATCH, EPOCH))


FileNotFoundError: [Errno 2] No such file or directory: '../QuixBugs_Lora16/QuixBugs_Lora16_Patch01/QuixBugs_Lora16_E1_Patch01.jsonl'