Skip to content

Commit

Permalink
Sanitize the conn_id to disallow potential script execution (apache#3…
Browse files Browse the repository at this point in the history
  • Loading branch information
andylamp authored and abhishekbhakat committed Mar 5, 2024
1 parent bdbbd68 commit 294dcb9
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 5 deletions.
37 changes: 35 additions & 2 deletions airflow/models/connection.py
Expand Up @@ -24,6 +24,7 @@
from typing import Any
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit

import re2
from sqlalchemy import Boolean, Column, Integer, String, Text
from sqlalchemy.orm import declared_attr, reconstructor, synonym

Expand All @@ -38,6 +39,13 @@
from airflow.utils.module_loading import import_string

log = logging.getLogger(__name__)
# sanitize the `conn_id` pattern by allowing alphanumeric characters plus
# the symbols #,!,-,_,.,:,\,/ and () requiring at least one match.
#
# You can try the regex here: https://regex101.com/r/69033B/1
RE_SANITIZE_CONN_ID = re2.compile(r"^[\w\#\!\(\)\-\.\:\/\\]{1,}$")
# the conn ID max len should be 250
CONN_ID_MAX_LEN: int = 250


def parse_netloc_to_hostname(*args, **kwargs):
Expand All @@ -46,10 +54,35 @@ def parse_netloc_to_hostname(*args, **kwargs):
return _parse_netloc_to_hostname(*args, **kwargs)


def sanitize_conn_id(conn_id: str | None, max_length=CONN_ID_MAX_LEN) -> str | None:
r"""Sanitizes the connection id and allows only specific characters to be within.
Namely, it allows alphanumeric characters plus the symbols #,!,-,_,.,:,\,/ and () from 1 and up to
250 consecutive matches. If desired, the max length can be adjusted by setting `max_length`.
You can try to play with the regex here: https://regex101.com/r/69033B/1
The character selection is such that it prevents the injection of javascript or
executable bits to avoid any awkward behaviour in the front-end.
:param conn_id: The connection id to sanitize.
:param max_length: The max length of the connection ID, by default it is 250.
:return: the sanitized string, `None` otherwise.
"""
# check if `conn_id` or our match group is `None` and the `conn_id` is within the specified length.
if (not isinstance(conn_id, str) or len(conn_id) > max_length) or (
res := re2.match(RE_SANITIZE_CONN_ID, conn_id)
) is None:
return None

# if we reach here, then we matched something, return the first match
return res.group(0)


# Python automatically converts all letters to lowercase in hostname
# See: https://issues.apache.org/jira/browse/AIRFLOW-3615
def _parse_netloc_to_hostname(uri_parts):
"""Parse a URI string to get correct Hostname."""
"""Parse a URI string to get the correct Hostname."""
hostname = unquote(uri_parts.hostname or "")
if "/" in hostname:
hostname = uri_parts.netloc
Expand Down Expand Up @@ -115,7 +148,7 @@ def __init__(
uri: str | None = None,
):
super().__init__()
self.conn_id = conn_id
self.conn_id = sanitize_conn_id(conn_id)
self.description = description
if extra and not isinstance(extra, str):
extra = json.dumps(extra)
Expand Down
4 changes: 2 additions & 2 deletions airflow/www/forms.py
Expand Up @@ -41,7 +41,7 @@
from airflow.providers_manager import ProvidersManager
from airflow.utils import timezone
from airflow.utils.types import DagRunType
from airflow.www.validators import ReadOnly, ValidKey
from airflow.www.validators import ReadOnly, ValidConnID
from airflow.www.widgets import (
AirflowDateTimePickerROWidget,
AirflowDateTimePickerWidget,
Expand Down Expand Up @@ -221,7 +221,7 @@ def process(self, formdata=None, obj=None, **kwargs):

conn_id = StringField(
lazy_gettext("Connection Id"),
validators=[InputRequired(), ValidKey()],
validators=[InputRequired(), ValidConnID()],
widget=BS3TextFieldWidget(),
)
conn_type = SelectField(
Expand Down
28 changes: 27 additions & 1 deletion airflow/www/validators.py
Expand Up @@ -22,6 +22,7 @@

from wtforms.validators import EqualTo, ValidationError

from airflow.models.connection import CONN_ID_MAX_LEN, sanitize_conn_id
from airflow.utils import helpers


Expand Down Expand Up @@ -85,7 +86,7 @@ class ValidKey:
Validates values that will be used as keys.
:param max_length:
The maximum length of the given key
The maximum allowed length of the given key
"""

def __init__(self, max_length=200):
Expand All @@ -108,3 +109,28 @@ class ReadOnly:

def __call__(self, form, field):
field.flags.readonly = True


class ValidConnID:
"""
Validates the connection ID adheres to the desired format.
:param max_length:
The maximum allowed length of the given Connection ID.
"""

message = (
"Connection ID must be alphanumeric characters plus dashes, dots, hashes, colons, semicolons, "
"underscores, exclamation marks, and parentheses"
)

def __init__(
self,
max_length: int = CONN_ID_MAX_LEN,
):
self.max_length = max_length

def __call__(self, form, field):
if field.data:
if sanitize_conn_id(field.data, self.max_length) is None:
raise ValidationError(f"{self.message} for 1 and up to {self.max_length} matches")
64 changes: 64 additions & 0 deletions tests/models/test_connection.py
Expand Up @@ -186,3 +186,67 @@ def test_parse_from_uri(
)
def test_get_uri(self, connection, expected_uri):
assert connection.get_uri() == expected_uri

@pytest.mark.parametrize(
"connection, expected_conn_id",
[
# a valid example of connection id
(
Connection(
conn_id="12312312312213___12312321",
conn_type="type",
login="user",
password="pass",
host="host",
port=100,
schema="schema",
extra={"param1": "val1", "param2": "val2"},
),
"12312312312213___12312321",
),
# an invalid example of connection id, which allows potential code execution
(
Connection(
conn_id="<script>alert(1)</script>",
conn_type="type",
host="protocol://host",
port=100,
schema="schema",
extra={"param1": "val1", "param2": "val2"},
),
None,
),
# a valid connection as well
(
Connection(
conn_id="a_valid_conn_id_!!##",
conn_type="type",
login="user",
password="pass",
host="protocol://host",
port=100,
schema="schema",
extra={"param1": "val1", "param2": "val2"},
),
"a_valid_conn_id_!!##",
),
# a valid connection as well testing dashes
(
Connection(
conn_id="a_-.11",
conn_type="type",
login="user",
password="pass",
host="protocol://host",
port=100,
schema="schema",
extra={"param1": "val1", "param2": "val2"},
),
"a_-.11",
),
],
)
# Responsible for ensuring that the sanitized connection id
# string works as expected.
def test_sanitize_conn_id(self, connection, expected_conn_id):
assert connection.conn_id == expected_conn_id

0 comments on commit 294dcb9

Please sign in to comment.