In [50]:
import base64
import time
from typing import Any

import openai
from openai import OpenAI
import numpy as np

OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
    "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
    + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
)


class ChatCompletionSampler():
    """
    Sample from OpenAI's chat completion API
    """

    def __init__(
        self,
        model: str = "gpt-3.5-turbo",
        system_message: str | None = None,
        temperature: float = 0.5,
        max_tokens: int = 1024,
        base_url=None,
        api_key=None,
        logprobs = False
    ):
        self.api_key_name = "OPENAI_API_KEY"
        if base_url and any(provider in base_url for provider in ["google", "databricks", "together"]):
            self.client = OpenAI(base_url=base_url, api_key=api_key)
        else:
            self.client = OpenAI()
        # using api_key=os.environ.get("OPENAI_API_KEY")  # please set your API_KEY
        self.model = model
        self.system_message = system_message
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.image_format = "url"
        self.logprobs = logprobs
        self.top_logprobs = None

    def _handle_image(
        self, image: str, encoding: str = "base64", format: str = "png", fovea: int = 768
    ):
        new_image = {
            "type": "image_url",
            "image_url": {
                "url": f"data:image/{format};{encoding},{image}",
            },
        }
        return new_image

    def _handle_text(self, text: str):
        return {"type": "text", "text": text}

    def _pack_message(self, role: str, content: Any):
        return {"role": str(role), "content": content}

    def __call__(self, message_list) -> str:
        if self.system_message:
            message_list = [self._pack_message("system", self.system_message)] + message_list
        trial = 0
        while True:
            try:
                if self.logprobs:
                    try:
                        response = self.client.chat.completions.create(
                            model=self.model,
                            messages=message_list,
                            temperature=self.temperature,
                            max_tokens=self.max_tokens,
                            logprobs=self.logprobs,
                            top_logprobs=10
                        )
                    except:
                        response = self.client.chat.completions.create(
                            model=self.model,
                            messages=message_list,
                            temperature=self.temperature,
                            max_tokens=self.max_tokens,
                            logprobs=10
                        )
                    try:
                        self.top_logprobs = [t.top_logprobs for t in response.choices[0].logprobs.content]
                        return response.choices[0].message.content, float(np.exp(np.array([t.logprob for t in response.choices[0].logprobs.content])).mean()), [t.logprob for t in response.choices[0].logprobs.content]
                    except:
                        return response.choices[0].message.content, float(np.exp(response.choices[0].logprobs.token_logprobs).mean()), response.choices[0].logprobs.token_logprobs
                else:
                    response = self.client.chat.completions.create(
                        model=self.model,
                        messages=message_list,
                        temperature=self.temperature,
                        max_tokens=self.max_tokens
                    )
                    return response.choices[0].message.content, None, None
                
            # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
            except openai.BadRequestError as e:
                print("Bad Request Error", e)
                return ""
            except Exception as e:
                exception_backoff = min(2**trial, 60)  # expontial back off
                print(
                    f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
                    e,
                )
                time.sleep(exception_backoff)
                trial += 1
            # unknown error shall throw exception


In [51]:
c = ChatCompletionSampler()
c.logprobs=True

In [52]:
r = c([{"role": "user", "content": "hi"}])

In [63]:
c.top_logprobs

[[TopLogprob(token='Hello', bytes=[72, 101, 108, 108, 111], logprob=-0.000225947),
  TopLogprob(token='Hi', bytes=[72, 105], logprob=-8.637153),
  TopLogprob(token=' Hello', bytes=[32, 72, 101, 108, 108, 111], logprob=-10.439324),
  TopLogprob(token='Hey', bytes=[72, 101, 121], logprob=-11.673375),
  TopLogprob(token='Good', bytes=[71, 111, 111, 100], logprob=-12.9118395),
  TopLogprob(token='\n', bytes=[10], logprob=-12.924041),
  TopLogprob(token='hello', bytes=[104, 101, 108, 108, 111], logprob=-12.944734),
  TopLogprob(token='Greetings', bytes=[71, 114, 101, 101, 116, 105, 110, 103, 115], logprob=-13.420112),
  TopLogprob(token='\n\n', bytes=[10, 10], logprob=-14.000659),
  TopLogprob(token=' \n\n', bytes=[32, 10, 10], logprob=-15.666173)],
 [TopLogprob(token='!', bytes=[33], logprob=-0.0001726703),
  TopLogprob(token=',', bytes=[44], logprob=-9.273631),
  TopLogprob(token=' there', bytes=[32, 116, 104, 101, 114, 101], logprob=-9.48706),
  TopLogprob(token='.', bytes=[46], logprob=