Generation involves running the following cells repeatedly, changing the values for domain and model_temp to hit every combination of those variables.

In [None]:
import datetime
import logging
import pandas as pd
from ChatGPT_Data_Generation import OpenAIMisinfoBatchManager, process_batch_output_file
from creds import OPENAI_KEY
from dataclasses import dataclass

@dataclass
class BatchDataContainer:
    """
    A dataclass to hold information about a batch of prompts to be sent to the OpenAI API.
    """
    df: pd.DataFrame
    output_filepath: str = None
    error_filepath: str = None
    batch_id: str = None
    misinfo_engine: OpenAIMisinfoBatchManager = None


def setup_logger():
    """
    Set up the logger for the script.
    """
    # Create a logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Create a file handler that logs to a new file each time the script is run
    log_filename = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S.log")
    file_handler = logging.FileHandler(log_filename)
    file_handler.setLevel(logging.INFO)

    # Create a stream handler that logs to the terminal
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.INFO)

    # Create a formatter and set it for both handlers
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    stream_handler.setFormatter(formatter)

    # Add the handlers to the logger
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)

    logger.info("Logger setup complete.")

setup_logger()

prompts_csv_path = "data/prompts.csv"
domain = "paraphrase" # "rewrite" or "paraphrase" or "open_ended"
model_temp = 1.4
batch_size = 50
output_filepath_template = "{}_misinfo_responses_{}.jsonl"
error_filepath_template = "{}_misinfo_errors_{}.jsonl"
results_filename_template = "model_{}___temp_{}___domain_{}___{}.csv"
raw_csv_df = pd.read_csv("data/prompts.csv")
full_prompts_df = raw_csv_df[raw_csv_df["type"] == domain]

if domain not in ["rewrite", "paraphrase", "open_ended"]:
    raise ValueError("Domain must be one of 'rewrite', 'paraphrase', or 'open_ended'.")

# Break up the prompts into batches of batch_size
batch_list = []
for i in range(0, len(full_prompts_df), batch_size):
    batch_df = full_prompts_df[i:i+batch_size]
    batch_list.append(BatchDataContainer(df=batch_df))

for i, batch_data in enumerate(batch_list):
    batch_df = batch_data.df
    misinfo_engine = OpenAIMisinfoBatchManager(temp=model_temp,
                                                top_p=.9,
                                                api_key=OPENAI_KEY)
    misinfo_engine.send_batch_misinfo_request(batch_df)
    output_filepath = output_filepath_template.format(domain, misinfo_engine.batch_id)
    error_filepath = error_filepath_template.format(domain, misinfo_engine.batch_id)
    
    # Update the dataclass instance with the new information
    batch_data.output_filepath = output_filepath
    batch_data.error_filepath = error_filepath
    batch_data.batch_id = misinfo_engine.batch_id
    batch_data.misinfo_engine = misinfo_engine
    
    print("Misinformation generation request sent.")


In [None]:
for i, batch_data in enumerate(batch_list):
    misinfo_engine = batch_data.misinfo_engine

    # Use the dataclass fields for output_filepath and error_filepath
    misinfo_engine.retrieve_batch_results(output_filepath=batch_data.output_filepath,
                                        error_filepath=batch_data.error_filepath,
                                        max_wait_time=24 * 60 * 60,  # Wait for up to 24 hours
                                        status_check_interval=.5 * 60)  # Check status every 30 seconds
    
    print("Misinformation generation request completed.")
    
    # Process the output file using the path stored in the dataclass
    output_df = process_batch_output_file(file_path=batch_data.output_filepath)
    
    # Save the output dataframe to a CSV file using the batch-specific information
    output_df.to_csv(results_filename_template.format(misinfo_engine.model, 
                                                      misinfo_engine.temp, 
                                                      domain, 
                                                      misinfo_engine.batch_id), index=False)


Run the next cell multiple times, changing the target_temp value to each of the model temp values, to consolidate the generated data into a single file.

In [None]:
import os
import pandas as pd

# Directory containing the CSV files
directory = 'data'
data_master_path = 'data/data_master.csv'
# The temperature value you're filtering by (e.g., 0, 0.7, 1.4)
target_temp = 0  # Set the temp you want to filter by here

# Load the prompts.csv file
master_df = pd.read_csv(data_master_path)

# Iterate over all files in the directory
for filename in os.listdir(directory):
    if filename.endswith('.csv'):
        # Split the filename to extract the model and temp
        parts = filename.split('___')
        if len(parts) < 4:
            continue
        
        model_name = parts[0].split('_')[1]
        temp_value = float(parts[1].split('_')[1])
        domain = parts[2].split('_')[1]
        
        # Check if the temp matches the target temp
        if temp_value == target_temp:
            # Load the current CSV file
            csv_path = os.path.join(directory, filename)
            data_df = pd.read_csv(csv_path)
            
            # Merge the data based on the 'hash' column
            merged_df = pd.merge(master_df, data_df[['hash', 'gpt-4o-2024-05-13']], on='hash', how='left')
            
            # Add columns for the temp and model name
            merged_df[f'gpt-4o-2024-05-13_temp_{temp_value}'] = merged_df['gpt-4o-2024-05-13']
            
            # Drop the intermediate 'gpt-4o-2024-05-13' column after merging
            merged_df.drop(columns=['gpt-4o-2024-05-13'], inplace=True)
            
            # Update the prompts.csv DataFrame
            master_df = merged_df

# Save the updated prompts.csv
master_df.to_csv(data_master_path , index=False)

print(f"{data_master_path} updated successfully.")

