In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import nest_asyncio


sys.path.insert(0, os.path.abspath('..'))
nest_asyncio.apply()

In [None]:
import logging


logging.basicConfig(
    level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s'
)

In [None]:
from asyncio import Queue, create_task, sleep
from collections import deque
from time import ctime, perf_counter
from typing import Generic

from pydantic import BaseModel, Field
from tiktoken import get_encoding

from math_rag.application.models.inference import LLMRequest, LLMRequestBatch
from math_rag.application.types.inference import LLMResponseType

In [None]:
class LLMStatusTracker(BaseModel):
    num_tasks_started: int = 0
    num_tasks_in_progress: int = 0
    num_tasks_succeeded: int = 0
    num_tasks_failed: int = 0
    num_rate_limit_errors: int = 0
    num_api_errors: int = 0
    time_of_last_rate_limit_error: int = 0


class LLMError(BaseModel):
    message: str
    body: object | None

In [None]:
from openai import (
    APIConnectionError,
    APIError,
    APITimeoutError,
    AuthenticationError,
    BadRequestError,
    ConflictError,
    InternalServerError,
    NotFoundError,
    PermissionDeniedError,
    RateLimitError,
    UnprocessableEntityError,
)


class LLMRequestWrapper(LLMRequest[LLMResponseType], Generic[LLMResponseType]):
    errors: list[APIError] = Field(default_factory=list)

    # NOTE: for status
    token_consumption: int
    attempts_left: int

    async def call_api(  # TODO reuse existing generate method?
        self,
        retry_queue: Queue['LLMRequestWrapper'],
        status_tracker: LLMStatusTracker,
    ):
        error = None

        try:
            response = ...  # call api

        # TODO new ----
        except RateLimitError as e:
            error = e
            status_tracker.time_of_last_rate_limit_error = perf_counter()
            status_tracker.num_rate_limit_errors += 1

        except (APITimeoutError, InternalServerError, UnprocessableEntityError) as e:
            error = e
            status_tracker.num_api_errors += 1

        except (
            APIConnectionError,
            AuthenticationError,
            BadRequestError,
            ConflictError,
            NotFoundError,
            PermissionDeniedError,
        ) as e:
            raise

        # TODO new end----

        if error:
            self.errors.append(error)

            if self.attempts_left:
                retry_queue.put_nowait(self)

            else:
                logging.error(f'Request {self.id} failed after all attempts')
                # TODO: save self.errors
                status_tracker.num_tasks_in_progress -= 1
                status_tracker.num_tasks_failed += 1
        else:
            # TODO: save response
            status_tracker.num_tasks_in_progress -= 1
            status_tracker.num_tasks_succeeded += 1

In [None]:
def num_tokens_from_request(request_json: dict):
    encoding = get_encoding(...)

    max_tokens = request_json.get('max_tokens', 15)
    n = request_json.get('n', 1)
    completion_tokens = n * max_tokens

    num_tokens = 0
    for message in request_json['messages']:
        num_tokens += (
            4  # every message follows <im_start>{role/name}\n{content}<im_end>\n
        )
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == 'name':  # if there's a name, the role is omitted
                num_tokens -= 1  # role is always required and always 1 token
    num_tokens += 2  # every reply is primed with <im_start>assistant

    return num_tokens + completion_tokens

In [None]:
async def concurrent_generate(
    request_batch: LLMRequestBatch,
    max_requests_per_minute: float,
    max_tokens_per_minute: float,
    max_attempts: int,
):
    SECONDS_TO_PAUSE_AFTER_RATE_LIMIT_ERROR = 15
    SECONDS_TO_SLEEP_EACH_LOOP = 0.001

    retry_queue: Queue[LLMRequestWrapper] = Queue()
    status_tracker = LLMStatusTracker()
    next_request: LLMRequestWrapper | None = None

    available_request_capacity = max_requests_per_minute
    available_token_capacity = max_tokens_per_minute
    last_update_time = perf_counter()

    file_not_finished = True

    requests: deque[LLMRequest] = deque(request_batch.requests)

    while True:
        if next_request is None:
            if not retry_queue.empty():
                next_request = retry_queue.get_nowait()
                logging.debug(f'Retrying request {next_request.id}')

            elif file_not_finished:
                if requests:
                    request = requests.popleft()
                    token_consumption = num_tokens_from_request(request)
                    next_request = LLMRequestWrapper(
                        request=request,
                        token_consumption=token_consumption,
                        attempts_left=max_attempts,
                    )
                    status_tracker.num_tasks_started += 1
                    status_tracker.num_tasks_in_progress += 1
                    logging.debug(f'Reading request {next_request.id}')

                else:
                    file_not_finished = False

        current_time = perf_counter()
        seconds_since_update = current_time - last_update_time
        available_request_capacity = min(
            available_request_capacity
            + max_requests_per_minute * seconds_since_update / 60.0,
            max_requests_per_minute,
        )
        available_token_capacity = min(
            available_token_capacity
            + max_tokens_per_minute * seconds_since_update / 60.0,
            max_tokens_per_minute,
        )
        last_update_time = current_time

        if next_request:
            next_request_tokens = next_request.token_consumption

            if (
                available_request_capacity >= 1
                and available_token_capacity >= next_request_tokens
            ):
                available_request_capacity -= 1
                available_token_capacity -= next_request_tokens
                next_request.attempts_left -= 1

                create_task(
                    next_request.call_api(
                        retry_queue=retry_queue,
                        status_tracker=status_tracker,
                    )
                )
                next_request = None

        if status_tracker.num_tasks_in_progress == 0:
            break

        await sleep(SECONDS_TO_SLEEP_EACH_LOOP)
        seconds_since_rate_limit_error = (
            perf_counter() - status_tracker.time_of_last_rate_limit_error
        )

        if seconds_since_rate_limit_error < SECONDS_TO_PAUSE_AFTER_RATE_LIMIT_ERROR:
            remaining_seconds_to_pause = (
                SECONDS_TO_PAUSE_AFTER_RATE_LIMIT_ERROR - seconds_since_rate_limit_error
            )
            await sleep(remaining_seconds_to_pause)

            logging.warning(
                f'Pausing to cool down until {ctime(status_tracker.time_of_last_rate_limit_error + SECONDS_TO_PAUSE_AFTER_RATE_LIMIT_ERROR)}'
            )

    if status_tracker.num_tasks_failed > 0:
        logging.warning(
            f'{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed'
        )

    if status_tracker.num_rate_limit_errors > 0:
        logging.warning(
            f'{status_tracker.num_rate_limit_errors} rate limit errors received'
        )

In [None]:
from openai import APIConnectionError, NotFoundError

from math_rag.infrastructure.containers import InfrastructureContainer


RESET = False

infrastructure_container = InfrastructureContainer()
infrastructure_container.init_resources()

llm = infrastructure_container.openai_llm()

try:
    # llm.client.api_key = ""
    completion = await llm.client.chat.completions.create(
        messages=[{'role': 'user', 'content': 'what is 2+2'}],
        model='gpt-4o-mini-NONE',
        temperature=0.0,
    )
except APIConnectionError as e:
    print(e)
except NotFoundError as e:
    print(e.message)
    print(e.body)  # body (can be None) has message, type, param, code

Error code: 404 - {'error': {'message': 'The model `gpt-4o-mini-NONE` does not exist or you do not have access to it.', 'type': 'invalid_request_error', 'param': None, 'code': 'model_not_found'}}
{'message': 'The model `gpt-4o-mini-NONE` does not exist or you do not have access to it.', 'type': 'invalid_request_error', 'param': None, 'code': 'model_not_found'}
