Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add geospatial post processing operations #9661

Merged
merged 8 commits into from
Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ flask-talisman==0.7.0 # via apache-superset (setup.py)
flask-wtf==0.14.2 # via apache-superset (setup.py), flask-appbuilder
flask==1.1.1 # via apache-superset (setup.py), flask-appbuilder, flask-babel, flask-caching, flask-compress, flask-jwt-extended, flask-login, flask-migrate, flask-openid, flask-sqlalchemy, flask-wtf
geographiclib==1.50 # via geopy
geopy==1.20.0 # via apache-superset (setup.py)
geopy==1.21.0 # via apache-superset (setup.py)
gunicorn==20.0.4 # via apache-superset (setup.py)
humanize==0.5.1 # via apache-superset (setup.py)
importlib-metadata==1.4.0 # via jsonschema, kombu
Expand Down
84 changes: 82 additions & 2 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,23 @@ class ChartDataSelectOptionsSchema(ChartDataPostProcessingOperationOptionsSchema
columns = fields.List(
fields.String(),
description="Columns which to select from the input data, in the desired "
"order. If columns are renamed, the old column name should be "
"order. If columns are renamed, the original column name should be "
"referenced here.",
example=["country", "gender", "age"],
required=False,
)
drop = fields.List(
villebro marked this conversation as resolved.
Show resolved Hide resolved
fields.String(),
description="Columns which to drop from selection.",
example=["my_temp_column"],
required=False,
)
rename = fields.List(
fields.Dict(),
description="columns which to rename, mapping source column to target column. "
"For instance, `{'y': 'y2'}` will rename the column `y` to `y2`.",
example=[{"age": "average_age"}],
required=False,
)


Expand Down Expand Up @@ -335,12 +343,81 @@ class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema)
aggregates = ChartDataAggregateConfigField()


class ChartDataGeohashDecodeOptionsSchema(
ChartDataPostProcessingOperationOptionsSchema
):
"""
Geohash decode operation config.
"""

geohash = fields.String(
description="Name of source column containing geohash string", required=True,
)
latitude = fields.String(
description="Name of target column for decoded latitude", required=True,
)
longitude = fields.String(
description="Name of target column for decoded longitude", required=True,
)


class ChartDataGeohashEncodeOptionsSchema(
ChartDataPostProcessingOperationOptionsSchema
):
"""
Geohash encode operation config.
"""

latitude = fields.String(
description="Name of source latitude column", required=True,
)
longitude = fields.String(
description="Name of source longitude column", required=True,
)
geohash = fields.String(
description="Name of target column for encoded geohash string", required=True,
)


class ChartDataGeodeticParseOptionsSchema(
ChartDataPostProcessingOperationOptionsSchema
):
"""
Geodetic point string parsing operation config.
"""

geodetic = fields.String(
description="Name of source column containing geodetic point strings",
required=True,
)
latitude = fields.String(
description="Name of target column for decoded latitude", required=True,
)
longitude = fields.String(
description="Name of target column for decoded longitude", required=True,
)
altitude = fields.String(
description="Name of target column for decoded longitude. If omitted, "
villebro marked this conversation as resolved.
Show resolved Hide resolved
"altitude information in geodetic string is ignored.",
required=False,
)


class ChartDataPostProcessingOperationSchema(Schema):
operation = fields.String(
description="Post processing operation type",
required=True,
validate=validate.OneOf(
choices=("aggregate", "pivot", "rolling", "select", "sort")
choices=(
"aggregate",
"geodetic_parse",
"geohash_decode",
"geohash_encode",
"pivot",
"rolling",
"select",
"sort",
)
),
example="aggregate",
)
Expand Down Expand Up @@ -638,4 +715,7 @@ class ChartDataResponseSchema(Schema):
ChartDataRollingOptionsSchema,
ChartDataSelectOptionsSchema,
ChartDataSortOptionsSchema,
ChartDataGeohashDecodeOptionsSchema,
ChartDataGeohashEncodeOptionsSchema,
ChartDataGeodeticParseOptionsSchema,
)
119 changes: 109 additions & 10 deletions superset/utils/pandas_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import geohash as geohash_lib
import numpy as np
from flask_babel import gettext as _
from geopy.point import Point
from pandas import DataFrame, NamedAgg

from superset.exceptions import QueryObjectValidationError
Expand Down Expand Up @@ -144,10 +146,7 @@ def _append_columns(
:return: new DataFrame with combined data from `base_df` and `append_df`
"""
return base_df.assign(
**{
target: append_df[append_df.columns[idx]]
for idx, target in enumerate(columns.values())
}
**{target: append_df[source] for source, target in columns.items()}
)


Expand Down Expand Up @@ -325,23 +324,32 @@ def rolling( # pylint: disable=too-many-arguments

@validate_column_args("columns", "rename")
def select(
df: DataFrame, columns: List[str], rename: Optional[Dict[str, str]] = None
df: DataFrame,
columns: Optional[List[str]] = None,
drop: Optional[List[str]] = None,
villebro marked this conversation as resolved.
Show resolved Hide resolved
rename: Optional[Dict[str, str]] = None,
) -> DataFrame:
"""
Only select a subset of columns in the original dataset. Can be useful for
removing unnecessary intermediate results, renaming and reordering columns.

:param df: DataFrame on which the rolling period will be based.
:param columns: Columns which to select from the DataFrame, in the desired order.
If columns are renamed, the old column name should be referenced
here.
If left undefined, all columns will be selected. If columns are
renamed, the original column name should be referenced here.
:param drop: columns to drop from selection. If columns are renamed, the new
column name should be referenced here.
:param rename: columns which to rename, mapping source column to target column.
For instance, `{'y': 'y2'}` will rename the column `y` to
`y2`.
:return: Subset of columns in original DataFrame
:raises ChartDataValidationError: If the request in incorrect
"""
df_select = df[columns]
df_select = df.copy(deep=False)
if columns:
df_select = df_select[columns]
if drop:
df_select = df_select.drop(drop, axis=1)
if rename is not None:
df_select = df_select.rename(columns=rename)
return df_select
Expand All @@ -350,6 +358,7 @@ def select(
@validate_column_args("columns")
def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame:
"""
Calculate row-by-row difference for select columns.

:param df: DataFrame on which the diff will be based.
:param columns: columns on which to perform diff, mapping source column to
Expand All @@ -369,6 +378,7 @@ def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame
@validate_column_args("columns")
def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
"""
Calculate cumulative sum/product/min/max for select columns.

:param df: DataFrame on which the cumulative operation will be based.
:param columns: columns on which to perform a cumulative operation, mapping source
Expand All @@ -377,7 +387,7 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
`y2` based on cumulative values calculated from `y`, leaving the original
column `y` unchanged.
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
:return:
:return: DataFrame with cumulated columns
"""
df_cum = df[columns.keys()]
operation = "cum" + operator
Expand All @@ -388,3 +398,92 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
_("Invalid cumulative operator: %(operator)s", operator=operator)
)
return _append_columns(df, getattr(df_cum, operation)(), columns)


def geohash_decode(
df: DataFrame, geohash: str, longitude: str, latitude: str
) -> DataFrame:
"""
Decode a geohash column into longitude and latitude

:param df: DataFrame containing geohash data
:param geohash: Name of source column containing geohash location.
:param longitude: Name of new column to be created containing longitude.
:param latitude: Name of new column to be created containing latitude.
:return: DataFrame with decoded longitudes and latitudes
"""
try:
lonlat_df = DataFrame()
lonlat_df["latitude"], lonlat_df["longitude"] = zip(
*df[geohash].apply(geohash_lib.decode)
)
return _append_columns(
df, lonlat_df, {"latitude": latitude, "longitude": longitude}
)
except ValueError:
raise QueryObjectValidationError(_("Invalid geohash string"))


def geohash_encode(
df: DataFrame, geohash: str, longitude: str, latitude: str,
) -> DataFrame:
"""
Encode longitude and latitude into geohash

:param df: DataFrame containing longitude and latitude data
:param geohash: Name of new column to be created containing geohash location.
:param longitude: Name of source column containing longitude.
:param latitude: Name of source column containing latitude.
:return: DataFrame with decoded longitudes and latitudes
"""
try:
encode_df = df[[latitude, longitude]]
encode_df.columns = ["latitude", "longitude"]
encode_df["geohash"] = encode_df.apply(
lambda row: geohash_lib.encode(row["latitude"], row["longitude"]), axis=1,
)
return _append_columns(df, encode_df, {"geohash": geohash})
except ValueError:
QueryObjectValidationError(_("Invalid longitude/latitude"))


def geodetic_parse(
df: DataFrame,
geodetic: str,
longitude: str,
latitude: str,
altitude: Optional[str] = None,
) -> DataFrame:
"""
Parse a column containing a geodetic point string
[Geopy](https://geopy.readthedocs.io/en/stable/#geopy.point.Point).

:param df: DataFrame containing geodetic point data
:param geodetic: Name of source column containing geodetic point string.
:param longitude: Name of new column to be created containing longitude.
:param latitude: Name of new column to be created containing latitude.
:param altitude: Name of new column to be created containing altitude.
:return: DataFrame with decoded longitudes and latitudes
"""

def _parse_location(location: str) -> Tuple[float, float, float]:
"""
Parse a string containing a geodetic point and return latitude, longitude
and altitude
"""
point = Point(location) # type: ignore
return point[0], point[1], point[2]

try:
geodetic_df = DataFrame()
(
geodetic_df["latitude"],
geodetic_df["longitude"],
geodetic_df["altitude"],
) = zip(*df[geodetic].apply(_parse_location))
villebro marked this conversation as resolved.
Show resolved Hide resolved
columns = {"latitude": longitude, "longitude": latitude}
if altitude:
columns["altitude"] = altitude
return _append_columns(df, geodetic_df, columns)
except ValueError:
raise QueryObjectValidationError(_("Invalid geodetic string"))
14 changes: 14 additions & 0 deletions tests/fixtures/dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,17 @@
index=to_datetime(["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]),
data={"label": ["x", "y", "z", "q"], "y": [1.0, 2.0, 3.0, 4.0]},
)

lonlat_df = DataFrame(
{
"city": ["New York City", "Sydney"],
"geohash": ["dr5regw3pg6f", "r3gx2u9qdevk"],
"latitude": [40.71277496, -33.85598011],
"longitude": [-74.00597306, 151.20666526],
"altitude": [5.5, 0.012],
"geodetic": [
"40.71277496, -74.00597306, 5.5km",
"-33.85598011, 151.20666526, 12m",
],
}
)