diff --git a/.sampo/changesets/trim-whitespace-config-values.md b/.sampo/changesets/trim-whitespace-config-values.md new file mode 100644 index 00000000..83f3a805 --- /dev/null +++ b/.sampo/changesets/trim-whitespace-config-values.md @@ -0,0 +1,5 @@ +--- +pypi/posthog: patch +--- + +Trim surrounding whitespace from API keys and host config before using them. diff --git a/posthog/client.py b/posthog/client.py index 9672d9e4..c0c5491c 100644 --- a/posthog/client.py +++ b/posthog/client.py @@ -45,7 +45,6 @@ ) from posthog.poller import Poller from posthog.request import ( - DEFAULT_HOST, APIError, QuotaLimitError, RequestsConnectionError, @@ -54,6 +53,7 @@ determine_server_host, flags, get, + normalize_host, remote_config, ) from posthog.types import ( @@ -221,14 +221,14 @@ def __init__( self.queue = queue.Queue(max_queue_size) # api_key: This should be the Team API Key (token), public - self.api_key = project_api_key + self.api_key = project_api_key.strip() self.on_error = on_error self.debug = debug self.send = send self.sync_mode = sync_mode # Used for session replay URL generation - we don't want the server host here. - self.raw_host = host or DEFAULT_HOST + self.raw_host = normalize_host(host) self.host = determine_server_host(host) self.gzip = gzip self.timeout = timeout @@ -278,7 +278,11 @@ def __init__( self.project_root = project_root # personal_api_key: This should be a generated Personal API Key, private - self.personal_api_key = personal_api_key + self.personal_api_key = ( + personal_api_key.strip() + if isinstance(personal_api_key, str) + else personal_api_key + ) or None if debug: # Ensures that debug level messages are logged when debug mode is on. # Otherwise, defaults to WARNING level. See https://docs.python.org/3/howto/logging.html#what-happens-if-no-configuration-is-provided @@ -287,6 +291,11 @@ def __init__( else: self.log.setLevel(logging.WARNING) + if not self.api_key: + self.log.error( + "api_key is empty after trimming whitespace; check your project API key" + ) + self._set_before_send(before_send) if self.enable_exception_autocapture: @@ -1288,12 +1297,19 @@ def _load_feature_flags(self): def _fetch_feature_flags_from_api(self): """Fetch feature flags from the PostHog API.""" + personal_api_key = self.personal_api_key + if personal_api_key is None: + self.log.warning( + "[FEATURE FLAGS] You have to specify a personal_api_key to use feature flags." + ) + return + try: # Store old flags to detect changes old_flags_by_key: dict[str, dict] = self.feature_flags_by_key or {} response = get( - self.personal_api_key, + personal_api_key, f"/flags/definitions?token={self.api_key}&send_cohorts", self.host, timeout=10, diff --git a/posthog/request.py b/posthog/request.py index 3d335d17..fa54fac7 100644 --- a/posthog/request.py +++ b/posthog/request.py @@ -165,9 +165,17 @@ def disable_connection_reuse() -> None: USER_AGENT = "posthog-python/" + VERSION +def normalize_host(host: Optional[str]) -> str: + """Normalize a configured host, defaulting blank values to DEFAULT_HOST.""" + normalized_host = (host or "").strip() + if not normalized_host: + return DEFAULT_HOST + return normalized_host + + def determine_server_host(host: Optional[str]) -> str: """Determines the server host to use.""" - host_or_default = host or DEFAULT_HOST + host_or_default = normalize_host(host) trimmed_host = remove_trailing_slash(host_or_default) if trimmed_host in ("https://app.posthog.com", "https://us.posthog.com"): return US_INGESTION_ENDPOINT @@ -190,7 +198,8 @@ def post( log = logging.getLogger("posthog") body = kwargs body["sentAt"] = datetime.now(tz=tzutc()).isoformat() - url = remove_trailing_slash(host or DEFAULT_HOST) + path + trimmed_host = remove_trailing_slash(normalize_host(host)) + url = trimmed_host + path body["api_key"] = api_key data = json.dumps(body, cls=DatetimeSerializer) log.debug("making request: %s to url: %s", data, url) @@ -330,7 +339,8 @@ def get( - not_modified=False and data=response if server returns 200 """ log = logging.getLogger("posthog") - full_url = remove_trailing_slash(host or DEFAULT_HOST) + url + trimmed_host = remove_trailing_slash(normalize_host(host)) + full_url = trimmed_host + url headers = {"Authorization": "Bearer %s" % api_key, "User-Agent": USER_AGENT} if etag: diff --git a/posthog/test/test_client.py b/posthog/test/test_client.py index 9cbc206c..587e7005 100644 --- a/posthog/test/test_client.py +++ b/posthog/test/test_client.py @@ -41,6 +41,38 @@ def setUp(self): def test_requires_api_key(self): self.assertRaises(TypeError, Client) + @parameterized.expand( + [ + ("valid_key", " \nphc_validkey\t ", "phc_validkey", False), + ("whitespace_only", " \n\t ", "", True), + ] + ) + def test_trims_api_key_whitespace( + self, _, raw_api_key, expected_api_key, expect_error_log + ): + with mock.patch.object(Client.log, "error") as mock_error: + client = Client(raw_api_key, send=False) + + self.assertEqual(client.api_key, expected_api_key) + if expect_error_log: + mock_error.assert_called_once_with( + "api_key is empty after trimming whitespace; check your project API key" + ) + else: + mock_error.assert_not_called() + + def test_trims_host_and_personal_api_key_whitespace(self): + client = Client( + FAKE_TEST_API_KEY, + host=" \nhttps://eu.posthog.com/\t ", + personal_api_key=" \n\t ", + send=False, + ) + + self.assertEqual(client.raw_host, "https://eu.posthog.com/") + self.assertEqual(client.host, "https://eu.i.posthog.com") + self.assertIsNone(client.personal_api_key) + def test_empty_flush(self): self.client.flush() diff --git a/posthog/test/test_request.py b/posthog/test/test_request.py index c87af36b..3529d907 100644 --- a/posthog/test/test_request.py +++ b/posthog/test/test_request.py @@ -348,6 +348,8 @@ def test_get_removes_trailing_slash_from_host(self, mock_get): ("https://app.posthog.com/", "https://us.i.posthog.com"), ("https://eu.posthog.com/", "https://eu.i.posthog.com"), ("https://us.posthog.com/", "https://us.i.posthog.com"), + (" \nhttps://eu.posthog.com/\t ", "https://eu.i.posthog.com"), + (" \n\t ", "https://us.i.posthog.com"), (None, "https://us.i.posthog.com"), ], )