In [3]:
!pip install -q openai networkx
!pip install matplotlib



In [4]:
from dotenv import load_dotenv
import os

In [5]:
load_dotenv(override=True)
api_key = os.getenv('OPENAI_API_KEY')

if not api_key:
    print("No API key was found - please head over to the troubleshooting notebook in this folder to identify & fix!")
elif not api_key.startswith("sk-proj-"):
    print("An API key was found, but it doesn't start sk-proj-; please check you're using the right key - see troubleshooting notebook")
elif api_key.strip() != api_key:
    print("An API key was found, but it looks like it might have space or tab characters at the start or end - please remove them - see troubleshooting notebook")
else:
    print("API key found and looks good so far!")

API key found and looks good so far!


In [6]:
from openai import OpenAI
openai = OpenAI(api_key=api_key)
response = openai.chat.completions.create(model="gpt-4o-mini", messages=[{"role": "user", "content": "What is 2+2?"}])
print(response.choices[0].message.content)

2 + 2 equals 4.


In [7]:
import networkx as nx
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple
import json
import math

In [13]:
class TokenPredictor:
    def __init__(self, client, model_name: str, temperature: int):
        self.client = client
        self.messages = []
        self.predictions = []
        self.model_name = model_name
        self.temperature = temperature

    def predict_tokens(self, prompt: str, max_tokens: int = 100) -> List[Dict]:
        """
        Generate text token by token and track prediction probabilities.
        Returns list of predictions with top token and alternatives.
        """
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=max_tokens,
            temperature=self.temperature,
            logprobs=True,
            seed=42,
            top_logprobs=7,
            stream=True
        )

        predictions = []
        for chunk in response:
            if chunk.choices[0].delta.content:
                token = chunk.choices[0].delta.content
                logprobs = chunk.choices[0].logprobs.content[0].top_logprobs
                logprob_dict = {item.token: item.logprob for item in logprobs}

                # Get top predicted token and probability
                top_token = token
                top_prob = logprob_dict[token]

                # Get alternative predictions
                alternatives = []
                for alt_token, alt_prob in logprob_dict.items():
                    if alt_token != token:
                        alternatives.append((alt_token, math.exp(alt_prob)))
                alternatives.sort(key=lambda x: x[1], reverse=True)

                prediction = {'token': top_token, 'probability': math.exp(top_prob),'alternatives': alternatives[:3]}
                predictions.append(prediction)

        return predictions

In [14]:
model_name = "gpt-4o-mini"
temperature = 0.0

predictor = TokenPredictor(openai, model_name, temperature)
prompt = "I feel lonely. Reply in 1 short sentence."
predictions = predictor.predict_tokens(prompt)

In [15]:
from pprint import pprint

In [16]:
pprint(predictions)

[{'alternatives': [("You're", 0.21435394506452804),
                   ("It's", 0.16693902027069935),
                   ('I', 0.015527745848756037)],
  'probability': 0.582674433727415,
  'token': "I'm"},
 {'alternatives': [(' sorry', 0.07577578028593854),
                   (' really', 0.0010808834546355618),
                   (' truly', 3.898272562976294e-06)],
  'probability': 0.9231378962881625,
  'token': ' here'},
 {'alternatives': [(' to', 0.06007491717047663),
                   (' if', 0.0001687379981244691),
                   (',', 1.3850858343846484e-05)],
  'probability': 0.9397297377540044,
  'token': ' for'},
 {'alternatives': [(' support', 3.398267819495071e-09),
                   ('你', 2.335593038799337e-09),
                   ('you', 2.061153622438558e-09)],
  'probability': 1.0,
  'token': ' you'},
 {'alternatives': [('—', 0.38699359425390156),
                   (',', 0.11087553432530178),
                   (' if', 0.001792136099027015)],
  'probability': 0.496