Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .sampo/changesets/trim-whitespace-config-values.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
pypi/posthog: patch
---

Trim surrounding whitespace from API keys and host config before using them.
26 changes: 21 additions & 5 deletions posthog/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
)
from posthog.poller import Poller
from posthog.request import (
DEFAULT_HOST,
APIError,
QuotaLimitError,
RequestsConnectionError,
Expand All @@ -54,6 +53,7 @@
determine_server_host,
flags,
get,
normalize_host,
remote_config,
)
from posthog.types import (
Expand Down Expand Up @@ -221,14 +221,14 @@ def __init__(
self.queue = queue.Queue(max_queue_size)

# api_key: This should be the Team API Key (token), public
self.api_key = project_api_key
self.api_key = project_api_key.strip()

self.on_error = on_error
self.debug = debug
self.send = send
self.sync_mode = sync_mode
# Used for session replay URL generation - we don't want the server host here.
self.raw_host = host or DEFAULT_HOST
self.raw_host = normalize_host(host)
self.host = determine_server_host(host)
self.gzip = gzip
self.timeout = timeout
Expand Down Expand Up @@ -278,7 +278,11 @@ def __init__(
self.project_root = project_root

# personal_api_key: This should be a generated Personal API Key, private
self.personal_api_key = personal_api_key
self.personal_api_key = (
personal_api_key.strip()
if isinstance(personal_api_key, str)
else personal_api_key
) or None
if debug:
# Ensures that debug level messages are logged when debug mode is on.
# Otherwise, defaults to WARNING level. See https://docs.python.org/3/howto/logging.html#what-happens-if-no-configuration-is-provided
Expand All @@ -287,6 +291,11 @@ def __init__(
else:
self.log.setLevel(logging.WARNING)

if not self.api_key:
self.log.error(
"api_key is empty after trimming whitespace; check your project API key"
)

self._set_before_send(before_send)

if self.enable_exception_autocapture:
Expand Down Expand Up @@ -1288,12 +1297,19 @@ def _load_feature_flags(self):

def _fetch_feature_flags_from_api(self):
"""Fetch feature flags from the PostHog API."""
personal_api_key = self.personal_api_key
if personal_api_key is None:
self.log.warning(
"[FEATURE FLAGS] You have to specify a personal_api_key to use feature flags."
)
return

try:
# Store old flags to detect changes
old_flags_by_key: dict[str, dict] = self.feature_flags_by_key or {}

response = get(
self.personal_api_key,
personal_api_key,
f"/flags/definitions?token={self.api_key}&send_cohorts",
self.host,
timeout=10,
Expand Down
16 changes: 13 additions & 3 deletions posthog/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,17 @@ def disable_connection_reuse() -> None:
USER_AGENT = "posthog-python/" + VERSION


def normalize_host(host: Optional[str]) -> str:
"""Normalize a configured host, defaulting blank values to DEFAULT_HOST."""
normalized_host = (host or "").strip()
if not normalized_host:
return DEFAULT_HOST
return normalized_host


def determine_server_host(host: Optional[str]) -> str:
"""Determines the server host to use."""
host_or_default = host or DEFAULT_HOST
host_or_default = normalize_host(host)
trimmed_host = remove_trailing_slash(host_or_default)
if trimmed_host in ("https://app.posthog.com", "https://us.posthog.com"):
return US_INGESTION_ENDPOINT
Expand All @@ -190,7 +198,8 @@ def post(
log = logging.getLogger("posthog")
body = kwargs
body["sentAt"] = datetime.now(tz=tzutc()).isoformat()
url = remove_trailing_slash(host or DEFAULT_HOST) + path
trimmed_host = remove_trailing_slash(normalize_host(host))
url = trimmed_host + path
body["api_key"] = api_key
data = json.dumps(body, cls=DatetimeSerializer)
log.debug("making request: %s to url: %s", data, url)
Expand Down Expand Up @@ -330,7 +339,8 @@ def get(
- not_modified=False and data=response if server returns 200
"""
log = logging.getLogger("posthog")
full_url = remove_trailing_slash(host or DEFAULT_HOST) + url
trimmed_host = remove_trailing_slash(normalize_host(host))
full_url = trimmed_host + url
headers = {"Authorization": "Bearer %s" % api_key, "User-Agent": USER_AGENT}

if etag:
Expand Down
32 changes: 32 additions & 0 deletions posthog/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,38 @@ def setUp(self):
def test_requires_api_key(self):
self.assertRaises(TypeError, Client)

@parameterized.expand(
[
("valid_key", " \nphc_validkey\t ", "phc_validkey", False),
("whitespace_only", " \n\t ", "", True),
]
)
def test_trims_api_key_whitespace(
self, _, raw_api_key, expected_api_key, expect_error_log
):
with mock.patch.object(Client.log, "error") as mock_error:
client = Client(raw_api_key, send=False)

self.assertEqual(client.api_key, expected_api_key)
if expect_error_log:
mock_error.assert_called_once_with(
"api_key is empty after trimming whitespace; check your project API key"
)
else:
mock_error.assert_not_called()

def test_trims_host_and_personal_api_key_whitespace(self):
client = Client(
FAKE_TEST_API_KEY,
host=" \nhttps://eu.posthog.com/\t ",
personal_api_key=" \n\t ",
send=False,
)

self.assertEqual(client.raw_host, "https://eu.posthog.com/")
self.assertEqual(client.host, "https://eu.i.posthog.com")
self.assertIsNone(client.personal_api_key)

def test_empty_flush(self):
self.client.flush()

Expand Down
2 changes: 2 additions & 0 deletions posthog/test/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ def test_get_removes_trailing_slash_from_host(self, mock_get):
("https://app.posthog.com/", "https://us.i.posthog.com"),
("https://eu.posthog.com/", "https://eu.i.posthog.com"),
("https://us.posthog.com/", "https://us.i.posthog.com"),
(" \nhttps://eu.posthog.com/\t ", "https://eu.i.posthog.com"),
(" \n\t ", "https://us.i.posthog.com"),
(None, "https://us.i.posthog.com"),
],
)
Expand Down
Loading