Skip to content

Commit

Permalink
Fix MySql time grain issue (#1590)
Browse files Browse the repository at this point in the history
* Fix MySql time grain issue

* linting

* linting
  • Loading branch information
mistercrunch committed Nov 15, 2016
1 parent 84b98c2 commit 99b0d4c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 66 deletions.
112 changes: 60 additions & 52 deletions superset/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@
from superset.viz import viz_types
from superset.jinja_context import get_template_processor
from superset.utils import (
flasher, MetricPermException, DimSelector, wrap_clause_in_parens
flasher, MetricPermException, DimSelector, wrap_clause_in_parens,
DTTM_ALIAS,
)


config = app.config

QueryResult = namedtuple('namedtuple', ['df', 'query', 'duration'])
Expand Down Expand Up @@ -1013,61 +1015,36 @@ def query( # sqla

if granularity:

# TODO: sqlalchemy 1.2 release should be doing this on its own.
# Patch only if the column clause is specific for DateTime set and
# granularity is selected.
@compiles(ColumnClause)
def visit_column(element, compiler, **kw):
"""Patch for sqlalchemy bug
TODO: sqlalchemy 1.2 release should be doing this on its own.
Patch only if the column clause is specific for DateTime
set and granularity is selected.
"""
text = compiler.visit_column(element, **kw)
try:
if element.is_literal and hasattr(element.type, 'python_type') and \
type(element.type) is DateTime:

if (
element.is_literal and
hasattr(element.type, 'python_type') and
type(element.type) is DateTime
):
text = text.replace('%%', '%')
except NotImplementedError:
pass # Some elements raise NotImplementedError for python_type
# Some elements raise NotImplementedError for python_type
pass
return text

dttm_col = cols[granularity]
dttm_expr = dttm_col.sqla_col.label('timestamp')
timestamp = dttm_expr

# Transforming time grain into an expression based on configuration
time_grain_sqla = extras.get('time_grain_sqla')
if time_grain_sqla:
db_engine_spec = self.database.db_engine_spec
if dttm_col.python_date_format == 'epoch_s':
dttm_expr = \
db_engine_spec.epoch_to_dttm().format(col=dttm_expr)
elif dttm_col.python_date_format == 'epoch_ms':
dttm_expr = \
db_engine_spec.epoch_ms_to_dttm().format(col=dttm_expr)
udf = self.database.grains_dict().get(time_grain_sqla, '{col}')
timestamp_grain = literal_column(
udf.function.format(col=dttm_expr), type_=DateTime).label('timestamp')
else:
timestamp_grain = timestamp
time_grain = extras.get('time_grain_sqla')
timestamp = dttm_col.get_timestamp_expression(time_grain)

if is_timeseries:
select_exprs += [timestamp_grain]
groupby_exprs += [timestamp_grain]

outer_from = text(dttm_col.dttm_sql_literal(from_dttm))
outer_to = text(dttm_col.dttm_sql_literal(to_dttm))

time_filter = [
timestamp >= outer_from,
timestamp <= outer_to,
]
inner_time_filter = copy(time_filter)
if inner_from_dttm:
inner_time_filter[0] = timestamp >= text(
dttm_col.dttm_sql_literal(inner_from_dttm))
if inner_to_dttm:
inner_time_filter[1] = timestamp <= text(
dttm_col.dttm_sql_literal(inner_to_dttm))
else:
inner_time_filter = []
select_exprs += [timestamp]
groupby_exprs += [timestamp]

time_filter = dttm_col.get_time_filter(from_dttm, to_dttm)

select_exprs += metrics_exprs
qry = select(select_exprs)
Expand Down Expand Up @@ -1104,7 +1081,7 @@ def visit_column(element, compiler, **kw):
having_clause_and += [wrap_clause_in_parens(
template_processor.process_template(having))]
if granularity:
qry = qry.where(and_(*(time_filter + where_clause_and)))
qry = qry.where(and_(*([time_filter] + where_clause_and)))
else:
qry = qry.where(and_(*where_clause_and))
qry = qry.having(and_(*having_clause_and))
Expand All @@ -1123,7 +1100,11 @@ def visit_column(element, compiler, **kw):
inner_select_exprs += [main_metric_expr]
subq = select(inner_select_exprs)
subq = subq.select_from(tbl)
subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
inner_time_filter = dttm_col.get_time_filter(
inner_from_dttm or from_dttm,
inner_to_dttm or to_dttm,
)
subq = subq.where(and_(*(where_clause_and + [inner_time_filter])))
subq = subq.group_by(*inner_groupby_exprs)
ob = main_metric_expr
if timeseries_limit_metric_expr is not None:
Expand Down Expand Up @@ -1437,6 +1418,31 @@ def sqla_col(self):
col = literal_column(self.expression).label(name)
return col

def get_time_filter(self, start_dttm, end_dttm):
col = self.sqla_col.label('__time')
return and_(
col >= text(self.dttm_sql_literal(start_dttm)),
col <= text(self.dttm_sql_literal(end_dttm)),
)

def get_timestamp_expression(self, time_grain):
"""Getting the time component of the query"""
expr = self.expression or self.column_name
if not self.expression and not time_grain:
return column(expr, type_=DateTime).label(DTTM_ALIAS)
if time_grain:
pdf = self.python_date_format
if pdf in ('epoch_s', 'epoch_ms'):
# if epoch, translate to DATE using db specific conf
db_spec = self.table.database.db_engine_spec
if pdf == 'epoch_s':
expr = db_spec.epoch_to_dttm().format(col=expr)
elif pdf == 'epoch_ms':
expr = db_spec.epoch_ms_to_dttm().format(col=expr)
grain = self.table.database.grains_dict().get(time_grain, '{col}')
expr = grain.function.format(col=expr)
return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)

@classmethod
def import_obj(cls, column_to_import):
session = db.session
Expand Down Expand Up @@ -2070,19 +2076,21 @@ def recursive_get_fields(_conf):
query_str += json.dumps(
client.query_builder.last_query.query_dict, indent=2)
df = client.export_pandas()
df.columns = [
DTTM_ALIAS if c == 'timestamp' else c for c in df.columns]
if df is None or df.size == 0:
raise Exception(_("No data was returned."))

if (
not is_timeseries and
granularity == "all" and
'timestamp' in df.columns):
del df['timestamp']
DTTM_ALIAS in df.columns):
del df[DTTM_ALIAS]

# Reordering columns
cols = []
if 'timestamp' in df.columns:
cols += ['timestamp']
if DTTM_ALIAS in df.columns:
cols += [DTTM_ALIAS]
cols += [col for col in groupby if col in df.columns]
cols += [col for col in metrics if col in df.columns]
df = df[cols]
Expand All @@ -2093,7 +2101,7 @@ def increment_timestamp(ts):
dt = utils.parse_human_datetime(ts).replace(
tzinfo=config.get("DRUID_TZ"))
return dt + timedelta(milliseconds=time_offset)
if 'timestamp' in df.columns and time_offset:
if DTTM_ALIAS in df.columns and time_offset:
df.timestamp = df.timestamp.apply(increment_timestamp)

return QueryResult(
Expand Down
1 change: 1 addition & 0 deletions superset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@


EPOCH = datetime(1970, 1, 1)
DTTM_ALIAS = '__timestamp'


class SupersetException(Exception):
Expand Down
27 changes: 13 additions & 14 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from superset import app, utils, cache
from superset.forms import FormFactory
from superset.utils import flasher
from superset.utils import flasher, DTTM_ALIAS

config = app.config

Expand Down Expand Up @@ -177,15 +177,14 @@ def get_df(self, query_obj=None):
if df is None or df.empty:
raise utils.NoDataException("No data.")
else:
if 'timestamp' in df.columns:
if DTTM_ALIAS in df.columns:
if timestamp_format in ("epoch_s", "epoch_ms"):
df.timestamp = pd.to_datetime(
df.timestamp, utc=False)
df[DTTM_ALIAS] = pd.to_datetime(df[DTTM_ALIAS], utc=False)
else:
df.timestamp = pd.to_datetime(
df.timestamp, utc=False, format=timestamp_format)
df[DTTM_ALIAS] = pd.to_datetime(
df[DTTM_ALIAS], utc=False, format=timestamp_format)
if self.datasource.offset:
df.timestamp += timedelta(hours=self.datasource.offset)
df[DTTM_ALIAS] += timedelta(hours=self.datasource.offset)
df.replace([np.inf, -np.inf], np.nan)
df = df.fillna(0)
return df
Expand Down Expand Up @@ -449,8 +448,8 @@ def get_df(self, query_obj=None):
df = super(TableViz, self).get_df(query_obj)
if (
self.form_data.get("granularity") == "all" and
'timestamp' in df):
del df['timestamp']
DTTM_ALIAS in df):
del df[DTTM_ALIAS]
return df

def get_data(self):
Expand Down Expand Up @@ -507,8 +506,8 @@ def get_df(self, query_obj=None):
df = super(PivotTableViz, self).get_df(query_obj)
if (
self.form_data.get("granularity") == "all" and
'timestamp' in df):
del df['timestamp']
DTTM_ALIAS in df):
del df[DTTM_ALIAS]
df = df.pivot_table(
index=self.form_data.get('groupby'),
columns=self.form_data.get('columns'),
Expand Down Expand Up @@ -1041,7 +1040,7 @@ def get_df(self, query_obj=None):
raise Exception("Pick a time granularity for your time series")

df = df.pivot_table(
index="timestamp",
index=DTTM_ALIAS,
columns=form_data.get('groupby'),
values=form_data.get('metrics'))

Expand Down Expand Up @@ -1108,7 +1107,7 @@ def to_series(self, df, classed='', title_suffix=''):
ys = series[name]
if df[name].dtype.kind not in "biufc":
continue
df['timestamp'] = pd.to_datetime(df.index, utc=False)
df[DTTM_ALIAS] = pd.to_datetime(df.index, utc=False)
if isinstance(name, string_types):
series_title = name
else:
Expand All @@ -1125,7 +1124,7 @@ def to_series(self, df, classed='', title_suffix=''):
"classed": classed,
"values": [
{'x': ds, 'y': ys[ds] if ds in ys else None}
for ds in df.timestamp
for ds in df[DTTM_ALIAS]
],
}
chart_data.append(d)
Expand Down

0 comments on commit 99b0d4c

Please sign in to comment.