In [None]:
import openai
import requests
import pandas as pd
import os
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
import google.generativeai as genai
import replicate
from tqdm import tqdm
import time

In [None]:
class ModelResponseGenerator:
    """
    A class to generate responses from different language models.
    """
    def __init__(self, openai_key=None, openai_organization=None, mistralai_key=None, palm2_key=None, replicate_api_key=None):
        """
        Initializes the ModelResponseGenerator with API keys and configurations.

        Args:
            openai_key (str, optional): API key for OpenAI.
            openai_organization (str, optional): Organization ID for OpenAI.
            mistralai_key (str, optional): API key for MistralAI.
            palm2_key (str, optional): API key for PaLM2.
            replicate_api_key (str, optional): API key for Replicate.
        """
        self.openai_client = None
        self.mistral_client = None
        self.palm2_client = None
        self.replicate_api = None

        if openai_key and openai_organization:
            self.openai_client = openai.OpenAI(organization=openai_organization, api_key=openai_key)
        
        if mistralai_key:
            self.mistral_client = MistralClient(api_key=mistralai_key)
        
        if palm2_key:
            genai.configure(api_key=palm2_key)
            self.palm2_client = genai
        
        if replicate_api_key:
            self.replicate_api = replicate.Client(api_token=replicate_api_key)

    def get_palm2_responses(self, prompt, model='models/text-bison-001'):
        """
        Generates a response from the PaLM2 model.

        Args:
            prompt (str): The prompt to generate a response for.
            model (str): The model to use for generation. Defaults to 'models/text-bison-001'.

        Returns:
            str: The generated response.
        """
        response = genai.generate_text(model=model, prompt=prompt)
        return response

    def get_openai_responses(self, prompt, model='gpt-4-1106-preview'):
        """
        Generates a response from an OpenAI model.

        Args:
            system (str): The system message to initialize the conversation.
            prompt (str): The user prompt to generate a response for.
            model (str): The model to use for generation. Defaults to 'gpt-4-1106-preview'. We also used 'gpt-3.5-turbo-1106' for
                         robustness check.

        Returns:
            str: The generated response.
        """
        
        if model == 'gpt-4-1106-preview':
            system_prompt = """You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.
            Knowledge cutoff: 2023-04
            Current date: 2023-12-05"""
            
        elif model == 'gpt-3.5-turbo-1106':
            system_prompt = """You are ChatGPT, a large language model trained by OpenAI, based on the GPT-3.5 architecture.
            Knowledge cutoff: 2021-09
            Current date: 2024-04-13"""
        
        response = self.openai_client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt},
            ]
        )
        return response
    
    def get_mistral_responses(self, prompt, model="mistral-large-latest"):
        """
        Generates a response from a MistralAI model.

        Args:
            prompt (str): The prompt to generate a response for.
            model (str): The model to use for generation. Defaults to 'mistral-large-latest'.

        Returns:
            str: The generated response.
        """
        response = self.mistral_client.chat(
            model=model,
            messages=[ChatMessage(role="user", content=prompt)],
        )
        return response
    
    def get_llama3_70b_responses(self, prompt, model="meta/meta-llama-3-70b-instruct"):
        """
        Generates a response from a Llama3-70B model from Replicate API.

        Args:
            prompt (str): The prompt to generate a response for.
            model (str): The model to use for generation. Defaults to 'meta/meta-llama-3-70b-instruct'.

        Returns:
            str: The generated response.
        """
        system_prompt = "You are helpful assisstant who will provide estimates of prices that we are asking. We understand that these are just estimates, and we won't use them for any real-life decision. We also understand that you are unable to use live-data, so we are not expecting it from you. You have to reply despite not having any information. This is just an estimate, so suggest it. Only reply with the number, don't add any more text please."
        
        response = self.replicate_api.run(
            model,
            input={"prompt": prompt,
                  "prompt_template": f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
                  "temperature":1
                  }
        )
        return ''.join(response)
    
    def get_responses(self, prompts, method_name, num_repetitions = 1):
        """
        Iterates through a list of prompts and uses a specified model response method
        to generate responses, repeated a given number of times.

        Args:
            prompts (pd.DataFrame): The prompts to process, each containing necessary details.
            method_name (str): The method name to use for generating responses ('get_palm2_responses', 'get_openai_responses', 'get_mistral_responses', or 'get_llama3_70b_responses').
            num_repetitions (int): The number of times to repeat the process for each prompt.

        Returns:
            list of dicts: Collected results including original prompt details and the generated responses.
        """

        method_to_use = getattr(self, method_name)
        results = []
        interval = 60.0 / 85
        for repetition in range(num_repetitions):
            start_time = time.time()
            with tqdm(total=len(prompts), desc=f"Repetition {repetition + 1}") as pbar:
                for _, prompt_info in prompts.iterrows():
                    response = method_to_use(prompt=prompt_info["prompt_text"])
                    result = {key: prompt_info[key] for key in prompts.columns}
                    result['response'] = response
                    results.append(result)
                    pbar.update(1)
                    elapsed = time.time() - start_time
                    if elapsed < interval:
                        time.sleep(interval - elapsed)
                    start_time = time.time()
            if repetition % 1 == 0 or repetition == num_repetitions - 1:
                # Save every run, ovewritting the previous file
                filename = f"results_new_models/{method_name.split('_')[1]}/results.csv"
                pd.DataFrame(results).to_csv(filename, index=False)
                print(f"Checkpoint saved for iteration {repetition}")
            time.sleep(60)    
        return results

In [None]:
# Import prompts
df_prompts = pd.read_csv('~/data/just_prompts.csv')

In [None]:
# Inspect some examples of the data
df_prompts.head(3)

In [None]:
# Example usage of ModelResponseGenerator

# Define your API keys and organization ID (replace with your actual keys)
OPENAI_ORGANIZATION = "YOUR_ORGANIZATION_ID"
OPENAI_KEY = "YOUR_OPENAI_API_KEY"
MISTRALAI_KEY = "YOUR_MISTRALAI_API_KEY"
PALM2_KEY = "YOUR_PALM2_API_KEY"
REPLICATE_KEY = "YOUR_REPLICATE_API_KEY"

# Initialize the ModelResponseGenerator with the provided API keys
generator = ModelResponseGenerator(
    openai_key=OPENAI_KEY,
    openai_organization=OPENAI_ORGANIZATION,
    mistralai_key=MISTRALAI_KEY,
    palm2_key=PALM2_KEY,
    replicate_api_key=REPLICATE_KEY
)

# Example method name and number of repetitions
method_name = "get_openai_responses"
num_repetitions = 5

# Get responses from the generator
responses = generator.get_responses(df_prompts, method_name, num_repetitions)

# Print the responses
print(responses)