Skip to content

Commit

Permalink
Fixed finding postaggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
Mogball committed Dec 6, 2017
1 parent defe678 commit ac2cb53
Show file tree
Hide file tree
Showing 3 changed files with 397 additions and 148 deletions.
174 changes: 112 additions & 62 deletions superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,73 +786,123 @@ def granularity(period_name, timezone=None, origin=None):
return granularity

@staticmethod
def _metrics_and_post_aggs(metrics, metrics_dict):
all_metrics = []
post_aggs = {}

def recursive_get_fields(_conf):
_type = _conf.get('type')
_field = _conf.get('field')
_fields = _conf.get('fields')

field_names = []
if _type in ['fieldAccess', 'hyperUniqueCardinality',
'quantile', 'quantiles']:
field_names.append(_conf.get('fieldName', ''))
def get_post_agg(mconf):
"""
For a metric specified as `postagg` returns the
kind of post aggregation for pydruid.
"""
if mconf.get('type') == 'javascript':
return JavascriptPostAggregator(
name=mconf.get('name', ''),
field_names=mconf.get('fieldNames', []),
function=mconf.get('function', ''))
elif mconf.get('type') == 'quantile':
return Quantile(
mconf.get('name', ''),
mconf.get('probability', ''),
)
elif mconf.get('type') == 'quantiles':
return Quantiles(
mconf.get('name', ''),
mconf.get('probabilities', ''),
)
elif mconf.get('type') == 'fieldAccess':
return Field(mconf.get('name'))
elif mconf.get('type') == 'constant':
return Const(
mconf.get('value'),
output_name=mconf.get('name', ''),
)
elif mconf.get('type') == 'hyperUniqueCardinality':
return HyperUniqueCardinality(
mconf.get('name'),
)
elif mconf.get('type') == 'arithmetic':
return Postaggregator(
mconf.get('fn', '/'),
mconf.get('fields', []),
mconf.get('name', ''))
else:
return CustomPostAggregator(
mconf.get('name', ''),
mconf)

if _field:
field_names += recursive_get_fields(_field)
@staticmethod
def find_postaggs_for(postagg_names, metrics_dict):
"""Return a list of metrics that are post aggregations"""
postagg_metrics = [
metrics_dict[name] for name in postagg_names
if metrics_dict[name].metric_type == 'postagg'
]
# Remove post aggregations that were found
for postagg in postagg_metrics:
postagg_names.remove(postagg.metric_name)
return postagg_metrics

if _fields:
for _f in _fields:
field_names += recursive_get_fields(_f)
@staticmethod
def recursive_get_fields(_conf):
_type = _conf.get('type')
_field = _conf.get('field')
_fields = _conf.get('fields')
field_names = []
if _type in ['fieldAccess', 'hyperUniqueCardinality',
'quantile', 'quantiles']:
field_names.append(_conf.get('fieldName', ''))
if _field:
field_names += DruidDatasource.recursive_get_fields(_field)
if _fields:
for _f in _fields:
field_names += DruidDatasource.recursive_get_fields(_f)
return list(set(field_names))

return list(set(field_names))
@staticmethod
def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dict):
mconf = postagg.json_obj
required_fields = set(
DruidDatasource.recursive_get_fields(mconf)
+ mconf.get('fieldNames', []))
# Check if the fields are already in aggs
# or is a previous postagg
required_fields = set([
field for field in required_fields
if field not in visited_postaggs and field not in agg_names
])
# First try to find postaggs that match
if len(required_fields) > 0:
missing_postaggs = DruidDatasource.find_postaggs_for(
required_fields, metrics_dict)
for missing_metric in required_fields:
agg_names.add(missing_metric)
for missing_postagg in missing_postaggs:
# Add to visited first to avoid infinite recursion
# if post aggregations are cyclicly dependent
visited_postaggs.add(missing_postagg.metric_name)
for missing_postagg in missing_postaggs:
DruidDatasource.resolve_postagg(
missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)

@staticmethod
def metrics_and_post_aggs(metrics, metrics_dict):
# Separate metrics into those that are aggregations
# and those that are post aggregations
agg_names = set()
postagg_names = []
for metric_name in metrics:
metric = metrics_dict[metric_name]
if metric.metric_type != 'postagg':
all_metrics.append(metric_name)
if metrics_dict[metric_name].metric_type != 'postagg':
agg_names.add(metric_name)
else:
mconf = metric.json_obj
all_metrics += recursive_get_fields(mconf)
all_metrics += mconf.get('fieldNames', [])
if mconf.get('type') == 'javascript':
post_aggs[metric_name] = JavascriptPostAggregator(
name=mconf.get('name', ''),
field_names=mconf.get('fieldNames', []),
function=mconf.get('function', ''))
elif mconf.get('type') == 'quantile':
post_aggs[metric_name] = Quantile(
mconf.get('name', ''),
mconf.get('probability', ''),
)
elif mconf.get('type') == 'quantiles':
post_aggs[metric_name] = Quantiles(
mconf.get('name', ''),
mconf.get('probabilities', ''),
)
elif mconf.get('type') == 'fieldAccess':
post_aggs[metric_name] = Field(mconf.get('name'))
elif mconf.get('type') == 'constant':
post_aggs[metric_name] = Const(
mconf.get('value'),
output_name=mconf.get('name', ''),
)
elif mconf.get('type') == 'hyperUniqueCardinality':
post_aggs[metric_name] = HyperUniqueCardinality(
mconf.get('name'),
)
elif mconf.get('type') == 'arithmetic':
post_aggs[metric_name] = Postaggregator(
mconf.get('fn', '/'),
mconf.get('fields', []),
mconf.get('name', ''))
else:
post_aggs[metric_name] = CustomPostAggregator(
mconf.get('name', ''),
mconf)
return all_metrics, post_aggs
postagg_names.append(metric_name)
# Create the post aggregations, maintain order since postaggs
# may depend on previous ones
post_aggs = OrderedDict()
visited_postaggs = set()
for postagg_name in postagg_names:
postagg = metrics_dict[postagg_name]
visited_postaggs.add(postagg_name)
DruidDatasource.resolve_postagg(
postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
return list(agg_names), post_aggs

def values_for_column(self,
column_name,
Expand Down Expand Up @@ -940,7 +990,7 @@ def run_query( # noqa / druid

columns_dict = {c.column_name: c for c in self.columns}

all_metrics, post_aggs = self._metrics_and_post_aggs(
all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
metrics,
metrics_dict)

Expand Down

0 comments on commit ac2cb53

Please sign in to comment.