In [None]:
# !pip install arize-phoenix

In [None]:
!pip install portpicker

In [None]:
from datetime import datetime, timezone

import phoenix as px

In [None]:
from datasets import load_dataset

sample_size = 5
path = "nvidia/ChatQA-Training-Data"
name = "synthetic_convqa"
df = load_dataset(path, name, split="train").to_pandas().sample(sample_size, random_state=42)
df

In [None]:
dataset_name = "nvidia/ChatQA-Training-Data" + "-" + datetime.now(timezone.utc).isoformat()
px.Client().upload_dataset(
    dataset_name=dataset_name,
    dataframe=df,
    input_keys=("messages", "document"),
    output_keys=("answers",),
)

In [None]:
ds = px.Client().get_dataset(name=dataset_name)
type(ds)

In [None]:
from contextlib import contextmanager
from threading import Thread
from time import sleep, time
from typing import Awaitable, Callable, Generator

from portpicker import pick_unused_port
from starlette.applications import Starlette
from starlette.responses import JSONResponse, Response
from starlette.routing import Request, Route
from uvicorn import Config, Server


async def hello(_: Request) -> Response:
    return JSONResponse(
        content={
            "id": "chatcmpl-123",
            "object": "chat.completion",
            "created": 1677652288,
            "model": "gpt-3.5-turbo-0125",
            "system_fingerprint": "fp_44709d6fcb",
            "choices": [
                {
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": "\n\nHello there, how may I assist you today?",
                    },
                    "finish_reason": "stop",
                }
            ],
            "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
        }
    )


class Receiver:
    def __init__(self, chat_completion: Callable[[Request], Awaitable[Response]]) -> None:
        self.app = Starlette(
            routes=[
                Route("/v1/chat/completions", chat_completion, methods=["POST"]),
            ]
        )

    def install_signal_handlers(self) -> None:
        pass

    @contextmanager
    def run_in_thread(self, port: int) -> Generator[Thread, None, None]:
        """A coroutine to keep the server running in a thread."""
        config = Config(app=self.app, port=port, loop="asyncio", log_level="critical")
        server = Server(config=config)
        thread = Thread(target=server.run)
        thread.start()
        time_limit = time() + 5  # 5 seconds
        try:
            while not server.started and thread.is_alive() and time() < time_limit:
                sleep(1e-3)
            if time() > time_limit:
                raise RuntimeError("server took too long to start")
            yield
        finally:
            server.should_exit = True
            thread.join(timeout=5)

In [None]:
import openai
from phoenix.experiments.types import Example

port = pick_unused_port()
client = openai.OpenAI(api_key="sk-", base_url=f"http://localhost:{port}/v1/")


def task(record: Example):
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=record.input["messages"],
        max_tokens=20,
    )
    return response.choices[0].message.content

In [None]:
import nest_asyncio
from phoenix.experiments import run_experiment

nest_asyncio.apply()

with Receiver(hello).run_in_thread(port):
    run_experiment(
        ds,
        task,
    )