Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make /decide endpoint more flexible (pt. 2) #1592

Merged
merged 6 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 5 additions & 54 deletions posthog/api/capture.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,16 @@
import base64
import gzip
import json
import re
import secrets
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

import lzstring # type: ignore
from dateutil import parser
from django.http import HttpResponse, JsonResponse
from django.utils import timezone
from django.views.decorators.csrf import csrf_exempt
from sentry_sdk import push_scope

from posthog.models import Team
from posthog.tasks.process_event import process_event
from posthog.utils import PersonalAPIKeyAuthentication, cors_response, get_ip_address


def _load_data(request) -> Optional[Union[Dict, List]]:
if request.method == "POST":
if request.content_type == "application/json":
data = request.body
else:
data = request.POST.get("data")
else:
data = request.GET.get("data")
if not data:
return None

# add the data in sentry's scope in case there's an exception
with push_scope() as scope:
scope.set_context("data", data)

compression = (
request.GET.get("compression") or request.POST.get("compression") or request.headers.get("content-encoding", "")
)
compression = compression.lower()

if compression == "gzip":
data = gzip.decompress(data)

if compression == "lz64":
if isinstance(data, str):
data = lzstring.LZString().decompressFromBase64(data.replace(" ", "+"))
else:
data = lzstring.LZString().decompressFromBase64(data.decode().replace(" ", "+"))

# Is it plain json?
try:
data = json.loads(data)
except json.JSONDecodeError:
# if not, it's probably base64 encoded from other libraries
data = json.loads(
base64.b64decode(data.replace(" ", "+") + "===")
.decode("utf8", "surrogatepass")
.encode("utf-16", "surrogatepass")
)
# FIXME: data can also be an array, function assumes it's either None or a dictionary.
return data
from posthog.utils import PersonalAPIKeyAuthentication, cors_response, get_ip_address, load_data_from_request


def _datetime_from_seconds_or_millis(timestamp: str) -> datetime:
Expand Down Expand Up @@ -114,7 +66,8 @@ def _get_distinct_id(data: Dict[str, Any]) -> str:
def get_event(request):
now = timezone.now()
try:
data = _load_data(request)
data_from_request = load_data_from_request(request)
data = data_from_request["data"]
except TypeError:
return cors_response(
request,
Expand All @@ -139,11 +92,9 @@ def get_event(request):
token = _get_token(data, request)
is_personal_api_key = False
if not token:
personal_api_key_with_source = PersonalAPIKeyAuthentication().find_key(
request, data if isinstance(data, dict) else None
token = PersonalAPIKeyAuthentication.find_key(
request, data_from_request["body"], data if isinstance(data, dict) else None
)
if personal_api_key_with_source:
token = personal_api_key_with_source[0]
is_personal_api_key = True
if not token:
return cors_response(
Expand Down
50 changes: 36 additions & 14 deletions posthog/api/decide.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,62 @@
import base64
import json
import secrets
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse

from django.conf import settings
from django.http import HttpRequest, JsonResponse
from django.views.decorators.csrf import csrf_exempt

from posthog.models import FeatureFlag, Team
from posthog.utils import cors_response
from posthog.utils import PersonalAPIKeyAuthentication, base64_to_json, cors_response, load_data_from_request


def _load_data(data: str) -> Dict[str, Any]:
return json.loads(
base64.b64decode(data.replace(" ", "+") + "===")
.decode("utf8", "surrogatepass")
.encode("utf-16", "surrogatepass")
)
def _load_data(request) -> Optional[Union[Dict[str, Any], List]]:
# JS Integration reloadFeatureFlags call
if request.content_type == "application/x-www-form-urlencoded":
return base64_to_json(request.POST["data"])

return load_data_from_request(request)


def _get_token(data, request):
if request.POST.get("api_key"):
return request.POST["api_key"]
if request.POST.get("token"):
return request.POST["token"]
if "token" in data:
return data["token"] # JS reloadFeatures call
if "api_key" in data:
return data["api_key"] # server-side libraries like posthog-python and posthog-ruby
return None


def feature_flags(request: HttpRequest) -> Dict[str, Any]:
feature_flags_data = {"flags_enabled": [], "has_malformed_json": False}
if request.method != "POST" or not request.POST.get("data"):
return feature_flags_data
try:
data = _load_data(request.POST["data"])
data_from_request = load_data_from_request(request)
data = data_from_request["data"]
except (json.decoder.JSONDecodeError, TypeError):
feature_flags_data["has_malformed_json"] = True
return feature_flags_data

team = Team.objects.get_cached_from_token(data["token"])
flags_enabled = []
if not data:
return feature_flags_data

token = _get_token(data, request)
is_personal_api_key = False
if not token:
token = PersonalAPIKeyAuthentication.find_key(
request, data_from_request["body"], data if isinstance(data, dict) else None
)
is_personal_api_key = True
if not token:
return feature_flags_data
team = Team.objects.get_cached_from_token(token, is_personal_api_key)
flags_enabled = []
feature_flags = FeatureFlag.objects.filter(team=team, active=True, deleted=False)
for feature_flag in feature_flags:
# distinct_id will always be a string, but data can have non-string values ("Any")
if feature_flag.distinct_id_matches(data["distinct_id"]):
flags_enabled.append(feature_flag.key)
feature_flags_data["flags_enabled"] = flags_enabled
Expand Down
14 changes: 0 additions & 14 deletions posthog/api/test/test_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,25 +158,11 @@ def test_empty_request_returns_an_error(self, patch_process_event):
# Empty GET
response = self.client.get("/e/?data=", content_type="application/json", HTTP_ORIGIN="https://localhost",)
self.assertEqual(response.status_code, 400)
self.assertEqual(
response.json(),
{
"code": "validation",
"message": "No data found. Make sure to use a POST request when sending the payload in the body of the request.",
},
)
self.assertEqual(patch_process_event.call_count, 0)

# Empty POST
response = self.client.post("/e/", {}, content_type="application/json", HTTP_ORIGIN="https://localhost",)
self.assertEqual(response.status_code, 400)
self.assertEqual(
response.json(),
{
"code": "validation",
"message": "No data found. Make sure to use a POST request when sending the payload in the body of the request.",
},
)
self.assertEqual(patch_process_event.call_count, 0)

@patch("posthog.models.team.TEAM_CACHE", {})
Expand Down
16 changes: 15 additions & 1 deletion posthog/api/test/test_decide.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from unittest.mock import patch

from posthog.models import FeatureFlag, Person
from posthog.models import FeatureFlag, Person, PersonalAPIKey

from .base import BaseTest

Expand Down Expand Up @@ -99,3 +99,17 @@ def test_feature_flags(self):
HTTP_ORIGIN="http://127.0.0.1:8000",
).json()
self.assertEqual(len(response["featureFlags"]), 0)

def test_feature_flags_with_personal_api_key(self):
key = PersonalAPIKey(label="X", user=self.user, team=self.team)
key.save()
Person.objects.create(team=self.team, distinct_ids=["example_id"])
FeatureFlag.objects.create(
team=self.team, rollout_percentage=100, name="Test", key="test", created_by=self.user,
)
response = self.client.post(
"/decide/",
{"data": json.dumps({"distinct_id": "example_id", "personal_api_key": key.value})},
HTTP_ORIGIN="http://127.0.0.1:8000",
).json()
self.assertEqual(len(response["featureFlags"]), 1)
4 changes: 2 additions & 2 deletions posthog/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ class CsrfOrKeyViewMiddleware(CsrfViewMiddleware):
"""Middleware accepting requests that either contain a valid CSRF token or a personal API key."""

def process_view(self, request, callback, callback_args, callback_kwargs):
result = super().process_view(request, callback, callback_args, callback_kwargs)
result = super().process_view(request, callback, callback_args, callback_kwargs) # None if request accepted
# if super().process_view did not find a valid CSRF token, try looking for a personal API key
if result is not None and PersonalAPIKeyAuthentication().find_key(request) is not None:
if result is not None and PersonalAPIKeyAuthentication().find_key_with_source(request) is not None:
return self._accept(request)
return result

Expand Down
7 changes: 0 additions & 7 deletions posthog/test/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,6 @@ def test_ip_range(self):

response = self.client.get("/batch/", REMOTE_ADDR="10.0.0.1",)

self.assertEqual(
response.json(),
{
"code": "validation",
"message": "No data found. Make sure to use a POST request when sending the payload in the body of the request.",
},
)
self.assertEqual(
response.status_code, status.HTTP_400_BAD_REQUEST
) # Check for a bad request exception because it means the middleware didn't block the request
Expand Down
Loading