In [1]:
import os
import json
import re
import glob

In [2]:
def extract_imports(code):
    return re.findall(r"(import .*?;)", code)

In [3]:
# Function for numerical sort
def numericalSort(value):
    numbers = re.compile(r'(\d+)')
    parts = numbers.split(value)
    parts[1::2] = map(int, parts[1::2])
    return parts

In [4]:
def get_correct_imports(outputFiles):
    correct_outputs = []
    for output in outputFiles:
        correct_output_list = json.load(open(output))['total_imports']
        correct_output_list = ["import "+i+";" for i in correct_output_list]
        correct_outputs.append(correct_output_list)

    for import_lines in correct_outputs:
        if "import gen.R;" in import_lines:
            import_lines.remove("import gen.R;")

    return correct_outputs

In [5]:
source_path = "/home/azmain/snr_all_json/"
os.chdir(source_path)

outputFiles = []

all_files = glob.glob(os.path.join(source_path, "*.benchmark_log.json"))

filtered_files = [f for f in all_files if "HibernateUtil" not in os.path.basename(f)]

sorted_files = sorted(filtered_files, key=numericalSort)

for file in sorted_files:
    outputFiles.append(file)

true_list = get_correct_imports(outputFiles)

In [6]:
def compare_import_statements(predicted_list, true_list):
    if len(predicted_list) != len(true_list):
        print("The two lists must have the same length.")

    counts = {
        "Same": 0,
        "Different": 0,
        "Extra": 0,
        "Missing": 0,
        "None": 0
    }

    for predicted, ground_truth in zip(predicted_list, true_list):
        predicted_set = set(predicted)
        ground_truth_set = set(ground_truth)

        if predicted_set == ground_truth_set:
            counts["Same"] += 1
        elif not predicted_set:
            counts["None"] += 1
        elif predicted_set.issubset(ground_truth_set):
            counts["Missing"] += 1
        elif predicted_set.issuperset(ground_truth_set):
            counts["Extra"] += 1
        else:
            counts["Different"] += 1

    return counts

In [7]:
def result_calc(source_path):
    os.chdir(source_path)

    all_files = glob.glob(os.path.join(source_path, "*.json"))

    # Filter out the "HibernateUtil" file based on the filename
    filtered_files = [f for f in all_files if "HibernateUtil" not in os.path.basename(f)]


    sorted_files = sorted(filtered_files, key=numericalSort)
    predicted_list = []

    # From JSON file
    for filepath in sorted_files:
        with open(filepath, 'r') as file:
            data = json.load(file)
            java_code = data.get("java_code", "")
            imports = extract_imports(java_code)
            
            predicted_list.append(imports)

    result = compare_import_statements(predicted_list, true_list)
    return result

### For SnR

In [8]:
# For SnR

print("For SnR:\n")

source_path = "/home/azmain/snr_fixed/"
os.chdir(source_path)

all_files = glob.glob(os.path.join(source_path, "*.java"))

# Filter out the "HibernateUtil" file based on the filename
filtered_files = [f for f in all_files if "HibernateUtil" not in os.path.basename(f)]


sorted_files = sorted(filtered_files, key=numericalSort)
predicted_list = []

# From .java file
for filepath in sorted_files:
    with open(filepath, 'r') as file:
        java_code = file.read()  # assuming the java file is plain text and not in JSON format
        imports = extract_imports(java_code)
        predicted_list.append(imports)

result = compare_import_statements(predicted_list, true_list)
print(result)

For SnR:

{'Same': 143, 'Different': 78, 'Extra': 4, 'Missing': 35, 'None': 7}


### For Base Prompt

In [9]:
# For Base Prompt

print("For Base Prompt:\n")
source_path = "/home/azmain/GitHub Codes/base_prompt_combined_logs/"
print(result_calc(source_path))

For Base Prompt:

{'Same': 173, 'Different': 55, 'Extra': 14, 'Missing': 19, 'None': 6}


### For Zero-shot Self-consistency with 10 sample 5 attempt 

In [10]:
# For Zero-shot Self-consistency with 10 sample 5 attempt

print("For Zero-shot Self-consistency with 10 sample 5 attempt:\n")
source_path = "/home/azmain/GitHub Codes/zero-shot_self_c_10_sample_5_attempt_combined/"
print(result_calc(source_path))

For Zero-shot Self-consistency with 10 sample 5 attempt:

{'Same': 202, 'Different': 35, 'Extra': 16, 'Missing': 14, 'None': 0}


### For Zero-shot Self-consistency from 10 Sample

In [11]:
# For Zero-shot Self-consistency from 10 Sample

print("For Zero-shot Self-consistency from 10 Sample:\n")
source_path = "/home/azmain/GitHub Codes/zero_shot_combined_logs/"
print(result_calc(source_path))

For Zero-shot Self-consistency from 10 Sample:

{'Same': 191, 'Different': 48, 'Extra': 14, 'Missing': 12, 'None': 2}


### For Zero-shot without Self-consistency (temp 0.5)

In [12]:
# For Zero-shot without Self-consistency (temp 0.5)

print("For Zero-shot without Self-consistency (temp 0.5):\n")
source_path = "/home/azmain/GitHub Codes/temp_0.5_no_self_c_combined_logs/"
print(result_calc(source_path))

For Zero-shot without Self-consistency (temp 0.5):

{'Same': 180, 'Different': 58, 'Extra': 14, 'Missing': 14, 'None': 1}


### For Source SO Dataset with Compile Fixing

In [13]:
# For Source SO Dataset with Compile Fixing

print("For Source SO Dataset with Compile Fixing:\n")
source_path = "/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/logs/all_so_logs/"
print(result_calc(source_path))

For Source SO Dataset with Compile Fixing:

{'Same': 175, 'Different': 48, 'Extra': 22, 'Missing': 19, 'None': 3}
