Default Frame for CDM test running

In [None]:
from gltr_ppl.cdm import GLTRPPLCodeDetector
from metrics import MetricsEvaluator
import csv
import time
# import requests
import re

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class CsvProcessor:
    def __init__(self, input_file, limit=100000):
        self.input_file = input_file
        self.limit = limit
        # Instantiate the other class here

    # run_human_test is default to False, if True it will ask api_client to predict the code
    def process_csv(
        self,
        api_client,
        human_result_file: str,
        output_file: str,
        gpt_headers: list,
        checkpoint: int,
        run_human_test: bool = False,
        human_headers: list = [],
    ):
        self.limit = checkpoint + self.limit
        with open(self.input_file, "r", encoding="utf-8") as csv_input_file:
            with open(human_result_file, "r", encoding="utf-8") as csv_human_file:
                with open(
                    output_file, "w", newline="", encoding="utf-8"
                ) as csv_output_file:
                    input_file_reader = csv.DictReader(csv_input_file)
                    csvheaders = (
                        input_file_reader.fieldnames + human_headers + gpt_headers
                    )
                    writer = csv.DictWriter(csv_output_file, fieldnames=csvheaders)
                    writer.writeheader()
                    index = 0
                    print("WORKING ON SOLUTION")
                    if not run_human_test:
                        human_res_reader = csv.DictReader(csv_human_file)
                        for input_row, human_res_row in zip(
                            input_file_reader, human_res_reader
                        ):
                            gpt_result = [""] * len(gpt_headers)
                            if checkpoint <= index < self.limit:
                                print(f"checkpoint: {index} ----- {input_row['index']}")
                                retries = 0
                                while retries < 2:
                                    try:
                                        gpt_result = api_client.text_predict_tuple(
                                            str(input_row["GPT Answer"].encode("utf-8"))
                                        )
                                        # print(gpt_result)
                                        break
                                    except Exception as e:
                                        retries += 1
                                        print(f"Error calling API: {e}. Retrying...")
                                        if retries >= 2:
                                            print(
                                                "Max retries reached for row ",
                                                input_row["index"],
                                                ". Skipping...",
                                            )
                                        time.sleep(0.25)
                                new_row = {
                                "index": input_row[
                                    "index"
                                ],  # \ufeffindex for goodanswer
                                "Source Name": input_row["Source Name"],
                                "local index": input_row["local index"],
                                "Problem": input_row["Problem"],
                                "Python Code": input_row["Python Code"],
                                "GPT Answer": input_row["GPT Answer"],
                                "variant": input_row["variant"],
                                }
                                for idx, header in enumerate(human_headers):
                                    new_row[header] = human_res_row[header]
                                for idx, header in enumerate(gpt_headers):
                                    new_row[header] = gpt_result[idx]
                                writer.writerow(new_row)
                            elif index >= self.limit:
                                break
                            index += 1
                    else:
                        for row in input_file_reader:
                            human_result = [""] * len(human_headers)
                            gpt_result = [""] * len(gpt_headers)
                            if checkpoint <= index < self.limit:
                                print(f"checkpoint: {index} ----- {row['index']}")
                                retries = 0
                                while retries < 2:
                                    try:
                                        # print(row['Python Code'], row["GPT Answer"])
                                        # print(row['index'])
                                        human_result = api_client.text_predict_tuple(
                                            str(row["Python Code"].encode("utf-8"))
                                        )
                                        # time.sleep(0.5)
                                        gpt_result = api_client.text_predict_tuple(
                                            str(row["GPT Answer"].encode("utf-8"))
                                        )
                                        # print(gpt_result)
                                        break
                                    except Exception as e:
                                        retries += 1
                                        print(f"Error calling API: {e}. Retrying...")
                                        if retries >= 2:
                                            print(
                                                "Max retries reached for row ",
                                                row["index"],
                                                ". Skipping...",
                                            )
                                        time.sleep(0.25)
                                new_row = {
                                    "index": row["index"],  # \ufeffindex for goodanswer
                                    "Source Name": row["Source Name"],
                                    "local index": row["local index"],
                                    "Problem": row["Problem"],
                                    "Python Code": row["Python Code"],
                                    "GPT Answer": row["GPT Answer"],
                                    "variant": row["variant"],
                                }
                                for idx, header in enumerate(human_headers):
                                    new_row[header] = human_result[idx]
                                for idx, header in enumerate(gpt_headers):
                                    new_row[header] = gpt_result[idx]
                                writer.writerow(new_row)
                            elif index >= self.limit:
                                break
                            index += 1
                    print("------DONE------")

# Global config for all CDMs

In [4]:
variant = 10
input_file = f"variant_{variant}_full.csv"
last_checkpoint = 0
csv_processor = CsvProcessor(input_file)

### CDM execution

In [None]:
# Setup input file, output file, code detector, CDM name, new col to be added, last checkpoint 
cdm_name = "gltr_ppl"
human_result_file = "gltr_ppl_human.csv"
output_file = f"{cdm_name}_variants_prompt_{variant}.csv"
new_human_header = [
   "GLTR_answer_human_binary",
   "PPL_answer_human_binary",
]
new_gpt_header = [
   "GLTR_answer_GPT_binary",
   "PPL_answer_GPT_binary",
]
last_checkpoint = 0
api_client = GLTRPPLCodeDetector()
run_human_test = False
csv_processor.process_csv(api_client, human_result_file, output_file, new_gpt_header, last_checkpoint, run_human_test, new_human_header)
# # Evaluate Metrics
# evaluator = MetricsEvaluator(output_file)
# evaluator.calculate(cdm_name,f"{cdm_name}_results.csv")