In [None]:
import os, json, time, re
from google import genai
from prompts import *

client = genai.Client(api_key="")

#### Zero-shot

In [2]:
def run_on_data(data_path, output_file):
    count = 0
    max_test = 5000
    
    # start = 396
    # count = start
    # end = 405
    debug = False #True #False

    if prompting == "few_shot":
        prompt_template = get_few_shot_prompt()
    elif prompting == "zero_shot":
        prompt_template = get_zero_shot_prompt()    
    
    # Ensure the output directory exists
    output_dir = os.path.dirname(output_file)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    
    with open(data_path, "r") as f, open(output_file, "a") as output_f:
        for line in f:
            # Clear unused GPU memory and collect garbage
            # torch.cuda.empty_cache()
            # gc.collect()
            
            # if count >= start and count <= end:
            if count < max_test:
                if line.strip():
                    try:
                        data = json.loads(line)
                    except json.JSONDecodeError as e:
                        print(f"Error decoding JSON: {e}")
                        continue
                    
                    # Extract fields from the JSON line
                    project = data.get("project", "")
                    commit_message = data.get("commit_message", "")
                    func = data.get("func", "")
                    target = data.get("target", "")
                    cwe = data.get("cwe", [])
                    cve = data.get("cve", "")
                    cve_desc = data.get("cve_desc", "")
    
                    final_prompt = prompt_template.format(
                        commit_message=commit_message, func=func
                    )
                    # print(final_prompt)
    
                    response = client.models.generate_content(
                        model=model_name, contents=final_prompt
                    )

                    match = re.match(r'^(Yes|No)\b[.]*\s*(.*)', response.text, re.DOTALL)
                    
                    if match:
                        result = match.group(1)  # Captures 'Yes' or 'No'
                        cot = match.group(2).strip()
                    else:
                        result = -1
                        cot = response.text

    
                    output_data = {
                        "project": project,
                        "commit_message": commit_message,
                        "func": func,
                        "target": target,
                        "cwe": cwe,
                        "cve": cve,
                        "cve_desc": cve_desc,
                        "result": result,
                        "cot": cot
                    }
    
                    output_f.write(json.dumps(output_data) + "\n")
                    
                    # Pause briefly and increment count
                    time.sleep(1)
                    count += 1
                    if count % 100 == 0:
                        print("Count", count)
    
    print(f"Processing complete. Results written to {output_file}")
    
        
# def main(data_path, model_id, prompting, output_file):
#     if prompting == "few_shot":
#         prompt_template = get_few_shot_prompt()
#     elif prompting == "zero_shot":
#         prompt_template = get_zero_shot_prompt()
    
#     generator = initiate_model(model_id)
#     run_on_data(data_path, prompt_template, generator, output_file)
    

# if __name__ == "__main__":
    
    # parser = argparse.ArgumentParser(description="Run PrimeVul processing with specified parameters.")
    
    # parser.add_argument("--data_path", type=str,
    #                     help="Path to the input JSONL data file.")
    # parser.add_argument("--model_id", type=str, default="bigcode/starcoder2-3b",
    #                     help="Identifier of the model to be used (e.g., bigcode/starcoder2-3b).")
    # parser.add_argument("--prompting", type=str, default="zero_shot",  # few_shot, zero_shot
    #                     help="Type of prompting to use (default: zero_shot).")
    # parser.add_argument("--output_file", type=str, default=None,
    #                     help="Path to the output file. If not provided, a default path is generated.")
    
# args = parser.parse_args()
model_name = "gemini-2.0-flash"
prompting = "zero_shot"

# Set default paths if not provided

data_path = "/speed-scratch/ra_mdash/PrimeVul_v0.1/primevul_valid.jsonl" #primevul_valid sample_data

output_file = f"/speed-scratch/ra_mdash/results/prime_vul/{model_name}_{prompting}.jsonl"

run_on_data(data_path, output_file)


Count 100
Count 200
Count 300
Count 400
Count 500
Count 600
Count 700
Count 800
Count 900
Count 1000
Count 1100
Count 1200
Count 1300
Count 1400
Count 1500
Count 1600
Count 1700
Count 1800
Count 1900
Count 2000
Count 2100
Count 2200
Count 2300
Count 2400
Count 2500
Count 2600
Count 2700
Count 2800
Count 2900
Count 3000
Count 3100
Count 3200
Count 3300
Count 3400
Count 3500
Count 3600
Count 3700
Count 3800
Count 3900
Count 4000
Count 4100
Count 4200
Count 4300
Count 4400
Count 4500
Count 4600
Count 4700
Count 4800
Count 4900
Count 5000
Processing complete. Results written to /speed-scratch/ra_mdash/results/prime_vul/gemini-2.0-flash_zero_shot.jsonl


#### Few-shot

In [None]:
def run_on_data(data_path, output_file):
    count = 0
    max_test = 5000
    
    # start = 396
    # count = start
    # end = 405
    debug = False #True #False

    # if prompting == "few_shot":
    #     prompt_template = get_zero_shot_prompt()
    # elif prompting == "zero_shot":
    prompt_template = get_zero_shot_prompt()    
    
    # Ensure the output directory exists
    output_dir = os.path.dirname(output_file)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    with open(data_path, "r") as f:
        samples = json.load(f)
    
    with open(output_file, "a") as output_f:
        for line in samples:
            # Clear unused GPU memory and collect garbage
            # torch.cuda.empty_cache()
            # gc.collect()
            
            # if count >= start and count <= end:
            if count < max_test:
                data = line
                # if line.strip():
                #     try:
                #         data = json.loads(line)
                #     except json.JSONDecodeError as e:
                #         print(f"Error decoding JSON: {e}")
                #         continue
                    
                # Extract fields from the JSON line
                project = data.get("project", "")
                commit_message = data.get("commit_message", "")
                func = data.get("func", "")
                target = data.get("target", "")
                cwe = data.get("cwe", [])
                cve = data.get("cve", "")
                cve_desc = data.get("cve_desc", "")
                fallback = data.get("fallback", "")
                # reasoning = data.get("reasoning", "")

                few_shot_examples = data.get("few_shot_samples", "")
                zero_shot_query = prompt_template.format(
                    commit_message=commit_message,
                    func=func
                )
                # print("zero_shot_query")
                # print(zero_shot_query)
                # Concatenate the few-shot examples with the zero-shot prompt.
                # Few-shot examples should provide context for the LLM prior to the query.
                final_prompt = f"{few_shot_examples}\n\n{zero_shot_query}"

                # print(final_prompt)

                response = client.models.generate_content(
                    model=model_name, contents=final_prompt
                )
                # print(response.text)
                # print("\n\n")

                match = re.match(r'(?:Answer:\s*)?(Yes|No)\b[.]*\s*(.*)', response.text, re.DOTALL)
                
                if match:
                    result = match.group(1)  # Captures 'Yes' or 'No'
                    cot = match.group(2).strip()
                else:
                    result = -1
                    cot = response.text


                output_data = {
                    "project": project,
                    "commit_message": commit_message,
                    "func": func,
                    "target": target,
                    "cwe": cwe,
                    "cve": cve,
                    "cve_desc": cve_desc,
                    "result": result,
                    "cot": cot,
                    "fallback": fallback,
                    # "reasoning": reasoning
                }

                output_f.write(json.dumps(output_data) + "\n")
                
                # Pause briefly and increment count
                time.sleep(1)
                count += 1
                if count % 100 == 0:
                    print("Count", count)
    
    print(f"Processing complete. Results written to {output_file}")
    
# args = parser.parse_args()
model_name = "gemini-2.0-flash"
prompting = "few_shot"

# Set default paths if not provided

data_path = "/speed-scratch/ra_mdash/PrimeVul_v0.1/primevul_with_4_shot.json"

output_file = f"/speed-scratch/ra_mdash/results/prime_vul/{model_name}_{prompting}.jsonl"

run_on_data(data_path, output_file)


Count 100
