Skip to content

Commit

Permalink
More complete handling for preventing XSS attacks.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamghill committed Oct 9, 2021
1 parent a28a81c commit 3a832a9
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 63 deletions.
6 changes: 4 additions & 2 deletions django_unicorn/components/unicorn_template_response.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging

from django.template.response import TemplateResponse
from django.utils.safestring import mark_safe

import orjson
from bs4 import BeautifulSoup
from bs4.dammit import EntitySubstitution
from bs4.element import Tag
from bs4.formatter import HTMLFormatter

Expand All @@ -22,6 +22,9 @@ class UnsortedAttributes(HTMLFormatter):
Prevent beautifulsoup from re-ordering attributes.
"""

def __init__(self):
super().__init__(entity_substitution=EntitySubstitution.substitute_html)

def attributes(self, tag: Tag):
for k, v in tag.attrs.items():
yield k, v
Expand Down Expand Up @@ -115,7 +118,6 @@ def render(self):
root_element.insert_after(t)

rendered_template = UnicornTemplateResponse._desoupify(soup)
rendered_template = mark_safe(rendered_template)
self.component.rendered(rendered_template)

response.content = rendered_template
Expand Down
21 changes: 0 additions & 21 deletions django_unicorn/components/unicorn_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from django.core.exceptions import ImproperlyConfigured
from django.db.models import Model
from django.http import HttpRequest
from django.utils.html import conditional_escape
from django.views.generic.base import TemplateView

from cachetools.lru import LRUCache
Expand Down Expand Up @@ -341,14 +340,6 @@ def get_frontend_context_variables(self) -> str:
if field_name in frontend_context_variables:
del frontend_context_variables[field_name]

safe_fields = []
# Keep a list of fields that are safe to not sanitize from `frontend_context_variables`
if hasattr(self, "Meta") and hasattr(self.Meta, "safe"):
if isinstance(self.Meta.safe, Sequence):
for field_name in self.Meta.safe:
if field_name in frontend_context_variables:
safe_fields.append(field_name)

# Add cleaned values to `frontend_content_variables` based on the widget in form's fields
form = self._get_form(attributes)

Expand All @@ -372,18 +363,6 @@ def get_frontend_context_variables(self) -> str:
):
frontend_context_variables[key] = value

for (
frontend_context_variable_key,
frontend_context_variable_value,
) in frontend_context_variables.items():
if (
isinstance(frontend_context_variable_value, str)
and frontend_context_variable_key not in safe_fields
):
frontend_context_variables[frontend_context_variable_key] = escape(
frontend_context_variable_value
)

encoded_frontend_context_variables = serializer.dumps(
frontend_context_variables
)
Expand Down
17 changes: 16 additions & 1 deletion django_unicorn/views/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import copy
import logging
from functools import wraps
from typing import Dict
from typing import Dict, Sequence

from django.core.cache import caches
from django.http import HttpRequest, JsonResponse
from django.http.response import HttpResponseNotModified
from django.utils.safestring import mark_safe
from django.views.decorators.csrf import csrf_protect
from django.views.decorators.http import require_POST

Expand Down Expand Up @@ -126,6 +127,20 @@ def _process_component_request(
# Re-load frontend context variables to deal with non-serializable properties
component_request.data = orjson.loads(component.get_frontend_context_variables())

# Get set of attributes that should be marked as `safe`
safe_fields = []
if hasattr(component, "Meta") and hasattr(component.Meta, "safe"):
if isinstance(component.Meta.safe, Sequence):
for field_name in component.Meta.safe:
if field_name in component._attributes().keys():
safe_fields.append(field_name)

# Mark safe attributes as such before rendering
for field_name in safe_fields:
value = getattr(component, field_name)
if isinstance(value, str):
setattr(component, field_name, mark_safe(value))

# Send back all available data for reset or refresh actions
updated_data = component_request.data

Expand Down
33 changes: 0 additions & 33 deletions tests/components/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,39 +82,6 @@ def test_get_frontend_context_variables(component):
assert frontend_context_variables_dict.get("name") == "World"


def test_get_frontend_context_variables_xss(component):
# Set component.name to a potential XSS attack
component.name = '<a><style>@keyframes x{}</style><a style="animation-name:x" onanimationend="alert(1)"></a>'

frontend_context_variables = component.get_frontend_context_variables()
frontend_context_variables_dict = orjson.loads(frontend_context_variables)
assert len(frontend_context_variables_dict) == 1
assert (
frontend_context_variables_dict.get("name")
== "&lt;a&gt;&lt;style&gt;@keyframes x{}&lt;/style&gt;&lt;a style=&quot;animation-name:x&quot; onanimationend=&quot;alert(1)&quot;&gt;&lt;/a&gt;"
)


def test_get_frontend_context_variables_safe(component):
# Set component.name to a potential XSS attack
component.name = '<a><style>@keyframes x{}</style><a style="animation-name:x" onanimationend="alert(1)"></a>'

class Meta:
safe = [
"name",
]

setattr(component, "Meta", Meta())

frontend_context_variables = component.get_frontend_context_variables()
frontend_context_variables_dict = orjson.loads(frontend_context_variables)
assert len(frontend_context_variables_dict) == 1
assert (
frontend_context_variables_dict.get("name")
== '<a><style>@keyframes x{}</style><a style="animation-name:x" onanimationend="alert(1)"></a>'
)


def test_get_context_data(component):
context_data = component.get_context_data()
assert (
Expand Down
16 changes: 15 additions & 1 deletion tests/components/test_unicorn_template_response.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest
from bs4 import BeautifulSoup

from django_unicorn.components.unicorn_template_response import get_root_element
from django_unicorn.components.unicorn_template_response import (
UnicornTemplateResponse,
get_root_element,
)


def test_get_root_element():
Expand Down Expand Up @@ -44,3 +47,14 @@ def test_get_root_element_no_element():
actual = get_root_element(soup)

assert str(actual) == expected


def test_desoupify():
html = "<div>&lt;a&gt;&lt;style&gt;@keyframes x{}&lt;/style&gt;&lt;a style=&quot;animation-name:x&quot; onanimationend=&quot;alert(1)&quot;&gt;&lt;/a&gt;!\n</div>\n\n<script type=\"application/javascript\">\n window.addEventListener('DOMContentLoaded', (event) => {\n Unicorn.addEventListener('updated', (component) => console.log('got updated', component));\n });\n</script>"
expected = "<div>&lt;a&gt;&lt;style&gt;@keyframes x{}&lt;/style&gt;&lt;a style=\"animation-name:x\" onanimationend=\"alert(1)\"&gt;&lt;/a&gt;!\n</div>\n<script type=\"application/javascript\">\n window.addEventListener('DOMContentLoaded', (event) => {\n Unicorn.addEventListener('updated', (component) => console.log('got updated', component));\n });\n</script>"

soup = BeautifulSoup(html, "html.parser")

actual = UnicornTemplateResponse._desoupify(soup)

assert expected == actual
2 changes: 1 addition & 1 deletion tests/templates/test_component_kwargs.html
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
<div>
->{{ hello }}<-
<b>{{ hello }}</b>
</div>
3 changes: 3 additions & 0 deletions tests/templates/test_component_kwargs_with_html_entity.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
<div>
->{{ hello }}<-
</div>
3 changes: 3 additions & 0 deletions tests/templates/test_component_variable.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
<div>
{{ hello }}
</div>
25 changes: 23 additions & 2 deletions tests/templatetags/test_unicorn_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ def __init__(self, *args, **kwargs):
self.hello = kwargs.get("test_kwarg")


class FakeComponentKwargsWithHtmlEntity(UnicornView):
template_name = "templates/test_component_kwargs_with_html_entity.html"
hello = "world"

def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
self.hello = kwargs.get("test_kwarg")


class FakeComponentModel(UnicornView):
template_name = "templates/test_component_model.html"
model_id = None
Expand Down Expand Up @@ -55,7 +64,7 @@ def test_unicorn_render_kwarg():
context = {}
actual = unicorn_node.render(context)

assert "->tested!<-" in actual
assert "<b>tested!</b>" in actual


def test_unicorn_render_context_variable():
Expand All @@ -67,7 +76,19 @@ def test_unicorn_render_context_variable():
context = {"test_var": {"nested": "variable!"}}
actual = unicorn_node.render(context)

assert "->variable!<-" in actual
assert "<b>variable!</b>" in actual


def test_unicorn_render_with_invalid_html():
token = Token(
TokenType.TEXT,
"unicorn 'tests.templatetags.test_unicorn_render.FakeComponentKwargsWithHtmlEntity' test_kwarg=test_var.nested",
)
unicorn_node = unicorn(None, token)
context = {"test_var": {"nested": "variable!"}}
actual = unicorn_node.render(context)

assert "-&gt;variable!&lt;-" in actual


def test_unicorn_render_parent(settings):
Expand Down
2 changes: 0 additions & 2 deletions tests/views/message/test_sync_input.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import orjson

from tests.views.message.utils import post_and_get_response


Expand Down
51 changes: 51 additions & 0 deletions tests/views/test_process_component_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from django_unicorn.components import UnicornView
from tests.views.message.utils import post_and_get_response


class FakeComponent(UnicornView):
template_name = "templates/test_component_variable.html"

hello = ""


class FakeComponentSafe(UnicornView):
template_name = "templates/test_component_variable.html"

hello = ""

class Meta:
safe = ("hello",)


def test_html_entities_encoded(client):
data = {"hello": "test"}
action_queue = [
{"payload": {"name": "hello", "value": "<b>test1</b>"}, "type": "syncInput",}
]
response = post_and_get_response(
client,
url="/message/tests.views.test_process_component_request.FakeComponent",
data=data,
action_queue=action_queue,
)

assert not response["errors"]
assert response["data"].get("hello") == "<b>test1</b>"
assert "&lt;b&gt;test1&lt;/b&gt;" in response["dom"]


def test_safe_html_entities_not_encoded(client):
data = {"hello": "test"}
action_queue = [
{"payload": {"name": "hello", "value": "<b>test1</b>"}, "type": "syncInput",}
]
response = post_and_get_response(
client,
url="/message/tests.views.test_process_component_request.FakeComponentSafe",
data=data,
action_queue=action_queue,
)

assert not response["errors"]
assert response["data"].get("hello") == "<b>test1</b>"
assert "<b>test1</b>" in response["dom"]

0 comments on commit 3a832a9

Please sign in to comment.