# Single Iteration C Code APR Using LLM Experiments

Fine-tune existing LLMs to detect vulnerabilities, modify the LLM engine's architecture to repair vulnerable AI code, and then integrate and demonstrate it within our repair framework
1. Design APR prompts
2. Use the following hardware: University CSF Facility
3. Use ~~Llama, Falcon, GPT-2, Safecoder-beta/Starchat-beta~~ Use GPT3.5 Turbo as the CSF 3 servers were not functional

## Methodology

True random select 100 random samples of Expanded NeuroCodeBench that have `VERIFICATION FAILED`.

Run each sample through GPT-3.5 Turbo with each prompt.

Total API calls: `10 * 100 = 1000`

## Data Collection

Run through ~100 samples per LLM per prompt and record and collect the following metrics:
* ~~How easy it is to setup & run~~
* Cost of running the LLM
* How long it takes to repair samples
* No. of Verification Successful/Not Compile/Verification Failed
* ~~How many attempts each LLM takes to repair a sample~~

## Obtaining Samples

Get true random filenames. Run the `get_filenames.sh` script to get the filenames of all the `VERIFICATION FAILED` samples. The filenames should be saved in `./all_sample_names`. The following section of code randomly selects 100 samples and places the result in `./selected_sample_names`.

In [2]:
from secrets import choice

In [81]:
with open("./all_sample_names", "r") as file:
    lines: list[str] = file.readlines()

selected_lines: set[str] = set()
while len(selected_lines) < 100:
    selected_lines.add(choice(lines))

len(selected_lines)

100

In [84]:
with open("./selected_sample_names", "w") as file:
    for filename in selected_lines:
        file.write(filename)

with open("./selected_esbmc_output_names", "w") as file:
    for filename in selected_lines:
        file.write(f"{filename.strip()}.stdout.txt\n")

### Getting Samples

Run `./get_selected_samples` will get the source code of each file listed in `./selected_sample_names` along with the ESBMC output. The content will be placed in `samples` and `esbmc_output` respectively.

## Running Experiments with GPT 3.5 Turbo

In [1]:
import os
from io import TextIOWrapper
from time import time

import tiktoken
from time import sleep
from typing import Optional
from dotenv import get_key as load_dotenv, get_key
from openai import OpenAI, Completion

### AI Params + Definitions

In [49]:
MAX_TOKENS: int = 16385

In [47]:
def num_tokens_from_string(string: str, encoding_name: str = "gpt-3.5-turbo") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.encoding_for_model(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

In [8]:
api_key: Optional[str] = get_key(dotenv_path=".env", key_to_get="OPENAI_API_KEY")
assert api_key
client = OpenAI(api_key=api_key)

In [76]:
def run_sample(prompt: str, source_code: str, esbmc_output: str, role: str) -> list[str]:
    message_stack: list = [
        {"role": "system", "content": prompt.format_map({"source": source_code, "esbmc": esbmc_output})},
    ]
    
    response: Completion = client.chat.completions.create(
        model="gpt-3.5-turbo-0125",
        messages=message_stack,
    )
    return response.choices[0].message.content

### Initializing Directories

In [90]:
dirs: list[str] = ["samples-patched", "results"]
dirs2: list[str] = ["constant", "contextual"]

for dir in dirs:
    if not os.path.exists(dir):
        os.mkdir(dir)

    for dir2 in dirs2:
        if not os.path.exists(f"{dir}/{dir2}"):
            os.mkdir(f"{dir}/{dir2}")

### Initializing Logger
Log everything using these easy custom print and write functions. Need to beware that opening log.txt may display outdated state until buffer is properly flushed. Editing the log file will result in corruption until `log_file.close()` is called.

In [4]:
# Initialize logger.
try:
    log_file.flush()
except (NameError, ValueError) as e2:
    log_file: TextIOWrapper = open("log.txt", "a")

def log_str(text: str = "") -> None:
    assert not log_file.closed, "The log file is closed."
    if len(text) == 0:
        log_file.write("\n")
    else:
        log_file.write(f"Log: {time()}: {text}\n")
    
def print_and_log(text: str = "") -> None:
    assert not log_file.closed, "The log file is closed."
    if len(text) == 0:
        log_file.write("\n")
        print()
    else:
        text = str(time()) + ": " + text
        log_file.write("Log: " + text + "\n")
        print(text)

print_and_log("Notice: Starting new logging session.")

1709133253.7469735: Notice: Starting new logging session.


### Load and Parse Data

#### ESBMC Parsing Strategies

In [78]:
def esbmc_output_remove_loop_unroll(esbmc_output: str) -> str:
    lines: list[str] = esbmc_output.splitlines()
    kept_lines: list[str] = []
    filter: bool = False
    for line in lines:
        if filter:
            if "Building error trace" in line:
                filter = False
                kept_lines.append(line)
        else:
            if "Starting Bounded Model Checking" in line:
                filter = True
            kept_lines.append(line)
    return "\n".join(kept_lines)

def esbmc_output_remove_dividers(esbmc_output: str) -> str:
    lines: list[str] = esbmc_output.splitlines()
    kept_lines: list[str] = []
    for line in lines:
        if "----------------------------------------------------" not in line:
            kept_lines.append(line)
    return "\n".join(kept_lines)

def compress_esbmc_output(esbmc_output: str) -> str:
    esbmc_output = esbmc_output_remove_loop_unroll(esbmc_output)
    esbmc_output = esbmc_output_remove_dividers(esbmc_output)
    return esbmc_output

In [79]:
def esbmc_get_violated_property(esbmc_output: str) -> str:
    """Gets the violated property line of the ESBMC output."""
    # Find "Violated property:" string in ESBMC output
    lines: list[str] = esbmc_output.splitlines()
    for ix, line in enumerate(lines):
        if "Violated property:" == line:
            return "\n".join(lines[ix:ix+3])
    raise Exception(f'Could not find "Violated property:" in {file_name_key}')

def esbmc_get_counter_example(esbmc_output: str) -> str:
    """Gets ESBMC output after and including [Counterexample]"""
    idx: int = esbmc_output.find("[Counterexample]\n")
    assert idx != -1
    return esbmc_output[idx:]

In [80]:
# Load all the samples and esbmc output.
data_samples: dict[str, str] = {}
data_esbmc_output: dict[str, str] = {}
data_vp_output: dict[str, str] = {}

for subdir in subdirs:
    files: list[str] = sorted(os.listdir(f"samples/{subdir}"))
    for file_name in files:
        if not file_name.endswith(".c"):
            continue
            
        with open(f"samples/{subdir}/{file_name}", "r") as file:
            key: str = f"{subdir}/{file_name}"
            data_samples[key] = file.read()
        with open(f"esbmc_output/{subdir}/{file_name}.stdout.txt", "r") as file:
            esbmc_output: str = file.read()
            # data_esbmc_output[key] = esbmc_output
            # Parse and remove lines between "Starting Bounded Model Checking" until "Symex completed"
            # data_esbmc_output[key] = compress_esbmc_output(esbmc_output)
            data_esbmc_output[key] = esbmc_get_counter_example(esbmc_output)
            data_vp_output[key] = esbmc_get_violated_property(esbmc_output)

### Define Prompts

The following prompts are going to be iterated through. The prompts `simple_prompts_no_esbmc` and `simple_prompts` are the baseline prompts.

In [83]:
# Simple Prompts

simple_prompts_no_esbmc: list[str] = [
    "The following source code segment might contain a memory vulnerability\n\n{source}\n\nFix the source code segment.",
    "Fix the memory vulnerability that may exist in the in the source code segment:\n\n{source}",
]

simple_prompts: list[str] = [
    "The following source code contains a memory vulnerability\n\n{source}\n\nThe following is output of ESBMC describing the vulnerability\n\n{esbmc}\n\nFix the source code.",
    "Fix the source code:\n\n{source}\n\n{esbmc}",
]

simple_prompts_flipped: list[str] = [
    "ESBMC output describes a memory vulnerability in the source code, the following is ESBMC output:\n\n{esbmc}\n\nThe following is the vulnerable source code:\n\n{source}\n\nFix the source code.",
    "Fix the source code:\n\n{esbmc}\n\n{source}",
]

# Persona Prompts

persona_prompt_no_esbmc: list[str] = [
    "You’re a {role}. You’ll be shown some C code. Repair the code and display it. The code is {source}",
    "From now on, act as an {role} that repairs AI C code. You will be shown AI C code. Provide the repaired C code as output, as would an {role}. Aside from the corrected source code, do not output any other text. The code is\n\n{source}"
]

persona_prompt: list[str] = [
    "You’re a {role}. You’ll be shown some C code, along with ESBMC output. Repair the code and display it. The code is\n\n{source}\n\n The ESBMC output is\n\n{esbmc}",
    "From now on, act as an {role} that repairs AI C code. You will be shown AI C code, along with ESBMC output. Pay close attention to the ESBMC output, which contains a stack trace along with what type of error has occurred and its location. Provide the repaired C code as output, as would an {role}. Aside from the corrected source code, do not output any other text. The code is\n\n{source}\n\nThe ESBMC output is\n\n{esbmc}",
]

persona_prompt_flipped: list[str] = [
    "You’re a {role}. You’ll be shown some C code, along with ESBMC output. Repair the code and display it. The ESBMC output is\n\n{esbmc}\n\nThe source code is\n\n{source}",
    "From now on, act as an {role} that repairs AI C code. You will be shown AI C code, along with ESBMC output. Pay close attention to the ESBMC output, which contains a stack trace along with what type of error has occurred and its location. Provide the repaired C code as output, as would an {role}. Aside from the corrected source code, do not output any other text. The ESBMC output is\n\n{esbmc}\n\nThe source code is\n\n{source}",
]

all_prompts: list[str] = simple_prompts_no_esbmc + simple_prompts + simple_prompts_flipped + persona_prompt_no_esbmc + persona_prompt + persona_prompt_flipped

persona_roles: list[str] = [
    "Programmer with 1 million years of experience",
    "Senior software engineer",
    "Automated code repair tool",
    "Artificial intelligence that specializes in repairing C programs",
    "The smartest human in the universe",
    "Dog",
]

### Splitting Strategies

The source code and the ESBMC output is too large for the LLM's context length. 3 strategies are proposed to see if they can alleviate the problem:

1. Constant: Split by line or by character no strucutre (Brutal split)
2. ~~Structural: Split semantically (function by function)~~
3. Contextual: Split from failure and show code before

#### Notation

Let `L={l1, l2, l3, ..., ln}` be the set of all line lengths and where `n` is the number of lines in `C` and where `lx` is the length of the `x`th line in `C`. So `C[l1]` is the length first line and so on... `E` represents the length of the line with the error and `e` is the index of that line in `L`, such that `E=L[e]=le`.

In [48]:
def get_code_from_solution(solution: str) -> str:
    """Strip the source code of any leftover text as sometimes the AI model
    will generate text and formatting despite being told not to.
    
    Source: https://github.com/Yiannis128/esbmc-ai/blob/master/esbmc_ai/solution_generator.py"""
    try:
        code_start: int = solution.index("```") + 3
        # Remove up until the new line, because usually there's a language
        # specification after the 3 ticks ```c...
        code_start = solution.index("\n", code_start)
        code_end: int = len(solution) - 3 - solution[::-1].index("```")
        solution = solution[code_start:code_end]
    except ValueError:
        pass
    finally:
        return solution

### Combining Strategies

The LLM response can be merged back into the full source code using the following strategies:

1. Brutal replacement: input to the LLM are simply replaced by whatever code the LLM responds with.
2. ~~Direct replacement: Ask the LLM to supply in a specific format the line(s) to replace. Then simply replace those lines.~~

In [11]:
def apply_patch_brutal_replacement(source_code: str, patch: str, start: int, end: int) -> str:
    """End is non-inclusive"""
    lines: list[str] = source_code.splitlines()
    lines = lines[:start] + [patch] + lines[end:]
    return "\n".join(lines)

### ~~Constant Splitting Strategy~~

Constant splitting strategy splits the source code into constant sized tokens, then asks the LLM to repair each segment. The segment size is determined by using a percentage of the max tokens. The size of each segment is going to be `95%` in order to leave room for the counterexample.

Calculate the segment boundaries like so:
1. We want the largest `i` such that `S = Σ{i=0}{e}(L[i]<=TOKENS)`.
2. The constraints of `S` are as follows: `0<=i<=n`.

In [73]:
SEGMENT_TOKEN_PERCENT: float = 0.95

def get_segments(lines: list[str]) -> list[str]:
    """Splits the lines into constant size segments using the `SEGMENT_TOKEN_PERCENT` and `MAX_TOKEN`
    variable. Make sure the lines have their new lines still attached."""
    token_count: int = 0
    segments: list[list[str]] = [[]]
    for line in lines:
        token_count += num_tokens_from_string(line)
        if token_count >= SEGMENT_TOKEN_PERCENT * MAX_TOKENS:
            # Start a new segment.
            segments.append([])
            token_count = 0
        # Add line to latest segment
        segments[-1].append(line)
    # Combine segment lines (list of lines) into a complete string.
    # Then return list of segments.
    return list("".join(segment) for segment in segments)

In [None]:
print_and_log()
print_and_log("Running Constant Splitting Strategy")

# Loop through prompts
for prompt_idx, prompt in enumerate(all_prompts):
    print_and_log()
    print_and_log(f"Running new cycle with the following prompt ({prompt_idx}):\n```\n{prompt}\n```")
    # Try all the roles
    # Check if a {role} tag is in the prompt string and use roles in that case.
    role_count: int
    if "{role}" in prompt:
        print_and_log("Notice: Prompt has roles. Will cycle roles.")
        role_count = len(persona_roles)
    else:
        print_and_log("Notice: Prompt has no roles. Roles will not be cycled or used.")
        role_count = 1
        
    # Loop through files
    for idx, file_name_key in enumerate(data_samples.keys()):
        print_and_log()
        print_and_log(f"Checkpoint {idx}: {file_name_key}")
        
        source_code: str = data_samples[file_name_key]
        esbmc_output: str = data_esbmc_output[file_name_key]

        # Split the source code into lines to divide into segments.
        lines: list[str] = source_code.splitlines(True)
        # Divide into segments. Get back string segments.
        source_code_segments: list[str] = get_segments(lines)
        
        # Loop through all the segments.
        for segment_idx, sc_segment in enumerate(source_code_segments):
            # Try all the roles, if no roles, then loop will execute once only.
            for role_idx in range(role_count):
                try:
                    delta: float = time()
                    # Role will be passed, if the prompt does not contain {role} then it will be not used.
                    llm_output = run_sample(prompt, trimmed_sc, esbmc_output, all_prompts[role_idx])
                    delta = time() - delta
    
                    print_and_log(f"Duration: {delta}")
                    log_str(f"Raw Response:\n\n{llm_output}")
                    
                    llm_output = get_code_from_solution(llm_output)
    
                    # Save patch
                    with open(f"results/constant/{file_name_key}-{segment_idx}-{prompt}-{role_idx}", "w") as file:
                        file.write(llm_output)
        
                    # Stitch together patch
                    patched_source: str = apply_patch_brutal_replacement(
                        source_code,
                        llm_output,
                        segment_idx * constant_split_lines,
                        (segment_idx + 1) * constant_split_lines
                    )
        
                    # Save patched source
                    with open(f"samples-patched/constant/{file_name_key}-{segment_idx}-{prompt}-{role_idx}", "w") as file:
                        file.write(patched_source)
                except Exception as e:
                    print_and_log(f"Notice: error: {file_name_key}: {e}")
                finally:
                    print_and_log()

### Contextual Strategy

Involves getting the line at which the error has occured along with a ratio split of the lines before/after. In this case chose `85%` before and `10%` as we want to give as much information of how the code looked before the error, however, still include some lines after for context.

The following variables are declared:
* `LTOKENS=MAX_TOKENS*0.85` - The window of tokens to keep before the error line.
* `UTOKENS=MAX_TOKENS*0.10` - The window of tokens to keep after the error line.

The lower bound line index is calculated like so:
1. We want the largest `il` such that `S = Σ{il=0}{e}(L[e-il]<=LTOKENS)`.
2. The constraints of `S` are as follows: `0<=il<=e` and `0<=L[e-S:e]<=LTOKENS`.

Similarly, the upper bound line index is calculated like so:
1. We want the largest `iu` such that `S = Σ{iu=0}{e}(L[e+iu]<=UTOKENS)`.
2. The constraints of `S` are as follows: `0<=iu<=n-e` and `0<=L[e:e+Su]<=UTOKENS`.

The combined window will be: `L[e-lu:e+iu]` which will fill `95%` of `MAX_TOKENS`, the other `5%` will be allocated to the ESBMC output's counterexample stack-trace (and/)or violated property.

In [85]:
def get_source_code_err_line(esbmc_output: str) -> int:
    # Find "Violated property:" string in ESBMC output
    lines: list[str] = esbmc_output.splitlines()
    for ix, line in enumerate(lines):
        if "Violated property:" == line:
            pos_line: str = lines[ix+1]
            pos_line_split: str = pos_line.split(" ")
            for iy, word in enumerate(pos_line_split):
                if word == "line":
                    # Get the line number
                    return int(pos_line_split[iy+1])
            raise Exception(f"Could not find line in {file_name_key}")
    raise Exception(f'Could not find "Violated property:" in {file_name_key}')

def get_lower_bound(source_code_lines: list[str], error_line: int) -> int:
    """Gets the lower index offset from the error line to include in the trimmed source code.
    Make sure the lines have their new lines still attached."""
    # Count each line's tokens and sum them into token_counts.
    token_counts: int = 0
    for i in range(error_line, -1, -1):
        line: str = source_code_lines[error_line - i]
        token_counts += num_tokens_from_string(line)
        # Get the largest i that is less than LTOKENS.
        if token_counts >= LTOKENS:
            return i-1
    # if we run out of lines then use all.
    return error_line

def get_upper_bound(source_code_lines: list[str], error_line: int) -> int:
    """Gets the upper index offset from the error line to include in the trimmed source code.
    Make sure the lines have their new lines still attached."""
    # Count each line's tokens and sum them into token_counts.
    token_counts: int = 0
    for i in range(error_line, len(source_code_lines) - error_line):
        line: str = source_code_lines[error_line + i]
        token_counts += num_tokens_from_string(line)
        # Get the largest i that is less than LTOKENS.
        if token_counts >= UTOKENS:
            return i-1
    # if we run out of lines then use all.
    return len(source_code_lines) - error_line

# The amount of lines to include before/after the line with the error.
LTOKENS=MAX_TOKENS*0.85
# Extra lines to add for context after the line with the error.
UTOKENS=MAX_TOKENS*0.10

In [91]:
print_and_log()
print_and_log("Running Contextual Strategy")

# Loop through prompts
for prompt_idx, prompt in enumerate(all_prompts):
    print_and_log()
    print_and_log(f"Notice: Running new cycle with prompt ({prompt_idx})")
    # Try all the roles
    # Check if a {role} tag is in the prompt string and use roles in that case.
    role_count: int
    if "{role}" in prompt:
        print_and_log("Notice: Prompt has roles. Will cycle roles.")
        role_count = len(persona_roles)
    else:
        print_and_log("Notice: Prompt has no roles. Roles will not be cycled or used.")
        role_count = 1

    # Loop through violated property ESBMC output and counterexample ESBMC output.
    for esbmc_output_type in ["ce", "vp"]:
        # Loop through the different roles.
        for role_idx in range(role_count):
            # Loop through files
            for idx, file_name_key in enumerate(data_samples.keys()):
                print_and_log()
                print_and_log(f"Notice: Checkpoint contextual {prompt_idx} {esbmc_output_type} {role_idx} {idx} {file_name_key}")
        
                # Write progress
                with open("progress.txt", "w") as file:
                    file.write(f"contextual {prompt_idx} {esbmc_output_type} {role_idx} {idx} {file_name_key}")
                
                source_code: str = data_samples[file_name_key]
                source_code_lines: list[str] = source_code.splitlines(True)
                # Get CE or VP output for ESBMC.
                esbmc_output: str = data_esbmc_output[file_name_key] if esbmc_output_type == "ce" else data_vp_output[file_name_key]
        
                err_line: int = get_source_code_err_line(esbmc_output)
                # Trim the source code by lines and get the window from
                # contextual_trim_lines and contextual_trim_buffer
                lower_bound: int = get_lower_bound(source_code_lines, err_line)
                upper_bound: int = get_upper_bound(source_code_lines, err_line)
                trimmed_sc: str = "".join(source_code_lines[lower_bound:upper_bound])
                
                try:
                    delta: float = time()
                    # Role will be passed, if the prompt does not contain {role} then it will be not used.
                    llm_output = run_sample(prompt, trimmed_sc, esbmc_output, all_prompts[role_idx])
                    delta = time() - delta
        
                    print_and_log(f"Notice: Duration: {delta}")
                    log_str(f"Raw Response:\n\n{llm_output}")
                    
                    llm_output = get_code_from_solution(llm_output)

                    # Name will be sorted by experimental order. Not filename as common experiments can be
                    # found near eachother.
                    file_name: str = f"{prompt_idx}.{esbmc_output_type}.{role_idx}.{idx}_{os.path.basename(file_name_key)}"
                    
                    # Save patch
                    with open(f"results/contextual/{file_name}", "w") as file:
                        file.write(llm_output)
        
                    # Stitch together patch
                    patched_source: str = apply_patch_brutal_replacement(source_code, llm_output, lower_bound, upper_bound)
        
                    # Save patched source
                    with open(f"samples-patched/contextual/{file_name}", "w") as file:
                        file.write(patched_source)
                except Exception as e:
                    print_and_log(f"Notice: error: {file_name_key}: {e}")
                finally:
                    print_and_log()


1709652384.3239503: Running Contextual Strategy

1709652384.3245773: Notice: Running new cycle with prompt (0)
1709652384.324592: Notice: Prompt has no roles. Roles will not be cycled or used.

1709652384.3246138: Notice: Checkpoint contextual 0 ce 0 0 reach_prob_density/gcas_5_safe.c-amalgamation-149.c
1709652385.816039: Notice: Duration: 1.4805569648742676


1709652385.8166823: Notice: Checkpoint contextual 0 ce 0 1 reach_prob_density/gcas_8_safe.c-amalgamation-6.c
1709652390.2595313: Notice: Duration: 4.433193922042847


1709652390.2601917: Notice: Checkpoint contextual 0 ce 0 2 reach_prob_density/robot_5_safe.c-amalgamation-124.c
1709652391.8338907: Notice: Duration: 1.5635185241699219


1709652391.8351376: Notice: Checkpoint contextual 0 ce 0 3 reach_prob_density/robot_5_safe.c-amalgamation-13.c
1709652395.6813545: Notice: Duration: 3.8360230922698975


1709652395.6826513: Notice: Checkpoint contextual 0 ce 0 4 reach_prob_density/robot_6_safe.c-amalgamation-46.c
1709652400.496683

KeyboardInterrupt: 

## Classification

The processing will be conducted on the FM servers, however, the results will be processed in this section.