Skip to content
19 changes: 19 additions & 0 deletions ai21/ai21_env_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations

import logging
import os
from dataclasses import dataclass
from typing import Optional
Expand All @@ -14,6 +16,8 @@
_ENV_AWS_REGION = "AI21_AWS_REGION"
_ENV_LOG_LEVEL = "AI21_LOG_LEVEL"

_logger = logging.getLogger(__name__)


@dataclass
class _AI21EnvConfig:
Expand Down Expand Up @@ -80,5 +84,20 @@ def log_level(self) -> Optional[str]:
self._log_level = os.getenv(_ENV_LOG_LEVEL, self._log_level)
return self._log_level

def log(self, with_secrets: bool = False) -> None:
env_vars = {
_ENV_API_VERSION: self.api_version,
_ENV_API_HOST: self.api_host,
_ENV_TIMEOUT_SEC: self.timeout_sec,
_ENV_NUM_RETRIES: self.num_retries,
_ENV_AWS_REGION: self.aws_region,
_ENV_LOG_LEVEL: self.log_level,
}

if with_secrets:
env_vars[_ENV_API_KEY] = self.api_key

_logger.debug(f"AI21 environment configuration: {env_vars}")


AI21EnvConfig = _AI21EnvConfig.from_env()
2 changes: 1 addition & 1 deletion ai21/http_client/async_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def _request(
) -> httpx.Response:
timeout = self._timeout_sec
headers = self._headers
logger.debug(f"Calling {method} {url} {headers} {params}")
logger.debug(f"Calling {method} {url} {headers} {params} {body}")

if method == "GET":
request = self._client.build_request(
Expand Down
8 changes: 3 additions & 5 deletions ai21/http_client/base_http_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from __future__ import annotations

import json

from typing import Generic, TypeVar, Union, Any, Optional, Dict, BinaryIO
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Union, Any, Optional, Dict, BinaryIO

import httpx

from ai21.stream.stream import Stream
from ai21.stream.async_stream import AsyncStream
from ai21.errors import (
BadRequest,
Unauthorized,
Expand All @@ -18,11 +15,12 @@
ServiceUnavailable,
AI21APIError,
)
from ai21.stream.async_stream import AsyncStream
from ai21.stream.stream import Stream

_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])


DEFAULT_TIMEOUT_SEC = 300
DEFAULT_NUM_RETRIES = 0
RETRY_BACK_OFF_FACTOR = 0.5
Expand Down
2 changes: 1 addition & 1 deletion ai21/http_client/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _request(
) -> httpx.Response:
timeout = self._timeout_sec
headers = self._headers
logger.debug(f"Calling {method} {url} {headers} {params}")
logger.debug(f"Calling {method} {url} {headers} {params} {body}")

if method == "GET":
request = self._client.build_request(
Expand Down
64 changes: 62 additions & 2 deletions ai21/logger.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,64 @@
import logging
import os
import re

from ai21.ai21_env_config import AI21EnvConfig

_verbose = False

logger = logging.getLogger("ai21")
httpx_logger = logging.getLogger("httpx")


class CensorSecretsFormatter(logging.Formatter):
Copy link
Contributor

@pazshalev pazshalev Jun 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using logging.Filter instead of logging.Formatter, seems simpler IMO in this case, like so:

class SensitiveInfoFilter(logging.Filter):
    def filter(self, record):
        record.msg = re.sub(r'<PATTERN>', '*********', record.msg)
        return True


logger.addFilter(SensitiveInfoFilter())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite what we need here, as loggingFilter filter the log so it won't be logged. I do want to log it, simply without some of the data that is being written

def format(self, record: logging.LogRecord) -> str:
# Get original log message
message = super().format(record)

if not get_verbose():
return self._censor_secrets(message)

return message

def _censor_secrets(self, message: str) -> str:
# Regular expression to find the Authorization key and its value
pattern = r"('Authorization':\s*'[^']*'|'api-key':\s*'[^']*'|'X-Amz-Security-Token':\s*'[^']*')"

def replacement(match):
return match.group(0).split(":")[0] + ": '**************'"

# Substitute the Authorization value with **************
return re.sub(pattern, replacement, message)


def set_verbose(value: bool) -> None:
"""
Use this function if you want to log additional, more sensitive data like - secrets and environment variables.
Log level will be set to DEBUG if verbose is set to True.
"""
global _verbose
_verbose = value

set_debug(_verbose)

AI21EnvConfig.log(with_secrets=value)


def set_debug(value: bool) -> None:
"""
Additional way to set log level to DEBUG.
"""
if value:
os.environ["AI21_LOG_LEVEL"] = "debug"
else:
os.environ["AI21_LOG_LEVEL"] = "info"

setup_logger()


def get_verbose() -> bool:
global _verbose
return _verbose


def _basic_config() -> None:
Expand All @@ -14,8 +70,12 @@ def _basic_config() -> None:

def setup_logger() -> None:
_basic_config()
# Set the root handler with the censor formatter
logger.root.handlers[0].setFormatter(CensorSecretsFormatter())

if AI21EnvConfig.log_level == "debug":
if AI21EnvConfig.log_level.lower() == "debug":
logger.setLevel(logging.DEBUG)
elif AI21EnvConfig.log_level == "info":
httpx_logger.setLevel(logging.DEBUG)
elif AI21EnvConfig.log_level.lower() == "info":
logger.setLevel(logging.INFO)
httpx_logger.setLevel(logging.INFO)