In [12]:
from typing import List, Optional, TypedDict
import time
from dotenv import load_dotenv
import openai
import os
from backoff import on_exception, expo
from openai.error import RateLimitError, APIError, InvalidRequestError
import pandas as pd
from transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

openai.api_key = os.getenv("OPENAI_API_KEY")
load_dotenv()


class Message(TypedDict):
    """OpenAI Message object containing a role and the message content"""

    role: str
    content: str


class OpenAI:
    def __init__(self, relevant_memory, full_message_history):
        self.relevant_memory = relevant_memory
        self.full_message_history = full_message_history
        self.next_message_to_add_index = 0
        self.current_tokens_used = 0
        self.insertion_index = 0
        self.current_context = []
        self.response = None

    @on_exception(expo, InvalidRequestError, max_tries=3)
    def get_openai_response(self, messages, model, temperature, max_tokens):
        COMPLETIONS_API_PARAMS = {
            # We use temperature of 0.0 because it gives the most predictable, factual answer.
            "temperature": temperature,
            "model": model,
            "max_tokens": max_tokens,
        }
        response = openai.ChatCompletion.create(
            messages=messages, **COMPLETIONS_API_PARAMS)
        return response

    def create_chat_completion(
        self,
        messages: List[Message],  # type: ignore
        model: Optional[str] = None,
        temperature: float = 0.0,
        max_tokens: Optional[int] = None,
    ) -> str:
        """Create a chat completion using the OpenAI API

        Args:
            messages (List[Message]): The messages to send to the chat completion
            model (str, optional): The model to use. Defaults to None.
            temperature (float, optional): The temperature to use. Defaults to 0.9.
            max_tokens (int, optional): The max tokens to use. Defaults to None.

        Returns:
            str: The response from the chat completion
        """
        num_retries = 3
        response = None
        for attempt in range(num_retries):
            backoff = 2 ** (attempt + 2)
            try:
                response = self.get_openai_response(
                    messages, model, temperature, max_tokens)
                break
            except RateLimitError as e:
                print(e)
            except APIError as e:
                if e.http_status != 502:
                    raise
                if attempt == num_retries - 1:
                    raise
            time.sleep(backoff)
        if response is None:
            raise RuntimeError(
                f"Failed to get response after {num_retries} retries")
        resp = response.choices[0].message["content"]
        return resp

    def generate_chat_context(self, prompt):
        current_context = [
            self.create_chat_message("system", prompt),
            self.create_chat_message(
                "system", f"The current time and date is {time.strftime('%c')}"
            ),
            self.create_chat_message(
                "system",
                f"This reminds you of these events from your past:\n{self.relevant_memory}\n\n",
            ),
        ]

        # Add messages from the full message history until we reach the token limit
        next_message_to_add_index = len(self.full_message_history) - 1
        insertion_index = len(current_context)
        # Count the currently used tokens
        current_tokens_used = [self.count_tokens(context['content'])
                                for context in current_context]
        return (
            next_message_to_add_index,
            current_tokens_used,
            insertion_index,
            current_context,
        )

    def create_chat_message(self, role, content) -> Message:
        return {"role": role, "content": content}

    def count_tokens(self, text: str) -> int:
        """count the number of tokens in a string"""
        return len(tokenizer.encode(text))
