Skip to content

Commit

Permalink
fix: add new config to allow for specific import data urls (#22942)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar committed Feb 6, 2023
1 parent 79114bc commit 7a0f350
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 6 deletions.
7 changes: 7 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,13 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
# Prevents unsafe default endpoints to be registered on datasets.
PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET = True

# Define a list of allowed URLs for dataset data imports (v1).
# Simple example to only allow URLs that belong to certain domains:
# ALLOWED_IMPORT_URL_DOMAINS = [
# r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"
# ]
DATASET_IMPORT_ALLOWED_DATA_URLS = [r".*"]

# Path used to store SSL certificates that are generated when using custom certs.
# Defaults to temporary directory.
# Example: SSL_CERT_PATH = "/certs"
Expand Down
4 changes: 4 additions & 0 deletions superset/datasets/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,7 @@ class DatasetAccessDeniedError(ForbiddenError):

class DatasetDuplicateFailedError(CreateFailedError):
message = _("Dataset could not be duplicated.")


class DatasetForbiddenDataURI(ForbiddenError):
message = _("Data URI is not allowed.")
32 changes: 31 additions & 1 deletion superset/datasets/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sqlalchemy.sql.visitors import VisitableType

from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.exceptions import DatasetForbiddenDataURI
from superset.models.core import Database

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -75,6 +76,28 @@ def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> Dict[str, VisitableType]:
}


def validate_data_uri(data_uri: str) -> None:
"""
Validate that the data URI is configured on DATASET_IMPORT_ALLOWED_URLS
has a valid URL.
:param data_uri:
:return:
"""
allowed_urls = current_app.config["DATASET_IMPORT_ALLOWED_DATA_URLS"]
for allowed_url in allowed_urls:
try:
match = re.match(allowed_url, data_uri)
except re.error:
logger.exception(
"Invalid regular expression on DATASET_IMPORT_ALLOWED_URLS"
)
raise
if match:
return
raise DatasetForbiddenDataURI()


def import_dataset(
session: Session,
config: Dict[str, Any],
Expand Down Expand Up @@ -139,7 +162,6 @@ def import_dataset(
table_exists = True

if data_uri and (not table_exists or force_data):
logger.info("Downloading data from %s", data_uri)
load_data(data_uri, dataset, dataset.database, session)

if hasattr(g, "user") and g.user:
Expand All @@ -151,6 +173,14 @@ def import_dataset(
def load_data(
data_uri: str, dataset: SqlaTable, database: Database, session: Session
) -> None:
"""
Load data from a data URI into a dataset.
:raises DatasetUnAllowedDataURI: If a dataset is trying
to load data from a URI that is not allowed.
"""
validate_data_uri(data_uri)
logger.info("Downloading data from %s", data_uri)
data = request.urlopen(data_uri) # pylint: disable=consider-using-with
if data_uri.endswith(".gz"):
data = gzip.open(data)
Expand Down
128 changes: 123 additions & 5 deletions tests/unit_tests/datasets/commands/importers/v1/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,25 @@

import copy
import json
import re
import uuid
from typing import Any, Dict
from unittest.mock import Mock, patch

import pytest
from flask import current_app
from sqlalchemy.orm.session import Session

from superset.datasets.commands.exceptions import DatasetForbiddenDataURI
from superset.datasets.commands.importers.v1.utils import validate_data_uri


def test_import_dataset(session: Session) -> None:
"""
Test importing a dataset.
"""
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.models.core import Database

engine = session.get_bind()
Expand Down Expand Up @@ -340,13 +346,85 @@ def test_import_column_extra_is_string(session: Session) -> None:
assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'


@patch("superset.datasets.commands.importers.v1.utils.request")
def test_import_column_allowed_data_url(request: Mock, session: Session) -> None:
"""
Test importing a dataset when using data key to fetch data from a URL.
"""
import io

from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.models.core import Database

request.urlopen.return_value = io.StringIO("col1\nvalue1\nvalue2\n")

engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member

database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()

dataset_uuid = uuid.uuid4()
yaml_config: Dict[str, Any] = {
"version": "1.0.0",
"table_name": "my_table",
"main_dttm_col": "ds",
"description": "This is the description",
"default_endpoint": None,
"offset": -8,
"cache_timeout": 3600,
"schema": None,
"sql": None,
"params": {
"remote_id": 64,
"database_name": "examples",
"import_time": 1606677834,
},
"template_params": None,
"filter_select_enabled": True,
"fetch_values_predicate": None,
"extra": None,
"uuid": dataset_uuid,
"metrics": [],
"columns": [
{
"column_name": "col1",
"verbose_name": None,
"is_dttm": False,
"is_active": True,
"type": "TEXT",
"groupby": False,
"filterable": False,
"expression": None,
"description": None,
"python_date_format": None,
"extra": None,
}
],
"database_uuid": database.uuid,
"data": "https://some-external-url.com/data.csv",
}

# the Marshmallow schema should convert strings to objects
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
_ = import_dataset(session, dataset_config, force_data=True)
session.connection()
assert [("value1",), ("value2",)] == session.execute(
"SELECT * FROM my_table"
).fetchall()


def test_import_dataset_managed_externally(session: Session) -> None:
"""
Test importing a dataset that is managed externally.
"""
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.models.core import Database
from tests.integration_tests.fixtures.importexport import dataset_config

Expand All @@ -357,7 +435,6 @@ def test_import_dataset_managed_externally(session: Session) -> None:
session.add(database)
session.flush()

dataset_uuid = uuid.uuid4()
config = copy.deepcopy(dataset_config)
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_table"
Expand All @@ -366,3 +443,44 @@ def test_import_dataset_managed_externally(session: Session) -> None:
sqla_table = import_dataset(session, config)
assert sqla_table.is_managed_externally is True
assert sqla_table.external_url == "https://example.org/my_table"


@pytest.mark.parametrize(
"allowed_urls, data_uri, expected, exception_class",
[
([r".*"], "https://some-url/data.csv", True, None),
(
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
"https://host1.domain1.com/data.csv",
True,
None,
),
(
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
"https://host2.domain1.com/data.csv",
True,
None,
),
(
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
"https://host1.domain2.com/data.csv",
True,
None,
),
(
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
"https://host1.domain3.com/data.csv",
False,
DatasetForbiddenDataURI,
),
([], "https://host1.domain3.com/data.csv", False, DatasetForbiddenDataURI),
(["*"], "https://host1.domain3.com/data.csv", False, re.error),
],
)
def test_validate_data_uri(allowed_urls, data_uri, expected, exception_class):
current_app.config["DATASET_IMPORT_ALLOWED_DATA_URLS"] = allowed_urls
if expected:
validate_data_uri(data_uri)
else:
with pytest.raises(exception_class):
validate_data_uri(data_uri)

0 comments on commit 7a0f350

Please sign in to comment.