diff --git a/posthog/api/capture.py b/posthog/api/capture.py index 37d88c50d2eb9..bc1b8c9dab4ba 100644 --- a/posthog/api/capture.py +++ b/posthog/api/capture.py @@ -1,64 +1,16 @@ -import base64 -import gzip -import json import re import secrets from datetime import datetime from typing import Any, Dict, List, Optional, Union -import lzstring # type: ignore from dateutil import parser from django.http import HttpResponse, JsonResponse from django.utils import timezone from django.views.decorators.csrf import csrf_exempt -from sentry_sdk import push_scope from posthog.models import Team from posthog.tasks.process_event import process_event -from posthog.utils import PersonalAPIKeyAuthentication, cors_response, get_ip_address - - -def _load_data(request) -> Optional[Union[Dict, List]]: - if request.method == "POST": - if request.content_type == "application/json": - data = request.body - else: - data = request.POST.get("data") - else: - data = request.GET.get("data") - if not data: - return None - - # add the data in sentry's scope in case there's an exception - with push_scope() as scope: - scope.set_context("data", data) - - compression = ( - request.GET.get("compression") or request.POST.get("compression") or request.headers.get("content-encoding", "") - ) - compression = compression.lower() - - if compression == "gzip": - data = gzip.decompress(data) - - if compression == "lz64": - if isinstance(data, str): - data = lzstring.LZString().decompressFromBase64(data.replace(" ", "+")) - else: - data = lzstring.LZString().decompressFromBase64(data.decode().replace(" ", "+")) - - # Is it plain json? - try: - data = json.loads(data) - except json.JSONDecodeError: - # if not, it's probably base64 encoded from other libraries - data = json.loads( - base64.b64decode(data.replace(" ", "+") + "===") - .decode("utf8", "surrogatepass") - .encode("utf-16", "surrogatepass") - ) - # FIXME: data can also be an array, function assumes it's either None or a dictionary. - return data +from posthog.utils import PersonalAPIKeyAuthentication, cors_response, get_ip_address, load_data_from_request def _datetime_from_seconds_or_millis(timestamp: str) -> datetime: @@ -114,7 +66,8 @@ def _get_distinct_id(data: Dict[str, Any]) -> str: def get_event(request): now = timezone.now() try: - data = _load_data(request) + data_from_request = load_data_from_request(request) + data = data_from_request["data"] except TypeError: return cors_response( request, @@ -139,11 +92,9 @@ def get_event(request): token = _get_token(data, request) is_personal_api_key = False if not token: - personal_api_key_with_source = PersonalAPIKeyAuthentication().find_key( - request, data if isinstance(data, dict) else None + token = PersonalAPIKeyAuthentication.find_key( + request, data_from_request["body"], data if isinstance(data, dict) else None ) - if personal_api_key_with_source: - token = personal_api_key_with_source[0] is_personal_api_key = True if not token: return cors_response( diff --git a/posthog/api/decide.py b/posthog/api/decide.py index d5d894d2024a1..54aabc3408881 100644 --- a/posthog/api/decide.py +++ b/posthog/api/decide.py @@ -1,7 +1,6 @@ -import base64 import json import secrets -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse from django.conf import settings @@ -9,32 +8,55 @@ from django.views.decorators.csrf import csrf_exempt from posthog.models import FeatureFlag, Team -from posthog.utils import cors_response +from posthog.utils import PersonalAPIKeyAuthentication, base64_to_json, cors_response, load_data_from_request -def _load_data(data: str) -> Dict[str, Any]: - return json.loads( - base64.b64decode(data.replace(" ", "+") + "===") - .decode("utf8", "surrogatepass") - .encode("utf-16", "surrogatepass") - ) +def _load_data(request) -> Optional[Union[Dict[str, Any], List]]: + # JS Integration reloadFeatureFlags call + if request.content_type == "application/x-www-form-urlencoded": + return base64_to_json(request.POST["data"]) + + return load_data_from_request(request) + + +def _get_token(data, request): + if request.POST.get("api_key"): + return request.POST["api_key"] + if request.POST.get("token"): + return request.POST["token"] + if "token" in data: + return data["token"] # JS reloadFeatures call + if "api_key" in data: + return data["api_key"] # server-side libraries like posthog-python and posthog-ruby + return None def feature_flags(request: HttpRequest) -> Dict[str, Any]: feature_flags_data = {"flags_enabled": [], "has_malformed_json": False} - if request.method != "POST" or not request.POST.get("data"): - return feature_flags_data try: - data = _load_data(request.POST["data"]) + data_from_request = load_data_from_request(request) + data = data_from_request["data"] except (json.decoder.JSONDecodeError, TypeError): feature_flags_data["has_malformed_json"] = True return feature_flags_data - team = Team.objects.get_cached_from_token(data["token"]) - flags_enabled = [] + if not data: + return feature_flags_data + token = _get_token(data, request) + is_personal_api_key = False + if not token: + token = PersonalAPIKeyAuthentication.find_key( + request, data_from_request["body"], data if isinstance(data, dict) else None + ) + is_personal_api_key = True + if not token: + return feature_flags_data + team = Team.objects.get_cached_from_token(token, is_personal_api_key) + flags_enabled = [] feature_flags = FeatureFlag.objects.filter(team=team, active=True, deleted=False) for feature_flag in feature_flags: + # distinct_id will always be a string, but data can have non-string values ("Any") if feature_flag.distinct_id_matches(data["distinct_id"]): flags_enabled.append(feature_flag.key) feature_flags_data["flags_enabled"] = flags_enabled diff --git a/posthog/api/test/test_capture.py b/posthog/api/test/test_capture.py index 8f930390bdf4e..c4947eab0bdf1 100644 --- a/posthog/api/test/test_capture.py +++ b/posthog/api/test/test_capture.py @@ -158,25 +158,11 @@ def test_empty_request_returns_an_error(self, patch_process_event): # Empty GET response = self.client.get("/e/?data=", content_type="application/json", HTTP_ORIGIN="https://localhost",) self.assertEqual(response.status_code, 400) - self.assertEqual( - response.json(), - { - "code": "validation", - "message": "No data found. Make sure to use a POST request when sending the payload in the body of the request.", - }, - ) self.assertEqual(patch_process_event.call_count, 0) # Empty POST response = self.client.post("/e/", {}, content_type="application/json", HTTP_ORIGIN="https://localhost",) self.assertEqual(response.status_code, 400) - self.assertEqual( - response.json(), - { - "code": "validation", - "message": "No data found. Make sure to use a POST request when sending the payload in the body of the request.", - }, - ) self.assertEqual(patch_process_event.call_count, 0) @patch("posthog.models.team.TEAM_CACHE", {}) diff --git a/posthog/api/test/test_decide.py b/posthog/api/test/test_decide.py index 11632b8ea8b99..d23206777d8aa 100644 --- a/posthog/api/test/test_decide.py +++ b/posthog/api/test/test_decide.py @@ -2,7 +2,7 @@ import json from unittest.mock import patch -from posthog.models import FeatureFlag, Person +from posthog.models import FeatureFlag, Person, PersonalAPIKey from .base import BaseTest @@ -99,3 +99,17 @@ def test_feature_flags(self): HTTP_ORIGIN="http://127.0.0.1:8000", ).json() self.assertEqual(len(response["featureFlags"]), 0) + + def test_feature_flags_with_personal_api_key(self): + key = PersonalAPIKey(label="X", user=self.user, team=self.team) + key.save() + Person.objects.create(team=self.team, distinct_ids=["example_id"]) + FeatureFlag.objects.create( + team=self.team, rollout_percentage=100, name="Test", key="test", created_by=self.user, + ) + response = self.client.post( + "/decide/", + {"data": json.dumps({"distinct_id": "example_id", "personal_api_key": key.value})}, + HTTP_ORIGIN="http://127.0.0.1:8000", + ).json() + self.assertEqual(len(response["featureFlags"]), 1) diff --git a/posthog/middleware.py b/posthog/middleware.py index 9f422f3e9e9e5..f958660cd6d82 100644 --- a/posthog/middleware.py +++ b/posthog/middleware.py @@ -100,9 +100,9 @@ class CsrfOrKeyViewMiddleware(CsrfViewMiddleware): """Middleware accepting requests that either contain a valid CSRF token or a personal API key.""" def process_view(self, request, callback, callback_args, callback_kwargs): - result = super().process_view(request, callback, callback_args, callback_kwargs) + result = super().process_view(request, callback, callback_args, callback_kwargs) # None if request accepted # if super().process_view did not find a valid CSRF token, try looking for a personal API key - if result is not None and PersonalAPIKeyAuthentication().find_key(request) is not None: + if result is not None and PersonalAPIKeyAuthentication().find_key_with_source(request) is not None: return self._accept(request) return result diff --git a/posthog/test/test_middleware.py b/posthog/test/test_middleware.py index 6961957713071..684e0ce7eb96f 100644 --- a/posthog/test/test_middleware.py +++ b/posthog/test/test_middleware.py @@ -28,13 +28,6 @@ def test_ip_range(self): response = self.client.get("/batch/", REMOTE_ADDR="10.0.0.1",) - self.assertEqual( - response.json(), - { - "code": "validation", - "message": "No data found. Make sure to use a POST request when sending the payload in the body of the request.", - }, - ) self.assertEqual( response.status_code, status.HTTP_400_BAD_REQUEST ) # Check for a bad request exception because it means the middleware didn't block the request diff --git a/posthog/utils.py b/posthog/utils.py index 7bc588204c409..cda4922ed1380 100644 --- a/posthog/utils.py +++ b/posthog/utils.py @@ -1,4 +1,6 @@ +import base64 import datetime +import gzip import hashlib import json import os @@ -8,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse, urlsplit +import lzstring # type: ignore import pytz import redis from dateutil import parser @@ -21,6 +24,7 @@ from rest_framework import authentication from rest_framework.exceptions import AuthenticationFailed from rest_framework.request import Request +from sentry_sdk import push_scope def relative_date_parse(input: str) -> datetime.datetime: @@ -215,7 +219,6 @@ def cors_response(request, response): class PersonalAPIKeyAuthentication(authentication.BaseAuthentication): """A way of authenticating with personal API keys. - Only the first key candidate found in the request is tried, and the order is: 1. Request Authorization header of type Bearer. 2. Request body. @@ -224,20 +227,22 @@ class PersonalAPIKeyAuthentication(authentication.BaseAuthentication): keyword = "Bearer" - def find_key( - self, request: Union[HttpRequest, Request], extra_data: Optional[Dict[str, Any]] = None + @classmethod + def find_key_with_source( + cls, + request: Union[HttpRequest, Request], + request_data: Optional[Dict[str, Any]] = None, + extra_data: Optional[Dict[str, Any]] = None, ) -> Optional[Tuple[str, str]]: + """Try to find personal API key in request and return it along with where it was found.""" if "HTTP_AUTHORIZATION" in request.META: - authorization_match = re.match(fr"^{self.keyword}\s+(\S.+)$", request.META["HTTP_AUTHORIZATION"]) + authorization_match = re.match(fr"^{cls.keyword}\s+(\S.+)$", request.META["HTTP_AUTHORIZATION"]) if authorization_match: return authorization_match.group(1).strip(), "Authorization header" - if isinstance(request, Request): + if request_data is None and isinstance(request, Request): data = request.data else: - try: - data = json.loads(request.body) - except json.JSONDecodeError: - data = {} + data = request_data or {} if "personal_api_key" in data: return data["personal_api_key"], "body" if "personal_api_key" in request.GET: @@ -247,8 +252,20 @@ def find_key( return extra_data["personal_api_key"], "query string data" return None - def authenticate(self, request: Union[HttpRequest, Request]) -> Optional[Tuple[Any, None]]: - personal_api_key_with_source = self.find_key(request) + @classmethod + def find_key( + cls, + request: Union[HttpRequest, Request], + request_data: Optional[Dict[str, Any]] = None, + extra_data: Optional[Dict[str, Any]] = None, + ) -> Optional[str]: + """Try to find personal API key in request and return it.""" + key_with_source = cls.find_key_with_source(request, request_data, extra_data) + return key_with_source[0] if key_with_source is not None else None + + @classmethod + def authenticate(cls, request: Union[HttpRequest, Request]) -> Optional[Tuple[Any, None]]: + personal_api_key_with_source = cls.find_key_with_source(request) if not personal_api_key_with_source: return None personal_api_key, source = personal_api_key_with_source @@ -263,8 +280,9 @@ def authenticate(self, request: Union[HttpRequest, Request]) -> Optional[Tuple[A personal_api_key_object.save() return personal_api_key_object.user, None - def authenticate_header(self, request) -> str: - return self.keyword + @classmethod + def authenticate_header(cls, request) -> str: + return cls.keyword class TemporaryTokenAuthentication(authentication.BaseAuthentication): @@ -321,3 +339,57 @@ def get_redis_heartbeat() -> Union[str, int]: if worker_heartbeat and (worker_heartbeat == 0 or worker_heartbeat < 300): return worker_heartbeat return "offline" + + +def base64_to_json(data) -> Dict: + return json.loads( + base64.b64decode(data.replace(" ", "+") + "===") + .decode("utf8", "surrogatepass") + .encode("utf-16", "surrogatepass") + ) + + +# Used by non-DRF endpoins from capture.py and decide.py (/decide, /batch, /capture, etc) +def load_data_from_request(request): + data_res: Dict[str, Any] = {"data": {}, "body": None} + if request.method == "POST": + if request.content_type == "application/json": + data = request.body + try: + data_res["body"] = {**json.loads(request.body)} + except: + pass + else: + data = request.POST.get("data") + else: + data = request.GET.get("data") + if not data: + return None + + # add the data in sentry's scope in case there's an exception + with push_scope() as scope: + scope.set_context("data", data) + + compression = ( + request.GET.get("compression") or request.POST.get("compression") or request.headers.get("content-encoding", "") + ) + compression = compression.lower() + + if compression == "gzip": + data = gzip.decompress(data) + + if compression == "lz64": + if isinstance(data, str): + data = lzstring.LZString().decompressFromBase64(data.replace(" ", "+")) + else: + data = lzstring.LZString().decompressFromBase64(data.decode().replace(" ", "+")) + + # Is it plain json? + try: + data = json.loads(data) + except json.JSONDecodeError: + # if not, it's probably base64 encoded from other libraries + data = base64_to_json(data) + data_res["data"] = data + # FIXME: data can also be an array, function assumes it's either None or a dictionary. + return data_res