diff --git a/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts b/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts index 55721500acf6..f1adec9e4488 100644 --- a/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts +++ b/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts @@ -128,7 +128,9 @@ describe('Test datatable', () => { cy.get('.ant-empty-description').should('not.exist'); }); it('Datapane loads view samples', () => { - cy.intercept('/api/v1/dataset/*/samples?force=false').as('Samples'); + cy.intercept( + 'api/v1/explore/samples?force=false&datasource_type=table&datasource_id=*', + ).as('Samples'); cy.contains('Samples') .click() .then(() => { diff --git a/superset-frontend/packages/superset-ui-core/src/query/DatasourceKey.ts b/superset-frontend/packages/superset-ui-core/src/query/DatasourceKey.ts index 2fe4bcf13905..38a38e10b13a 100644 --- a/superset-frontend/packages/superset-ui-core/src/query/DatasourceKey.ts +++ b/superset-frontend/packages/superset-ui-core/src/query/DatasourceKey.ts @@ -27,8 +27,8 @@ export default class DatasourceKey { constructor(key: string) { const [idStr, typeStr] = key.split('__'); this.id = parseInt(idStr, 10); - this.type = - typeStr === 'table' ? DatasourceType.Table : DatasourceType.Druid; + this.type = DatasourceType.Table; // default to SqlaTable model + this.type = typeStr === 'query' ? DatasourceType.Query : this.type; } public toString() { diff --git a/superset-frontend/packages/superset-ui-core/src/query/types/Column.ts b/superset-frontend/packages/superset-ui-core/src/query/types/Column.ts index 384d876d810b..c2b35f46c41d 100644 --- a/superset-frontend/packages/superset-ui-core/src/query/types/Column.ts +++ b/superset-frontend/packages/superset-ui-core/src/query/types/Column.ts @@ -38,7 +38,7 @@ export type PhysicalColumn = string; * Column information defined in datasource. */ export interface Column { - id: number; + id?: number; type?: string; type_generic?: GenericDataType; column_name: string; diff --git a/superset-frontend/packages/superset-ui-core/src/query/types/Datasource.ts b/superset-frontend/packages/superset-ui-core/src/query/types/Datasource.ts index 03916dee5ebb..e53a8d05ff65 100644 --- a/superset-frontend/packages/superset-ui-core/src/query/types/Datasource.ts +++ b/superset-frontend/packages/superset-ui-core/src/query/types/Datasource.ts @@ -21,7 +21,6 @@ import { Metric } from './Metric'; export enum DatasourceType { Table = 'table', - Druid = 'druid', Query = 'query', Dataset = 'dataset', SlTable = 'sl_table', @@ -47,11 +46,14 @@ export interface Datasource { }; } -export const DEFAULT_METRICS = [ +export const DEFAULT_METRICS: Metric[] = [ { metric_name: 'COUNT(*)', expression: 'COUNT(*)', }, ]; +export const isValidDatasourceType = (datasource: DatasourceType) => + Object.values(DatasourceType).includes(datasource); + export default {}; diff --git a/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts b/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts index 7ec5abab9e4d..9105e5b9c386 100644 --- a/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts +++ b/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts @@ -334,6 +334,7 @@ export type QueryResults = { expanded_columns: QueryColumn[]; selected_columns: QueryColumn[]; query: { limit: number }; + query_id?: number; }; }; diff --git a/superset-frontend/packages/superset-ui-core/test/query/DatasourceKey.test.ts b/superset-frontend/packages/superset-ui-core/test/query/DatasourceKey.test.ts index 6b1d62e6aa13..d7629bc1b1ca 100644 --- a/superset-frontend/packages/superset-ui-core/test/query/DatasourceKey.test.ts +++ b/superset-frontend/packages/superset-ui-core/test/query/DatasourceKey.test.ts @@ -20,17 +20,10 @@ import { DatasourceKey } from '@superset-ui/core'; describe('DatasourceKey', () => { const tableKey = '5__table'; - const druidKey = '5__druid'; it('should handle table data sources', () => { const datasourceKey = new DatasourceKey(tableKey); expect(datasourceKey.toString()).toBe(tableKey); expect(datasourceKey.toObject()).toEqual({ id: 5, type: 'table' }); }); - - it('should handle druid data sources', () => { - const datasourceKey = new DatasourceKey(druidKey); - expect(datasourceKey.toString()).toBe(druidKey); - expect(datasourceKey.toObject()).toEqual({ id: 5, type: 'druid' }); - }); }); diff --git a/superset-frontend/packages/superset-ui-core/test/query/buildQueryContext.test.ts b/superset-frontend/packages/superset-ui-core/test/query/buildQueryContext.test.ts index 366feeff7a2e..baae438321ae 100644 --- a/superset-frontend/packages/superset-ui-core/test/query/buildQueryContext.test.ts +++ b/superset-frontend/packages/superset-ui-core/test/query/buildQueryContext.test.ts @@ -31,17 +31,6 @@ describe('buildQueryContext', () => { expect(queryContext.result_format).toBe('json'); expect(queryContext.result_type).toBe('full'); }); - it('should build datasource for druid sources and set force to true', () => { - const queryContext = buildQueryContext({ - datasource: '5__druid', - granularity: 'ds', - viz_type: 'table', - force: true, - }); - expect(queryContext.datasource.id).toBe(5); - expect(queryContext.datasource.type).toBe('druid'); - expect(queryContext.force).toBe(true); - }); it('should build datasource for table sources with columns', () => { const queryContext = buildQueryContext( { diff --git a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx index 339303ba5792..b9b60c898ec3 100644 --- a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx +++ b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx @@ -23,7 +23,11 @@ import Button from 'src/components/Button'; import shortid from 'shortid'; import { styled, t, QueryResponse } from '@superset-ui/core'; import ErrorMessageWithStackTrace from 'src/components/ErrorMessage/ErrorMessageWithStackTrace'; -import { SaveDatasetModal } from 'src/SqlLab/components/SaveDatasetModal'; +import { + ISaveableDatasource, + ISimpleColumn, + SaveDatasetModal, +} from 'src/SqlLab/components/SaveDatasetModal'; import { UserWithPermissionsAndRoles } from 'src/types/bootstrapTypes'; import ProgressBar from 'src/components/ProgressBar'; import Loading from 'src/components/Loading'; @@ -220,6 +224,15 @@ export default class ResultSet extends React.PureComponent< const { showSaveDatasetModal } = this.state; const { query } = this.props; + const datasource: ISaveableDatasource = { + columns: query.columns as ISimpleColumn[], + name: query?.tab || 'Untitled', + dbId: 1, + sql: query.sql, + templateParams: query.templateParams, + schema: query.schema, + }; + return ( {this.props.visualize && this.props.database?.allows_virtual_table_explore && ( this.setState({ showSaveDatasetModal: true })} + onClick={() => { + // There is currently redux / state issue where sometimes a query will have serverId + // and other times it will not. We need this attribute consistently for this to work + // const qid = this.props?.query?.results?.query_id; + // if (qid) { + // // This will open explore using the query as datasource + // window.location.href = `/explore/?dataset_type=query&dataset_id=${qid}`; + // } else { + // this.setState({ showSaveDatasetModal: true }); + // } + this.setState({ showSaveDatasetModal: true }); + }} /> )} {this.props.csv && ( diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/SaveDatasetModal.test.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/SaveDatasetModal.test.tsx index d0698e3edf95..b7f3ad8a8a42 100644 --- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/SaveDatasetModal.test.tsx +++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/SaveDatasetModal.test.tsx @@ -17,16 +17,42 @@ * under the License. */ import React from 'react'; -import { QueryResponse, testQuery } from '@superset-ui/core'; -import { SaveDatasetModal } from 'src/SqlLab/components/SaveDatasetModal'; +import { + ISaveableDatasource, + SaveDatasetModal, +} from 'src/SqlLab/components/SaveDatasetModal'; import { render, screen } from 'spec/helpers/testing-library'; +import { DatasourceType } from '@superset-ui/core'; + +const testQuery: ISaveableDatasource = { + name: 'unimportant', + dbId: 1, + sql: 'SELECT *', + columns: [ + { + name: 'Column 1', + type: DatasourceType.Query, + is_dttm: false, + }, + { + name: 'Column 3', + type: DatasourceType.Query, + is_dttm: false, + }, + { + name: 'Column 2', + type: DatasourceType.Query, + is_dttm: true, + }, + ], +}; const mockedProps = { visible: true, onHide: () => {}, buttonTextOnSave: 'Save', buttonTextOnOverwrite: 'Overwrite', - datasource: testQuery as QueryResponse, + datasource: testQuery, }; describe('SaveDatasetModal RTL', () => { @@ -36,6 +62,7 @@ describe('SaveDatasetModal RTL', () => { const saveRadioBtn = screen.getByRole('radio', { name: /save as new unimportant/i, }); + const fieldLabel = screen.getByText(/save as new/i); const inputField = screen.getByRole('textbox'); const inputFieldText = screen.getByDisplayValue(/unimportant/i); diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx index 0fa28095cc4f..1189740d63fd 100644 --- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx +++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx @@ -30,6 +30,7 @@ import { JsonResponse, JsonObject, QueryResponse, + QueryFormData, } from '@superset-ui/core'; import { useSelector, useDispatch } from 'react-redux'; import moment from 'moment'; @@ -43,14 +44,36 @@ import { DatasetOwner, SqlLabExploreRootState, getInitialState, - ExploreDatasource, SqlLabRootState, } from 'src/SqlLab/types'; import { mountExploreUrl } from 'src/explore/exploreUtils'; import { postFormData } from 'src/explore/exploreUtils/formData'; import { URL_PARAMS } from 'src/constants'; import { SelectValue } from 'antd/lib/select'; -import { isEmpty } from 'lodash'; +import { isEmpty, isString } from 'lodash'; + +interface QueryDatabase { + id?: number; +} + +export type ExploreQuery = QueryResponse & { + database?: QueryDatabase | null | undefined; +}; + +export interface ISimpleColumn { + name?: string | null; + type?: string | null; + is_dttm?: boolean | null; +} + +export interface ISaveableDatasource { + columns: ISimpleColumn[]; + name: string; + dbId: number; + sql: string; + templateParams?: string | object | null; + schema?: string | null; +} interface SaveDatasetModalProps { visible: boolean; @@ -58,7 +81,9 @@ interface SaveDatasetModalProps { buttonTextOnSave: string; buttonTextOnOverwrite: string; modalDescription?: string; - datasource: ExploreDatasource; + datasource: ISaveableDatasource; + openWindow?: boolean; + formData?: Omit; } const Styles = styled.div` @@ -113,6 +138,8 @@ const updateDataset = async ( return data.json.result; }; +const UNTITLED = t('Untitled Dataset'); + // eslint-disable-next-line no-empty-pattern export const SaveDatasetModal: FunctionComponent = ({ visible, @@ -121,13 +148,15 @@ export const SaveDatasetModal: FunctionComponent = ({ buttonTextOnOverwrite, modalDescription, datasource, + openWindow = true, + formData = {}, }) => { const defaultVizType = useSelector( state => state.common?.conf?.DEFAULT_VIZ_TYPE || 'table', ); - const query = datasource as QueryResponse; + const getDefaultDatasetName = () => - `${query.tab} ${moment().format('MM/DD/YYYY HH:mm:ss')}`; + `${datasource?.name || UNTITLED} ${moment().format('MM/DD/YYYY HH:mm:ss')}`; const [datasetName, setDatasetName] = useState(getDefaultDatasetName()); const [newOrOverwrite, setNewOrOverwrite] = useState( DatasetRadioState.SAVE_NEW, @@ -145,29 +174,38 @@ export const SaveDatasetModal: FunctionComponent = ({ ); const dispatch = useDispatch<(dispatch: any) => Promise>(); + const createWindow = (url: string) => { + if (openWindow) { + window.open(url, '_blank', 'noreferrer'); + } else { + window.location.href = url; + } + }; + const formDataWithDefaults = { + ...EXPLORE_CHART_DEFAULT, + ...(formData || {}), + }; const handleOverwriteDataset = async () => { const [, key] = await Promise.all([ updateDataset( - query.dbId, - datasetToOverwrite.datasetid, - query.sql, - query.results.selected_columns.map( + datasource?.dbId, + datasetToOverwrite?.datasetid, + datasource?.sql, + datasource?.columns?.map( (d: { name: string; type: string; is_dttm: boolean }) => ({ column_name: d.name, type: d.type, is_dttm: d.is_dttm, }), ), - datasetToOverwrite.owners?.map((o: DatasetOwner) => o.id), + datasetToOverwrite?.owners?.map((o: DatasetOwner) => o.id), true, ), postFormData(datasetToOverwrite.datasetid, 'table', { - ...EXPLORE_CHART_DEFAULT, + ...formDataWithDefaults, datasource: `${datasetToOverwrite.datasetid}__table`, ...(defaultVizType === 'table' && { - all_columns: query.results.selected_columns.map( - column => column.name, - ), + all_columns: datasource?.columns?.map(column => column.name), }), }), ]); @@ -175,7 +213,7 @@ export const SaveDatasetModal: FunctionComponent = ({ const url = mountExploreUrl(null, { [URL_PARAMS.formDataKey.name]: key, }); - window.open(url, '_blank', 'noreferrer'); + createWindow(url); setShouldOverwriteDataset(false); setDatasetName(getDefaultDatasetName()); @@ -225,35 +263,36 @@ export const SaveDatasetModal: FunctionComponent = ({ return; } - const selectedColumns = query.results.selected_columns || []; + const selectedColumns = datasource?.columns ?? []; // The filters param is only used to test jinja templates. // Remove the special filters entry from the templateParams // before saving the dataset. - if (query.templateParams) { - const p = JSON.parse(query.templateParams); + let templateParams; + if (isString(datasource?.templateParams)) { + const p = JSON.parse(datasource.templateParams); /* eslint-disable-next-line no-underscore-dangle */ if (p._filters) { /* eslint-disable-next-line no-underscore-dangle */ delete p._filters; // eslint-disable-next-line no-param-reassign - query.templateParams = JSON.stringify(p); + templateParams = JSON.stringify(p); } } dispatch( createDatasource({ - schema: query.schema, - sql: query.sql, - dbId: query.dbId, - templateParams: query.templateParams, + schema: datasource.schema, + sql: datasource.sql, + dbId: datasource.dbId, + templateParams, datasourceName: datasetName, columns: selectedColumns, }), ) .then((data: { table_id: number }) => postFormData(data.table_id, 'table', { - ...EXPLORE_CHART_DEFAULT, + ...formDataWithDefaults, datasource: `${data.table_id}__table`, ...(defaultVizType === 'table' && { all_columns: selectedColumns.map(column => column.name), @@ -264,7 +303,7 @@ export const SaveDatasetModal: FunctionComponent = ({ const url = mountExploreUrl(null, { [URL_PARAMS.formDataKey.name]: key, }); - window.open(url, '_blank', 'noreferrer'); + createWindow(url); }) .catch(() => { addDangerToast(t('An error occurred saving dataset')); diff --git a/superset-frontend/src/SqlLab/components/SaveQuery/index.tsx b/superset-frontend/src/SqlLab/components/SaveQuery/index.tsx index 2b57ce0cdcc5..77d3b04e88d0 100644 --- a/superset-frontend/src/SqlLab/components/SaveQuery/index.tsx +++ b/superset-frontend/src/SqlLab/components/SaveQuery/index.tsx @@ -173,7 +173,7 @@ export default function SaveQuery({ width="620px" show={showSave} title={

{t('Save query')}

} - footer={[ + footer={ <> )} - , - ]} + + } > {renderModalBody()} diff --git a/superset-frontend/src/SqlLab/components/ScheduleQueryButton/index.tsx b/superset-frontend/src/SqlLab/components/ScheduleQueryButton/index.tsx index 8da7ab7e663d..43437d026a10 100644 --- a/superset-frontend/src/SqlLab/components/ScheduleQueryButton/index.tsx +++ b/superset-frontend/src/SqlLab/components/ScheduleQueryButton/index.tsx @@ -16,13 +16,13 @@ * specific language governing permissions and limitations * under the License. */ -import React, { FunctionComponent, useState } from 'react'; +import React, { FunctionComponent, useState, useRef } from 'react'; import SchemaForm, { FormProps, FormValidation } from 'react-jsonschema-form'; import { Row, Col } from 'src/components'; import { Input, TextArea } from 'src/components/Input'; import { t, styled } from '@superset-ui/core'; import * as chrono from 'chrono-node'; -import ModalTrigger from 'src/components/ModalTrigger'; +import ModalTrigger, { ModalTriggerRef } from 'src/components/ModalTrigger'; import { Form, FormItem } from 'src/components/Form'; import Button from 'src/components/Button'; @@ -143,7 +143,7 @@ const ScheduleQueryButton: FunctionComponent = ({ const [description, setDescription] = useState(''); const [label, setLabel] = useState(defaultLabel); const [showSchedule, setShowSchedule] = useState(false); - let saveModal: ModalTrigger | null; + const saveModal: ModalTriggerRef | null = useRef() as ModalTriggerRef; const onScheduleSubmit = ({ formData, @@ -159,7 +159,7 @@ const ScheduleQueryButton: FunctionComponent = ({ extra_json: JSON.stringify({ schedule_info: formData }), }; onSchedule(query); - saveModal?.close(); + saveModal?.current?.close(); }; const renderModalBody = () => ( @@ -225,9 +225,7 @@ const ScheduleQueryButton: FunctionComponent = ({ return ( { - saveModal = ref; - }} + ref={saveModal} modalTitle={t('Schedule query')} modalBody={renderModalBody()} triggerNode={ diff --git a/superset-frontend/src/SqlLab/components/TableElement/index.tsx b/superset-frontend/src/SqlLab/components/TableElement/index.tsx index 6e5b18d28c4d..7dea2bd8340c 100644 --- a/superset-frontend/src/SqlLab/components/TableElement/index.tsx +++ b/superset-frontend/src/SqlLab/components/TableElement/index.tsx @@ -152,11 +152,7 @@ const TableElement = ({ table, actions, ...props }: TableElementProps) => { if (table?.indexes?.length) { keyLink = ( - {t('Keys for table')} {table.name} - - } + modalTitle={`${t('Keys for table')} ${table.name}`} modalBody={table.indexes.map((ix, i) => (
{JSON.stringify(ix, null, '  ')}
))} diff --git a/superset-frontend/src/SqlLab/types.ts b/superset-frontend/src/SqlLab/types.ts index 1b7b1d495cd8..0d54f9747667 100644 --- a/superset-frontend/src/SqlLab/types.ts +++ b/superset-frontend/src/SqlLab/types.ts @@ -16,14 +16,11 @@ * specific language governing permissions and limitations * under the License. */ -import { Dataset } from '@superset-ui/chart-controls'; import { JsonObject, Query, QueryResponse } from '@superset-ui/core'; import { SupersetError } from 'src/components/ErrorMessage/types'; import { UserWithPermissionsAndRoles } from 'src/types/bootstrapTypes'; import { ToastType } from 'src/components/MessageToasts/types'; -import { ExploreRootState } from 'src/explore/types'; - -export type ExploreDatasource = Dataset | QueryResponse; +import { RootState } from 'src/dashboard/types'; // Object as Dictionary (associative array) with Query id as the key and type Query as the value export type QueryDictionary = { @@ -74,7 +71,7 @@ export type SqlLabRootState = { }; }; -export type SqlLabExploreRootState = SqlLabRootState | ExploreRootState; +export type SqlLabExploreRootState = SqlLabRootState | RootState; export const getInitialState = (state: SqlLabExploreRootState) => { if (state.hasOwnProperty('sqlLab')) { @@ -84,10 +81,8 @@ export const getInitialState = (state: SqlLabExploreRootState) => { return user; } - const { - explore: { user }, - } = state as ExploreRootState; - return user; + const { user } = state as RootState; + return user as UserWithPermissionsAndRoles; }; export enum DatasetRadioState { diff --git a/superset-frontend/src/components/Chart/Chart.jsx b/superset-frontend/src/components/Chart/Chart.jsx index 624354f1b65e..7d02bf3452e0 100644 --- a/superset-frontend/src/components/Chart/Chart.jsx +++ b/superset-frontend/src/components/Chart/Chart.jsx @@ -203,7 +203,6 @@ class Chart extends React.PureComponent { height, datasetsStatus, } = this.props; - const error = queryResponse?.errors?.[0]; const message = chartAlert || queryResponse?.message; @@ -237,6 +236,7 @@ class Chart extends React.PureComponent { link={queryResponse ? queryResponse.link : null} source={dashboardId ? 'dashboard' : 'explore'} stackTrace={chartStackTrace} + errorMitigationFunction={this.toggleSaveDatasetModal} /> ); } diff --git a/superset-frontend/src/components/Chart/chartAction.js b/superset-frontend/src/components/Chart/chartAction.js index b9aeffec4a4b..139d91cd1d70 100644 --- a/superset-frontend/src/components/Chart/chartAction.js +++ b/superset-frontend/src/components/Chart/chartAction.js @@ -598,8 +598,12 @@ export function refreshChart(chartKey, force, dashboardId) { }; } -export const getDatasetSamples = async (datasetId, force) => { - const endpoint = `/api/v1/dataset/${datasetId}/samples?force=${force}`; +export const getDatasourceSamples = async ( + datasourceType, + datasourceId, + force, +) => { + const endpoint = `/api/v1/explore/samples?force=${force}&datasource_type=${datasourceType}&datasource_id=${datasourceId}`; try { const response = await SupersetClient.get({ endpoint }); return response.json.result; diff --git a/superset-frontend/src/components/ErrorMessage/ErrorMessageWithStackTrace.tsx b/superset-frontend/src/components/ErrorMessage/ErrorMessageWithStackTrace.tsx index c073cf0461b4..44bd560f39bd 100644 --- a/superset-frontend/src/components/ErrorMessage/ErrorMessageWithStackTrace.tsx +++ b/superset-frontend/src/components/ErrorMessage/ErrorMessageWithStackTrace.tsx @@ -18,7 +18,6 @@ */ import React from 'react'; import { t } from '@superset-ui/core'; - import getErrorMessageComponentRegistry from './getErrorMessageComponentRegistry'; import { SupersetError, ErrorSource } from './types'; import ErrorAlert from './ErrorAlert'; @@ -33,6 +32,7 @@ type Props = { copyText?: string; stackTrace?: string; source?: ErrorSource; + errorMitigationFunction?: () => void; }; export default function ErrorMessageWithStackTrace({ diff --git a/superset-frontend/src/components/Modal/Modal.tsx b/superset-frontend/src/components/Modal/Modal.tsx index 3bd21fef8075..c12e6f664202 100644 --- a/superset-frontend/src/components/Modal/Modal.tsx +++ b/superset-frontend/src/components/Modal/Modal.tsx @@ -247,7 +247,13 @@ const CustomModal = ({ const draggableRef = useRef(null); const [bounds, setBounds] = useState(); const [dragDisabled, setDragDisabled] = useState(true); - const modalFooter = isNil(footer) + let FooterComponent; + if (React.isValidElement(footer)) { + // If a footer component is provided inject a closeModal function + // so the footer can provide a "close" button if desired + FooterComponent = React.cloneElement(footer, { closeModal: onHide }); + } + const modalFooter = isNil(FooterComponent) ? [ , ] - : footer; + : FooterComponent; const modalWidth = width || (responsive ? '100vw' : '600px'); const shouldShowMask = !(resizable || draggable); diff --git a/superset-frontend/src/components/ModalTrigger/ModalTrigger.stories.tsx b/superset-frontend/src/components/ModalTrigger/ModalTrigger.stories.tsx index 9c0817c2dbc6..7ab85dd1c49e 100644 --- a/superset-frontend/src/components/ModalTrigger/ModalTrigger.stories.tsx +++ b/superset-frontend/src/components/ModalTrigger/ModalTrigger.stories.tsx @@ -20,11 +20,11 @@ import React from 'react'; import ModalTrigger from '.'; interface IModalTriggerProps { - triggerNode: React.ReactNode; + triggerNode: JSX.Element; dialogClassName?: string; - modalTitle?: React.ReactNode; - modalBody?: React.ReactNode; - modalFooter?: React.ReactNode; + modalTitle?: string; + modalBody?: JSX.Element; + modalFooter?: JSX.Element; beforeOpen?: () => void; onExit?: () => void; isButton?: boolean; diff --git a/superset-frontend/src/components/ModalTrigger/index.jsx b/superset-frontend/src/components/ModalTrigger/index.jsx deleted file mode 100644 index b15000d851fb..000000000000 --- a/superset-frontend/src/components/ModalTrigger/index.jsx +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -import React from 'react'; -import PropTypes from 'prop-types'; -import Modal from 'src/components/Modal'; -import Button from 'src/components/Button'; - -const propTypes = { - dialogClassName: PropTypes.string, - triggerNode: PropTypes.node.isRequired, - modalTitle: PropTypes.node, - modalBody: PropTypes.node, // not required because it can be generated by beforeOpen - modalFooter: PropTypes.node, - beforeOpen: PropTypes.func, - onExit: PropTypes.func, - isButton: PropTypes.bool, - className: PropTypes.string, - tooltip: PropTypes.string, - width: PropTypes.string, - maxWidth: PropTypes.string, - responsive: PropTypes.bool, - resizable: PropTypes.bool, - resizableConfig: PropTypes.object, - draggable: PropTypes.bool, - draggableConfig: PropTypes.object, - destroyOnClose: PropTypes.bool, -}; - -const defaultProps = { - beforeOpen: () => {}, - onExit: () => {}, - isButton: false, - className: '', - modalTitle: '', - resizable: false, - draggable: false, -}; - -export default class ModalTrigger extends React.Component { - constructor(props) { - super(props); - this.state = { - showModal: false, - }; - this.open = this.open.bind(this); - this.close = this.close.bind(this); - } - - close() { - this.setState(() => ({ showModal: false })); - } - - open(e) { - e.preventDefault(); - this.props.beforeOpen(); - this.setState(() => ({ showModal: true })); - } - - renderModal() { - return ( - - {this.props.modalBody} - - ); - } - - render() { - if (this.props.isButton) { - return ( - <> - - {this.renderModal()} - - ); - } - /* eslint-disable jsx-a11y/interactive-supports-focus */ - return ( - <> - - {this.props.triggerNode} - - {this.renderModal()} - - ); - } -} - -ModalTrigger.propTypes = propTypes; -ModalTrigger.defaultProps = defaultProps; diff --git a/superset-frontend/src/components/ModalTrigger/index.tsx b/superset-frontend/src/components/ModalTrigger/index.tsx new file mode 100644 index 000000000000..8b689d640b7c --- /dev/null +++ b/superset-frontend/src/components/ModalTrigger/index.tsx @@ -0,0 +1,130 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import React, { useState } from 'react'; +import Modal from 'src/components/Modal'; +import Button from 'src/components/Button'; + +interface ModalTriggerProps { + dialogClassName?: string; + triggerNode: React.ReactNode; + modalTitle?: string; + modalBody?: React.ReactNode; // not required because it can be generated by beforeOpen + modalFooter?: React.ReactNode; + beforeOpen?: Function; + onExit?: Function; + isButton?: boolean; + className?: string; + tooltip?: string; + width?: string; + maxWidth?: string; + responsive?: boolean; + resizable?: boolean; + resizableConfig?: any; + draggable?: boolean; + draggableConfig?: any; + destroyOnClose?: boolean; +} + +export interface ModalTriggerRef { + current: { + close: Function; + open: Function; + }; +} + +const ModalTrigger = React.forwardRef( + (props: ModalTriggerProps, ref: ModalTriggerRef | null) => { + const [showModal, setShowModal] = useState(false); + const { + beforeOpen = () => {}, + onExit = () => {}, + isButton = false, + resizable = false, + draggable = false, + className = '', + tooltip, + modalFooter, + triggerNode, + destroyOnClose = true, + modalBody, + draggableConfig = {}, + resizableConfig = {}, + modalTitle, + responsive, + width, + maxWidth, + } = props; + + const close = () => { + setShowModal(false); + onExit?.(); + }; + + const open = (e: React.MouseEvent) => { + e.preventDefault(); + beforeOpen?.(); + setShowModal(true); + }; + + if (ref) { + ref.current = { close, open }; // eslint-disable-line + } + + /* eslint-disable jsx-a11y/interactive-supports-focus */ + return ( + <> + {isButton && ( + + )} + {!isButton && ( + + {triggerNode} + + )} + + {modalBody} + + + ); + }, +); + +export default ModalTrigger; diff --git a/superset-frontend/src/dashboard/components/RefreshIntervalModal.tsx b/superset-frontend/src/dashboard/components/RefreshIntervalModal.tsx index 896792ceebf6..54d11bba1458 100644 --- a/superset-frontend/src/dashboard/components/RefreshIntervalModal.tsx +++ b/superset-frontend/src/dashboard/components/RefreshIntervalModal.tsx @@ -16,13 +16,13 @@ * specific language governing permissions and limitations * under the License. */ -import React, { RefObject } from 'react'; +import React from 'react'; import Select, { propertyComparator } from 'src/components/Select/Select'; import { t, styled } from '@superset-ui/core'; import Alert from 'src/components/Alert'; import Button from 'src/components/Button'; -import ModalTrigger from 'src/components/ModalTrigger'; +import ModalTrigger, { ModalTriggerRef } from 'src/components/ModalTrigger'; import { FormLabel } from 'src/components/Form'; export const options = [ @@ -71,11 +71,11 @@ class RefreshIntervalModal extends React.PureComponent< refreshWarning: null, }; - modalRef: RefObject; + modalRef: ModalTriggerRef | null; constructor(props: RefreshIntervalModalProps) { super(props); - this.modalRef = React.createRef(); + this.modalRef = React.createRef() as ModalTriggerRef; this.state = { refreshFrequency: props.refreshFrequency, }; @@ -86,7 +86,7 @@ class RefreshIntervalModal extends React.PureComponent< onSave() { this.props.onChange(this.state.refreshFrequency, this.props.editMode); - this.modalRef.current?.close(); + this.modalRef?.current?.close(); this.props.addSuccessToast(t('Refresh interval saved')); } @@ -94,7 +94,7 @@ class RefreshIntervalModal extends React.PureComponent< this.setState({ refreshFrequency: this.props.refreshFrequency, }); - this.modalRef.current?.close(); + this.modalRef?.current?.close(); } handleFrequencyChange(value: number) { diff --git a/superset-frontend/src/dashboard/components/SaveModal.tsx b/superset-frontend/src/dashboard/components/SaveModal.tsx index 913125a16800..3cbaa40ba353 100644 --- a/superset-frontend/src/dashboard/components/SaveModal.tsx +++ b/superset-frontend/src/dashboard/components/SaveModal.tsx @@ -24,7 +24,7 @@ import { Input } from 'src/components/Input'; import Button from 'src/components/Button'; import { t, JsonResponse } from '@superset-ui/core'; -import ModalTrigger from 'src/components/ModalTrigger'; +import ModalTrigger, { ModalTriggerRef } from 'src/components/ModalTrigger'; import Checkbox from 'src/components/Checkbox'; import { SAVE_TYPE_OVERWRITE, @@ -69,7 +69,7 @@ const defaultProps = { class SaveModal extends React.PureComponent { static defaultProps = defaultProps; - modal: ModalTrigger | null; + modal: ModalTriggerRef | null; onSave: ( data: Record, @@ -84,17 +84,13 @@ class SaveModal extends React.PureComponent { newDashName: `${props.dashboardTitle} [copy]`, duplicateSlices: false, }; - this.modal = null; + this.handleSaveTypeChange = this.handleSaveTypeChange.bind(this); this.handleNameChange = this.handleNameChange.bind(this); this.saveDashboard = this.saveDashboard.bind(this); - this.setModalRef = this.setModalRef.bind(this); this.toggleDuplicateSlices = this.toggleDuplicateSlices.bind(this); this.onSave = this.props.onSave.bind(this); - } - - setModalRef(ref: ModalTrigger | null) { - this.modal = ref; + this.modal = React.createRef() as ModalTriggerRef; } toggleDuplicateSlices(): void { @@ -166,14 +162,14 @@ class SaveModal extends React.PureComponent { window.location.href = `/superset/dashboard/${resp.json.id}/`; } }); - this.modal?.close(); + this.modal?.current?.close?.(); } } render() { return ( { - modal: RefObject; + modal: ModalTriggerRef; constructor(props: FilterScopeModalProps) { super(props); - this.modal = React.createRef(); + this.modal = React.createRef() as ModalTriggerRef; this.handleCloseModal = this.handleCloseModal.bind(this); } handleCloseModal(): void { - if (this.modal.current) { - this.modal.current.close(); - } + this?.modal?.current?.close?.(); } render() { diff --git a/superset-frontend/src/explore/components/DataTablesPane/components/SamplesPane.tsx b/superset-frontend/src/explore/components/DataTablesPane/components/SamplesPane.tsx index 8b1137334b9d..0d1047c51d28 100644 --- a/superset-frontend/src/explore/components/DataTablesPane/components/SamplesPane.tsx +++ b/superset-frontend/src/explore/components/DataTablesPane/components/SamplesPane.tsx @@ -25,7 +25,7 @@ import { useFilteredTableData, useTableColumns, } from 'src/explore/components/DataTableControl'; -import { getDatasetSamples } from 'src/components/Chart/chartAction'; +import { getDatasourceSamples } from 'src/components/Chart/chartAction'; import { TableControls } from './DataTableControls'; import { SamplesPaneProps } from '../types'; @@ -61,7 +61,7 @@ export const SamplesPane = ({ if (isRequest && !cache.has(datasource)) { setIsLoading(true); - getDatasetSamples(datasource.id, queryForce) + getDatasourceSamples(datasource.type, datasource.id, queryForce) .then(response => { setData(ensureIsArray(response.data)); setColnames(ensureIsArray(response.colnames)); diff --git a/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx b/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx index 54c04c6003ba..0aa0b03a0677 100644 --- a/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx +++ b/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx @@ -29,26 +29,35 @@ import { SamplesPane } from '../components'; import { createSamplesPaneProps } from './fixture'; describe('SamplesPane', () => { - fetchMock.get('end:/api/v1/dataset/34/samples?force=false', { - result: { - data: [], - colnames: [], - coltypes: [], + fetchMock.get( + 'end:/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=34', + { + result: { + data: [], + colnames: [], + coltypes: [], + }, }, - }); + ); - fetchMock.get('end:/api/v1/dataset/35/samples?force=true', { - result: { - data: [ - { __timestamp: 1230768000000, genre: 'Action' }, - { __timestamp: 1230768000010, genre: 'Horror' }, - ], - colnames: ['__timestamp', 'genre'], - coltypes: [2, 1], + fetchMock.get( + 'end:/api/v1/explore/samples?force=true&datasource_type=table&datasource_id=35', + { + result: { + data: [ + { __timestamp: 1230768000000, genre: 'Action' }, + { __timestamp: 1230768000010, genre: 'Horror' }, + ], + colnames: ['__timestamp', 'genre'], + coltypes: [2, 1], + }, }, - }); + ); - fetchMock.get('end:/api/v1/dataset/36/samples?force=false', 400); + fetchMock.get( + 'end:/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=36', + 400, + ); const setForceQuery = jest.spyOn(exploreActions, 'setForceQuery'); diff --git a/superset-frontend/src/explore/components/DatasourcePanel/DatasourcePanel.test.tsx b/superset-frontend/src/explore/components/DatasourcePanel/DatasourcePanel.test.tsx index a9495175d9b9..6e6eb6edd047 100644 --- a/superset-frontend/src/explore/components/DatasourcePanel/DatasourcePanel.test.tsx +++ b/superset-frontend/src/explore/components/DatasourcePanel/DatasourcePanel.test.tsx @@ -22,6 +22,7 @@ import { DndProvider } from 'react-dnd'; import { render, screen, waitFor } from 'spec/helpers/testing-library'; import userEvent from '@testing-library/user-event'; import DatasourcePanel, { + IDatasource, Props as DatasourcePanelProps, } from 'src/explore/components/DatasourcePanel'; import { @@ -31,25 +32,15 @@ import { import { DatasourceType } from '@superset-ui/core'; import DatasourceControl from 'src/explore/components/controls/DatasourceControl'; -const datasource = { +const datasource: IDatasource = { id: 1, type: DatasourceType.Table, - name: 'birth_names', columns, metrics, - uid: '1__table', database: { - backend: 'mysql', - name: 'main', + id: 1, }, - column_format: { ratio: '.2%' }, - verbose_map: { __timestamp: 'Time' }, - main_dttm_col: 'None', datasource_name: 'table1', - description: 'desc', - owners: [ - { first_name: 'admin', last_name: 'admin', username: 'admin', id: 1 }, - ], }; const mockUser = { @@ -100,7 +91,6 @@ test('should render', () => { test('should display items in controls', () => { render(setup(props), { useRedux: true }); - expect(screen.getByText('birth_names')).toBeInTheDocument(); expect(screen.getByText('Metrics')).toBeInTheDocument(); expect(screen.getByText('Columns')).toBeInTheDocument(); }); @@ -238,6 +228,5 @@ test('should not render a save dataset modal when datasource is not query or dat }), { useRedux: true }, ); - expect(screen.queryByText(/create a dataset/i)).toBe(null); }); diff --git a/superset-frontend/src/explore/components/DatasourcePanel/fixtures.tsx b/superset-frontend/src/explore/components/DatasourcePanel/fixtures.tsx index a4a893fae08f..05909b1e64fc 100644 --- a/superset-frontend/src/explore/components/DatasourcePanel/fixtures.tsx +++ b/superset-frontend/src/explore/components/DatasourcePanel/fixtures.tsx @@ -16,10 +16,9 @@ * specific language governing permissions and limitations * under the License. */ -import { ColumnMeta } from '@superset-ui/chart-controls'; import { GenericDataType } from '@superset-ui/core'; -export const columns: ColumnMeta[] = [ +export const columns = [ { column_name: 'bootcamp_attend', description: null, diff --git a/superset-frontend/src/explore/components/DatasourcePanel/index.tsx b/superset-frontend/src/explore/components/DatasourcePanel/index.tsx index 335d2f2ada41..dedffc104455 100644 --- a/superset-frontend/src/explore/components/DatasourcePanel/index.tsx +++ b/superset-frontend/src/explore/components/DatasourcePanel/index.tsx @@ -17,39 +17,69 @@ * under the License. */ import React, { useEffect, useMemo, useRef, useState } from 'react'; -import { css, styled, t, DatasourceType } from '@superset-ui/core'; import { - ControlConfig, - Dataset, - ColumnMeta, -} from '@superset-ui/chart-controls'; -import { debounce } from 'lodash'; + css, + styled, + t, + DatasourceType, + Metric, + QueryFormData, +} from '@superset-ui/core'; + +import { ControlConfig, ColumnMeta } from '@superset-ui/chart-controls'; + +import { debounce, isArray } from 'lodash'; import { matchSorter, rankings } from 'match-sorter'; import Collapse from 'src/components/Collapse'; import Alert from 'src/components/Alert'; import { SaveDatasetModal } from 'src/SqlLab/components/SaveDatasetModal'; +import { getDatasourceAsSaveableDataset } from 'src/utils/datasourceUtils'; import { Input } from 'src/components/Input'; import { FAST_DEBOUNCE } from 'src/constants'; import { FeatureFlag, isFeatureEnabled } from 'src/featureFlags'; import { ExploreActions } from 'src/explore/actions/exploreActions'; import Control from 'src/explore/components/Control'; -import { ExploreDatasource } from 'src/SqlLab/types'; import DatasourcePanelDragOption from './DatasourcePanelDragOption'; import { DndItemType } from '../DndItemType'; import { StyledColumnOption, StyledMetricOption } from '../optionRenderers'; +import { DndItemValue } from './types'; interface DatasourceControl extends ControlConfig { - datasource?: ExploreDatasource; + datasource?: IDatasource; +} + +export interface DataSourcePanelColumn { + is_dttm?: boolean | null; + description?: string | null; + expression?: string | null; + is_certified?: number | null; + column_name?: string | null; + name?: string | null; + type?: string; +} +export interface IDatasource { + metrics: Metric[]; + columns: DataSourcePanelColumn[]; + id: number; + type: DatasourceType; + database: { + id: number; + }; + sql?: string | null; + datasource_name?: string | null; + name?: string | null; + schema?: string | null; } export interface Props { - datasource: Dataset; + datasource: IDatasource; controls: { datasource: DatasourceControl; }; actions: Partial & Pick; // we use this props control force update when this panel resize shouldForceUpdate?: number; + formData?: QueryFormData; } const enableExploreDnd = isFeatureEnabled( @@ -183,20 +213,20 @@ const LabelContainer = (props: { export default function DataSourcePanel({ datasource, + formData, controls: { datasource: datasourceControl }, actions, shouldForceUpdate, }: Props) { const { columns: _columns, metrics } = datasource; - // display temporal column first const columns = useMemo( () => - [..._columns].sort((col1, col2) => { - if (col1.is_dttm && !col2.is_dttm) { + [...(isArray(_columns) ? _columns : [])].sort((col1, col2) => { + if (col1?.is_dttm && !col2?.is_dttm) { return -1; } - if (col2.is_dttm && !col1.is_dttm) { + if (col2?.is_dttm && !col1?.is_dttm) { return 1; } return 0; @@ -236,7 +266,7 @@ export default function DataSourcePanel({ }, { key: item => - [item.description, item.expression].map( + [item?.description ?? '', item?.expression ?? ''].map( x => x?.replace(/[_\n\s]+/g, ' ') || '', ), threshold: rankings.CONTAINS, @@ -257,7 +287,7 @@ export default function DataSourcePanel({ }, { key: item => - [item.description, item.expression].map( + [item?.description ?? '', item?.expression ?? ''].map( x => x?.replace(/[_\n\s]+/g, ' ') || '', ), threshold: rankings.CONTAINS, @@ -266,8 +296,9 @@ export default function DataSourcePanel({ ], keepDiacritics: true, baseSort: (a, b) => - Number(b.item.is_certified) - Number(a.item.is_certified) || - String(a.rankedValue).localeCompare(b.rankedValue), + Number(b?.item?.is_certified ?? 0) - + Number(a?.item?.is_certified ?? 0) || + String(a?.rankedValue ?? '').localeCompare(b?.rankedValue ?? ''), }), }); }, FAST_DEBOUNCE), @@ -282,23 +313,23 @@ export default function DataSourcePanel({ setInputValue(''); }, [columns, datasource, metrics]); - const sortCertifiedFirst = (slice: ColumnMeta[]) => - slice.sort((a, b) => b.is_certified - a.is_certified); + const sortCertifiedFirst = (slice: DataSourcePanelColumn[]) => + slice.sort((a, b) => (b?.is_certified ?? 0) - (a?.is_certified ?? 0)); const metricSlice = useMemo( () => showAllMetrics - ? lists.metrics - : lists.metrics.slice(0, DEFAULT_MAX_METRICS_LENGTH), - [lists.metrics, showAllMetrics], + ? lists?.metrics + : lists?.metrics?.slice?.(0, DEFAULT_MAX_METRICS_LENGTH), + [lists?.metrics, showAllMetrics], ); const columnSlice = useMemo( () => showAllColumns - ? sortCertifiedFirst(lists.columns) + ? sortCertifiedFirst(lists?.columns) : sortCertifiedFirst( - lists.columns.slice(0, DEFAULT_MAX_COLUMNS_LENGTH), + lists?.columns?.slice?.(0, DEFAULT_MAX_COLUMNS_LENGTH), ), [lists.columns, showAllColumns], ); @@ -308,6 +339,14 @@ export default function DataSourcePanel({ return true; }; + const saveableDatasets = { + query: DatasourceType.Query, + saved_query: DatasourceType.SavedQuery, + }; + + const datasourceIsSaveable = + datasource.type && saveableDatasets[datasource.type]; + const mainBody = useMemo( () => ( <> @@ -322,7 +361,7 @@ export default function DataSourcePanel({ placeholder={t('Search Metrics & Columns')} />
- {datasource.type === DatasourceType.Query && showInfoboxCheck() && ( + {datasourceIsSaveable && showInfoboxCheck() && ( - {t('Metrics')}} - key="metrics" - > -
- {t( - `Showing %s of %s`, - metricSlice.length, - lists.metrics.length, - )} -
- {metricSlice.map(m => ( - - {enableExploreDnd ? ( - - ) : ( - + {metrics?.length && ( + {t('Metrics')}} + key="metrics" + > +
+ {t( + `Showing %s of %s`, + metricSlice?.length, + lists?.metrics.length, )} - - ))} - {lists.metrics.length > DEFAULT_MAX_METRICS_LENGTH ? ( - - - - ) : ( - <> - )} - +
+ {metricSlice?.map?.((m: Metric) => ( + + {enableExploreDnd ? ( + + ) : ( + + )} + + ))} + {lists?.metrics?.length > DEFAULT_MAX_METRICS_LENGTH ? ( + + + + ) : ( + <> + )} +
+ )} {t('Columns')}} key="column" @@ -404,11 +445,11 @@ export default function DataSourcePanel({ > {enableExploreDnd ? ( ) : ( - + )}
))} @@ -426,28 +467,34 @@ export default function DataSourcePanel({
), + // eslint-disable-next-line react-hooks/exhaustive-deps [ columnSlice, inputValue, lists.columns.length, - lists.metrics.length, + lists?.metrics?.length, metricSlice, search, showAllColumns, showAllMetrics, + datasourceIsSaveable, shouldForceUpdate, ], ); return ( - setShowSaveDatasetModal(false)} - buttonTextOnSave={t('Save')} - buttonTextOnOverwrite={t('Overwrite')} - datasource={datasource} - /> + {datasourceIsSaveable && showSaveDatasetModal && ( + setShowSaveDatasetModal(false)} + buttonTextOnSave={t('Save')} + buttonTextOnOverwrite={t('Overwrite')} + datasource={getDatasourceAsSaveableDataset(datasource)} + openWindow={false} + formData={formData} + /> + )} {datasource.id != null && mainBody} diff --git a/superset-frontend/src/explore/components/ExploreAlert.tsx b/superset-frontend/src/explore/components/ExploreAlert.tsx index 34c4cf070e30..0602fc903501 100644 --- a/superset-frontend/src/explore/components/ExploreAlert.tsx +++ b/superset-frontend/src/explore/components/ExploreAlert.tsx @@ -28,7 +28,7 @@ interface ControlPanelAlertProps { secondaryButtonAction?: (e: React.MouseEvent) => void; primaryButtonText?: string; secondaryButtonText?: string; - type: 'info' | 'warning'; + type: 'info' | 'warning' | 'error'; className?: string; } @@ -85,6 +85,11 @@ const Title = styled.p` font-weight: ${({ theme }) => theme.typography.weights.bold}; `; +const typeChart = { + warning: 'warning', + danger: 'danger', +}; + export const ExploreAlert = forwardRef( ( { @@ -114,7 +119,7 @@ export const ExploreAlert = forwardRef( )} + + + + ); +}; + +export default ViewQueryModalFooter; diff --git a/superset-frontend/src/explore/components/useExploreAdditionalActionsMenu/index.jsx b/superset-frontend/src/explore/components/useExploreAdditionalActionsMenu/index.jsx index 0740bd16c968..e75e91bba763 100644 --- a/superset-frontend/src/explore/components/useExploreAdditionalActionsMenu/index.jsx +++ b/superset-frontend/src/explore/components/useExploreAdditionalActionsMenu/index.jsx @@ -107,7 +107,6 @@ export const useExploreAdditionalActionsMenu = ( ); const { datasource } = latestQueryFormData; - const sqlSupported = datasource && datasource.split('__')[1] === 'table'; const shareByEmail = useCallback(async () => { try { @@ -356,7 +355,7 @@ export const useExploreAdditionalActionsMenu = ( responsive /> - {sqlSupported && ( + {datasource && ( {t('Run in SQL Lab')} @@ -373,7 +372,6 @@ export const useExploreAdditionalActionsMenu = ( openSubmenus, showReportSubMenu, slice, - sqlSupported, theme.gridUnit, ], ); diff --git a/superset-frontend/src/explore/types.ts b/superset-frontend/src/explore/types.ts index 85b161058aa6..8518f51097fc 100644 --- a/superset-frontend/src/explore/types.ts +++ b/superset-frontend/src/explore/types.ts @@ -30,7 +30,6 @@ import { } from '@superset-ui/chart-controls'; import { DatabaseObject } from 'src/views/CRUD/types'; import { UserWithPermissionsAndRoles } from 'src/types/bootstrapTypes'; -import { toastState } from 'src/SqlLab/types'; import { Slice } from 'src/types/Chart'; export type ChartStatus = @@ -69,32 +68,6 @@ export type Datasource = Dataset & { is_sqllab_view?: boolean; }; -export type ExploreRootState = { - explore: { - can_add: boolean; - can_download: boolean; - common: object; - controls: object; - controlsTransferred: object; - datasource: object; - datasource_id: number; - datasource_type: string; - force: boolean; - forced_height: object; - form_data: object; - isDatasourceMetaLoading: boolean; - isStarred: boolean; - slice: object; - sliceName: string; - standalone: boolean; - timeFormattedColumns: object; - user: UserWithPermissionsAndRoles; - }; - localStorageUsageInKilobytes: number; - messageToasts: toastState[]; - common: {}; -}; - export interface ExplorePageInitialData { dataset: Dataset; form_data: QueryFormData; diff --git a/superset-frontend/src/reduxUtils.ts b/superset-frontend/src/reduxUtils.ts index 245abc23e540..2cf43fa71222 100644 --- a/superset-frontend/src/reduxUtils.ts +++ b/superset-frontend/src/reduxUtils.ts @@ -136,10 +136,11 @@ export function extendArr( export function initEnhancer( persist = true, persistConfig: { paths?: StorageAdapter; config?: string } = {}, + disableDebugger = false, ) { const { paths, config } = persistConfig; const composeEnhancers = - process.env.WEBPACK_MODE === 'development' + process.env.WEBPACK_MODE === 'development' && disableDebugger !== true ? /* eslint-disable-next-line no-underscore-dangle, dot-notation */ window['__REDUX_DEVTOOLS_EXTENSION_COMPOSE__'] ? /* eslint-disable-next-line no-underscore-dangle, dot-notation */ diff --git a/superset-frontend/src/utils/datasourceUtils.js b/superset-frontend/src/utils/datasourceUtils.js new file mode 100644 index 000000000000..4a469e51095a --- /dev/null +++ b/superset-frontend/src/utils/datasourceUtils.js @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +export const getDatasourceAsSaveableDataset = source => ({ + columns: source.columns, + name: source?.datasource_name || 'Untitled', + dbId: source.database.id, + sql: source?.sql || '', + schema: source?.schema, +}); diff --git a/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.tsx b/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.tsx index 694b49055700..e8d7b5e201e9 100644 --- a/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.tsx +++ b/superset-frontend/src/views/CRUD/data/query/QueryPreviewModal.tsx @@ -115,32 +115,34 @@ function QueryPreviewModal({ onHide={onHide} show={show} title={t('Query preview')} - footer={[ - , - , - , - ]} + footer={ + <> + + + + + } > {t('Tab name')} {query.tab_name} diff --git a/superset-frontend/src/views/CRUD/data/savedquery/SavedQueryPreviewModal.tsx b/superset-frontend/src/views/CRUD/data/savedquery/SavedQueryPreviewModal.tsx index e8250d0fb7f8..29efb634a49e 100644 --- a/superset-frontend/src/views/CRUD/data/savedquery/SavedQueryPreviewModal.tsx +++ b/superset-frontend/src/views/CRUD/data/savedquery/SavedQueryPreviewModal.tsx @@ -95,32 +95,34 @@ const SavedQueryPreviewModal: FunctionComponent = onHide={onHide} show={show} title={t('Query preview')} - footer={[ - , - , - , - ]} + footer={ + <> + + + + + } > {t('Query name')} {savedQuery.label} diff --git a/superset-frontend/src/views/components/RightMenu.tsx b/superset-frontend/src/views/components/RightMenu.tsx index 5473fa069c28..516d920739c9 100644 --- a/superset-frontend/src/views/components/RightMenu.tsx +++ b/superset-frontend/src/views/components/RightMenu.tsx @@ -102,8 +102,8 @@ const RightMenu = ({ const dashboardId = useSelector( state => state.dashboardInfo?.id, ); - - const { roles } = user; + const userValues = user || {}; + const { roles } = userValues; const { CSV_EXTENSIONS, COLUMNAR_EXTENSIONS, @@ -287,7 +287,7 @@ const RightMenu = ({ } icon={} > - {dropdownItems.map(menu => { + {dropdownItems?.map?.(menu => { const canShowChild = menu.childs?.some( item => typeof item === 'object' && !!item.perm, ); @@ -299,7 +299,7 @@ const RightMenu = ({ className="data-menu" title={menuIconAndLabel(menu)} > - {menu.childs.map((item, idx) => + {menu?.childs?.map?.((item, idx) => typeof item !== 'string' && item.name && item.perm ? ( {idx === 2 && } @@ -348,9 +348,9 @@ const RightMenu = ({ title={t('Settings')} icon={} > - {settings.map((section, index) => [ + {settings?.map?.((section, index) => [ - {section.childs?.map(child => { + {section?.childs?.map?.(child => { if (typeof child !== 'string') { return ( diff --git a/superset-frontend/src/views/menu.tsx b/superset-frontend/src/views/menu.tsx index 166516569198..d29f9a3aee36 100644 --- a/superset-frontend/src/views/menu.tsx +++ b/superset-frontend/src/views/menu.tsx @@ -28,8 +28,11 @@ import Menu from 'src/views/components/Menu'; import { theme } from 'src/preamble'; import { Provider } from 'react-redux'; -import { store } from './store'; +import { setupStore } from './store'; +// Disable connecting to redux debugger so that the React app injected +// Below the menu like SqlLab or Explore can conect its redux store to the debugger +const store = setupStore(true); const container = document.getElementById('app'); const bootstrapJson = container?.getAttribute('data-bootstrap') ?? '{}'; const bootstrap = JSON.parse(bootstrapJson); diff --git a/superset-frontend/src/views/store.ts b/superset-frontend/src/views/store.ts index 9c3b1f625ddc..284a8d966ba8 100644 --- a/superset-frontend/src/views/store.ts +++ b/superset-frontend/src/views/store.ts @@ -16,7 +16,13 @@ * specific language governing permissions and limitations * under the License. */ -import { applyMiddleware, combineReducers, compose, createStore } from 'redux'; +import { + applyMiddleware, + combineReducers, + compose, + createStore, + Store, +} from 'redux'; import thunk from 'redux-thunk'; import messageToastReducer from 'src/components/MessageToasts/reducers'; import { initEnhancer } from 'src/reduxUtils'; @@ -86,8 +92,26 @@ export const rootReducer = combineReducers({ explore, }); -export const store = createStore( +export const store: Store = createStore( rootReducer, {}, compose(applyMiddleware(thunk, logger), initEnhancer(false)), ); + +/* In some cases the jinja template injects two seperate React apps into basic.html + * One for the top navigation Menu and one for the application below the Menu + * The first app to connect to the Redux debugger wins which is the menu blocking + * the application from being able to connect to the redux debugger. + * setupStore with disableDebugger true enables the menu.tsx component to avoid connecting + * to redux debugger so the application can connect to redux debugger + */ +export function setupStore(disableDegugger = false): Store { + return createStore( + rootReducer, + {}, + compose( + applyMiddleware(thunk, logger), + initEnhancer(false, undefined, disableDegugger), + ), + ); +} diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index 5899757d528d..0764e19340c9 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -150,7 +150,13 @@ def _get_samples( query_obj.orderby = [] query_obj.metrics = None query_obj.post_processing = [] - query_obj.columns = [o.column_name for o in datasource.columns] + qry_obj_cols = [] + for o in datasource.columns: + if isinstance(o, dict): + qry_obj_cols.append(o.get("column_name")) + else: + qry_obj_cols.append(o.column_name) + query_obj.columns = qry_obj_cols query_obj.from_dttm = None query_obj.to_dttm = None return _get_full(query_context, query_obj, force_cached) diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 1e1d16985ad4..dc43d28de9d5 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -82,6 +82,9 @@ def create( # pylint: disable=no-self-use def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: + return DatasourceDAO.get_datasource( - db.session, DatasourceType(datasource["type"]), int(datasource["id"]) + session=db.session, + datasource_type=DatasourceType(datasource["type"]), + datasource_id=int(datasource["id"]), ) diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 4f6e1910fb39..3b174dc7a217 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -116,6 +116,7 @@ def get_df_payload( and col != DTTM_ALIAS ) ] + if invalid_columns: raise QueryObjectValidationError( _( @@ -123,6 +124,7 @@ def get_df_payload( invalid_columns=invalid_columns, ) ) + query_result = self.get_query_result(query_obj) annotation_data = self.get_annotation_data(query_obj) cache.set_query_result( @@ -183,8 +185,16 @@ def get_query_result(self, query_object: QueryObject) -> QueryResult: # support multiple queries from different data sources. # The datasource here can be different backend but the interface is common - result = query_context.datasource.query(query_object.to_dict()) - query = result.query + ";\n\n" + # pylint: disable=import-outside-toplevel + from superset.models.sql_lab import Query + + query = "" + if isinstance(query_context.datasource, Query): + # todo(hugh): add logic to manage all sip68 models here + result = query_context.datasource.exc_query(query_object.to_dict()) + else: + result = query_context.datasource.query(query_object.to_dict()) + query = result.query + ";\n\n" df = result.df # Transform the timestamp we received from database to pandas supported diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index b4f4ec25c451..db9b15dc4164 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -61,7 +61,6 @@ from superset import security_manager, sql_parse from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.models.sql_lab import Query from superset.sql_parse import ParsedQuery, Table from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils @@ -73,6 +72,7 @@ # prevent circular imports from superset.connectors.sqla.models import TableColumn from superset.models.core import Database + from superset.models.sql_lab import Query ColumnTypeMapping = Tuple[ Pattern[str], @@ -885,7 +885,7 @@ def get_all_datasource_names( return all_datasources @classmethod - def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: + def handle_cursor(cls, cursor: Any, query: "Query", session: Session) -> None: """Handle a live cursor between the execute and fetchall calls The flow works without this method doing anything, but it allows @@ -1501,7 +1501,7 @@ def has_implicit_cancel(cls) -> bool: def get_cancel_query_id( # pylint: disable=unused-argument cls, cursor: Any, - query: Query, + query: "Query", ) -> Optional[str]: """ Select identifiers from the database engine that uniquely identifies the @@ -1519,7 +1519,7 @@ def get_cancel_query_id( # pylint: disable=unused-argument def cancel_query( # pylint: disable=unused-argument cls, cursor: Any, - query: Query, + query: "Query", cancel_query_id: str, ) -> bool: """ diff --git a/superset/explore/api.py b/superset/explore/api.py index 7cce592d361f..237eb67dbbe7 100644 --- a/superset/explore/api.py +++ b/superset/explore/api.py @@ -16,14 +16,22 @@ # under the License. import logging -from flask import g, request, Response +import simplejson +from flask import g, make_response, request, Response from flask_appbuilder.api import BaseApi, expose, protect, safe from superset.charts.commands.exceptions import ChartNotFoundError from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod +from superset.dao.exceptions import DatasourceNotFound from superset.explore.commands.get import GetExploreCommand from superset.explore.commands.parameters import CommandParameters -from superset.explore.exceptions import DatasetAccessDeniedError, WrongEndpointError +from superset.explore.commands.samples import SamplesDatasourceCommand +from superset.explore.exceptions import ( + DatasetAccessDeniedError, + DatasourceForbiddenError, + DatasourceSamplesFailedError, + WrongEndpointError, +) from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError from superset.explore.schemas import ExploreContextSchema from superset.extensions import event_logger @@ -31,13 +39,16 @@ TemporaryCacheAccessDeniedError, TemporaryCacheResourceNotFoundError, ) +from superset.utils.core import json_int_dttm_ser, parse_boolean_string logger = logging.getLogger(__name__) class ExploreRestApi(BaseApi): method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP - include_route_methods = {RouteMethod.GET} + include_route_methods = {RouteMethod.GET} | { + "samples", + } allow_browser_login = True class_permission_name = "Explore" resource_name = "explore" @@ -135,3 +146,70 @@ def get(self) -> Response: return self.response(403, message=str(ex)) except TemporaryCacheResourceNotFoundError as ex: return self.response(404, message=str(ex)) + + @expose("/samples", methods=["GET"]) + @protect() + @safe + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.samples", + log_to_statsd=False, + ) + def samples(self) -> Response: + """get samples from a Datasource + --- + get: + description: >- + get samples from a Datasource + parameters: + - in: path + schema: + type: integer + name: pk + - in: query + schema: + type: boolean + name: force + responses: + 200: + description: Datasource samples + content: + application/json: + schema: + type: object + properties: + result: + $ref: '#/components/schemas/ChartDataResponseResult' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + force = parse_boolean_string(request.args.get("force")) + rv = SamplesDatasourceCommand( + user=g.user, + datasource_type=request.args.get("datasource_type", type=str), + datasource_id=request.args.get("datasource_id", type=int), + force=force, + ).run() + + response_data = simplejson.dumps( + {"result": rv}, + default=json_int_dttm_ser, + ignore_nan=True, + ) + resp = make_response(response_data, 200) + resp.headers["Content-Type"] = "application/json; charset=utf-8" + return resp + except DatasourceNotFound: + return self.response_404() + except DatasourceForbiddenError: + return self.response_403() + except DatasourceSamplesFailedError as ex: + return self.response_400(message=str(ex)) diff --git a/superset/explore/commands/get.py b/superset/explore/commands/get.py index 80c8744fd926..1b9ec433daba 100644 --- a/superset/explore/commands/get.py +++ b/superset/explore/commands/get.py @@ -38,7 +38,6 @@ ) from superset.explore.permalink.commands.get import GetExplorePermalinkCommand from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError -from superset.models.sql_lab import Query from superset.utils import core as utils from superset.views.utils import ( get_datasource_info, @@ -89,7 +88,9 @@ def run(self) -> Optional[Dict[str, Any]]: "Form data not found in cache, reverting to chart metadata." ) elif self._dataset_id: - initial_form_data["datasource"] = f"{self._dataset_id}__table" + initial_form_data[ + "datasource" + ] = f"{self._dataset_id}__{self._dataset_type}" if self._form_data_key: message = _( "Form data not found in cache, reverting to dataset metadata." @@ -152,11 +153,6 @@ def run(self) -> Optional[Dict[str, Any]]: except (SupersetException, SQLAlchemyError): dataset_data = dummy_dataset_data - if dataset: - dataset_data["owners"] = dataset.owners_data - if isinstance(dataset, Query): - dataset_data["columns"] = dataset.columns - return { "dataset": sanitize_datasource_data(dataset_data), "form_data": form_data, diff --git a/superset/explore/commands/samples.py b/superset/explore/commands/samples.py new file mode 100644 index 000000000000..7fda5c1bc150 --- /dev/null +++ b/superset/explore/commands/samples.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Dict, Optional + +from flask_appbuilder.security.sqla.models import User + +from superset import db, security_manager +from superset.commands.base import BaseCommand +from superset.common.chart_data import ChartDataResultType +from superset.common.query_context_factory import QueryContextFactory +from superset.common.utils.query_cache_manager import QueryCacheManager +from superset.constants import CacheRegion +from superset.dao.exceptions import DatasourceNotFound +from superset.datasource.dao import Datasource, DatasourceDAO +from superset.exceptions import SupersetSecurityException +from superset.explore.exceptions import ( + DatasourceForbiddenError, + DatasourceSamplesFailedError, +) +from superset.utils.core import DatasourceType, QueryStatus + +logger = logging.getLogger(__name__) + + +class SamplesDatasourceCommand(BaseCommand): + def __init__( + self, + user: User, + datasource_id: Optional[int], + datasource_type: Optional[str], + force: bool, + ): + self._actor = user + self._datasource_id = datasource_id + self._datasource_type = datasource_type + self._force = force + self._model: Optional[Datasource] = None + + def run(self) -> Dict[str, Any]: + self.validate() + if not self._model: + raise DatasourceNotFound() + + qc_instance = QueryContextFactory().create( + datasource={ + "type": self._model.type, + "id": self._model.id, + }, + queries=[{}], + result_type=ChartDataResultType.SAMPLES, + force=self._force, + ) + results = qc_instance.get_payload() + try: + sample_data = results["queries"][0] + error_msg = sample_data.get("error") + if sample_data.get("status") == QueryStatus.FAILED and error_msg: + cache_key = sample_data.get("cache_key") + QueryCacheManager.delete(cache_key, region=CacheRegion.DATA) + raise DatasourceSamplesFailedError(error_msg) + return sample_data + except (IndexError, KeyError) as exc: + raise DatasourceSamplesFailedError from exc + + def validate(self) -> None: + # Validate/populate model exists + if self._datasource_type and self._datasource_id: + self._model = DatasourceDAO.get_datasource( + session=db.session, + datasource_type=DatasourceType(self._datasource_type), + datasource_id=self._datasource_id, + ) + + # Check ownership + try: + security_manager.raise_for_ownership(self._model) + except SupersetSecurityException as ex: + raise DatasourceForbiddenError() from ex diff --git a/superset/explore/exceptions.py b/superset/explore/exceptions.py index c6b46a66f5a4..ff375de147d7 100644 --- a/superset/explore/exceptions.py +++ b/superset/explore/exceptions.py @@ -16,7 +16,13 @@ # under the License. from typing import Optional -from superset.commands.exceptions import CommandException, ForbiddenError +from flask_babel import lazy_gettext as _ + +from superset.commands.exceptions import ( + CommandException, + CommandInvalidError, + ForbiddenError, +) class DatasetAccessDeniedError(ForbiddenError): @@ -33,3 +39,11 @@ class WrongEndpointError(CommandException): def __init__(self, redirect: str) -> None: self.redirect = redirect super().__init__() + + +class DatasourceSamplesFailedError(CommandInvalidError): + message = _("Samples for datasource could not be retrieved.") + + +class DatasourceForbiddenError(ForbiddenError): + message = _("Changing this datasource is forbidden") diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 051f8fb7c3bf..77707201b507 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -15,35 +15,123 @@ # specific language governing permissions and limitations # under the License. """a collection of model-related helper classes and functions""" +# pylint: disable=too-many-lines import json import logging import re import uuid from datetime import datetime, timedelta from json.decoder import JSONDecodeError -from typing import Any, Dict, List, Optional, Set, Union - +from typing import ( + Any, + cast, + Dict, + List, + Mapping, + NamedTuple, + Optional, + Set, + Text, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +import dateutil.parser import humanize +import numpy as np import pandas as pd import pytz import sqlalchemy as sa +import sqlparse import yaml from flask import escape, g, Markup from flask_appbuilder import Model from flask_appbuilder.models.decorators import renders from flask_appbuilder.models.mixins import AuditMixin from flask_appbuilder.security.sqla.models import User -from sqlalchemy import and_, or_, UniqueConstraint +from flask_babel import lazy_gettext as _ +from jinja2.exceptions import TemplateError +from sqlalchemy import and_, Column, or_, UniqueConstraint from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import Mapper, Session from sqlalchemy.orm.exc import MultipleResultsFound +from sqlalchemy.sql.elements import ColumnElement, literal_column, TextClause +from sqlalchemy.sql.expression import Label, Select, TextAsFrom +from sqlalchemy.sql.selectable import Alias, TableClause from sqlalchemy_utils import UUIDType +from superset import app, is_feature_enabled, security_manager +from superset.advanced_data_type.types import AdvancedDataTypeResponse from superset.common.db_query_status import QueryStatus +from superset.constants import EMPTY_STRING, NULL_STRING +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import ( + AdvancedDataTypeResponseError, + QueryClauseValidationException, + QueryObjectValidationError, + SupersetSecurityException, +) +from superset.extensions import feature_flag_manager +from superset.jinja_context import BaseTemplateProcessor +from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, sanitize_clause +from superset.superset_typing import ( + AdhocMetric, + FilterValue, + FilterValues, + Metric, + OrderBy, + QueryObjectDict, +) +from superset.utils import core as utils from superset.utils.core import get_user_id +if TYPE_CHECKING: + from superset.connectors.sqla.models import SqlMetric, TableColumn + from superset.db_engine_specs import BaseEngineSpec + from superset.models.core import Database + + +config = app.config logger = logging.getLogger(__name__) +CTE_ALIAS = "__cte" +VIRTUAL_TABLE_ALIAS = "virtual_table" +ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"] + + +def validate_adhoc_subquery( + sql: str, + database_id: int, + default_schema: str, +) -> str: + """ + Check if adhoc SQL contains sub-queries or nested sub-queries with table. + + If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS + predicates to it. + + :param sql: adhoc sql expression + :raise SupersetSecurityException if sql contains sub-queries or + nested sub-queries with table + """ + statements = [] + for statement in sqlparse.parse(sql): + if has_table_query(statement): + if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): + raise SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR, + message=_("Custom SQL fields cannot contain sub-queries."), + level=ErrorLevel.ERROR, + ) + ) + statement = insert_rls(statement, database_id, default_schema) + statements.append(statement) + + return ";\n".join(str(statement) for statement in statements) + def json_to_dict(json_str: str) -> Dict[Any, Any]: if json_str: @@ -544,3 +632,1113 @@ def clone_model( data.update(kwargs) return target.__class__(**data) + + +# todo(hugh): centralize where this code lives +class QueryStringExtended(NamedTuple): + applied_template_filters: Optional[List[str]] + labels_expected: List[str] + prequeries: List[str] + sql: str + + +class SqlaQuery(NamedTuple): + applied_template_filters: List[str] + cte: Optional[str] + extra_cache_keys: List[Any] + labels_expected: List[str] + prequeries: List[str] + sqla_query: Select + + +class ExploreMixin: # pylint: disable=too-many-public-methods + """ + Allows any flask_appbuilder.Model (Query, Table, etc.) + to be used to power a chart inside /explore + """ + + sqla_aggregations = { + "COUNT_DISTINCT": lambda column_name: sa.func.COUNT(sa.distinct(column_name)), + "COUNT": sa.func.COUNT, + "SUM": sa.func.SUM, + "AVG": sa.func.AVG, + "MIN": sa.func.MIN, + "MAX": sa.func.MAX, + } + + @property + def query(self) -> str: + raise NotImplementedError() + + @property + def database_id(self) -> int: + raise NotImplementedError() + + @property + def owners_data(self) -> List[Any]: + raise NotImplementedError() + + @property + def metrics(self) -> List[Any]: + raise NotImplementedError() + + @property + def uid(self) -> str: + raise NotImplementedError() + + @property + def is_rls_supported(self) -> bool: + raise NotImplementedError() + + @property + def cache_timeout(self) -> int: + raise NotImplementedError() + + @property + def column_names(self) -> List[str]: + raise NotImplementedError() + + @property + def offset(self) -> int: + raise NotImplementedError() + + @property + def main_dttm_col(self) -> Optional[str]: + raise NotImplementedError() + + @property + def dttm_cols(self) -> List[str]: + raise NotImplementedError() + + @property + def db_engine_spec(self) -> Type["BaseEngineSpec"]: + raise NotImplementedError() + + @property + def database(self) -> Type["Database"]: + raise NotImplementedError() + + @property + def schema(self) -> str: + raise NotImplementedError() + + @property + def sql(self) -> str: + raise NotImplementedError() + + @property + def columns(self) -> List[Any]: + raise NotImplementedError() + + @property + def get_fetch_values_predicate(self) -> List[Any]: + raise NotImplementedError() + + @staticmethod + def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: + raise NotImplementedError() + + def _process_sql_expression( # type: ignore # pylint: disable=no-self-use + self, + expression: Optional[str], + database_id: int, + schema: str, + template_processor: Optional[BaseTemplateProcessor], + ) -> Optional[str]: + if template_processor and expression: + expression = template_processor.process_template(expression) + if expression: + expression = validate_adhoc_subquery( + expression, + database_id, + schema, + ) + try: + expression = sanitize_clause(expression) + except QueryClauseValidationException as ex: + raise QueryObjectValidationError(ex.message) from ex + return expression + + def make_sqla_column_compatible( + self, sqla_col: ColumnElement, label: Optional[str] = None + ) -> ColumnElement: + """Takes a sqlalchemy column object and adds label info if supported by engine. + :param sqla_col: sqlalchemy column instance + :param label: alias/label that column is expected to have + :return: either a sql alchemy column or label instance if supported by engine + """ + label_expected = label or sqla_col.name + db_engine_spec = self.db_engine_spec + # add quotes to tables + if db_engine_spec.allows_alias_in_select: + label = db_engine_spec.make_label_compatible(label_expected) + sqla_col = sqla_col.label(label) + sqla_col.key = label_expected + return sqla_col + + def mutate_query_from_config(self, sql: str) -> str: + """Apply config's SQL_QUERY_MUTATOR + + Typically adds comments to the query with context""" + sql_query_mutator = config["SQL_QUERY_MUTATOR"] + if sql_query_mutator: + sql = sql_query_mutator( + sql, + user_name=utils.get_username(), # TODO(john-bodley): Deprecate in 3.0. + security_manager=security_manager, + database=self.database, + ) + return sql + + @staticmethod + def _apply_cte(sql: str, cte: Optional[str]) -> str: + """ + Append a CTE before the SELECT statement if defined + + :param sql: SELECT statement + :param cte: CTE statement + :return: + """ + if cte: + sql = f"{cte}\n{sql}" + return sql + + @staticmethod + def validate_adhoc_subquery( + sql: str, + database_id: int, + default_schema: str, + ) -> str: + """ + Check if adhoc SQL contains sub-queries or nested sub-queries with table. + + If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS + predicates to it. + + :param sql: adhoc sql expression + :raise SupersetSecurityException if sql contains sub-queries or + nested sub-queries with table + """ + + statements = [] + for statement in sqlparse.parse(sql): + if has_table_query(statement): + if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): + raise SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR, + message=_("Custom SQL fields cannot contain sub-queries."), + level=ErrorLevel.ERROR, + ) + ) + statement = insert_rls(statement, database_id, default_schema) + statements.append(statement) + + return ";\n".join(str(statement) for statement in statements) + + def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: + sqlaq = self.get_sqla_query(**query_obj) + sql = self.database.compile_sqla_query(sqlaq.sqla_query) # type: ignore + sql = self._apply_cte(sql, sqlaq.cte) + sql = sqlparse.format(sql, reindent=True) + sql = self.mutate_query_from_config(sql) + return QueryStringExtended( + applied_template_filters=sqlaq.applied_template_filters, + labels_expected=sqlaq.labels_expected, + prequeries=sqlaq.prequeries, + sql=sql, + ) + + def _normalize_prequery_result_type( + self, + row: pd.Series, + dimension: str, + columns_by_name: Dict[str, "TableColumn"], + ) -> Union[str, int, float, bool, Text]: + """ + Convert a prequery result type to its equivalent Python type. + + Some databases like Druid will return timestamps as strings, but do not perform + automatic casting when comparing these strings to a timestamp. For cases like + this we convert the value via the appropriate SQL transform. + + :param row: A prequery record + :param dimension: The dimension name + :param columns_by_name: The mapping of columns by name + :return: equivalent primitive python type + """ + + value = row[dimension] + + if isinstance(value, np.generic): + value = value.item() + + column_ = columns_by_name[dimension] + db_extra: Dict[str, Any] = self.database.get_extra() # type: ignore + + if column_.type and column_.is_temporal and isinstance(value, str): + sql = self.db_engine_spec.convert_dttm( + column_.type, dateutil.parser.parse(value), db_extra=db_extra + ) + + if sql: + value = self.text(sql) + + return value + + def make_orderby_compatible( + self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement] + ) -> None: + """ + If needed, make sure aliases for selected columns are not used in + `ORDER BY`. + + In some databases (e.g. Presto), `ORDER BY` clause is not able to + automatically pick the source column if a `SELECT` clause alias is named + the same as a source column. In this case, we update the SELECT alias to + another name to avoid the conflict. + """ + if self.db_engine_spec.allows_alias_to_source_column: + return + + def is_alias_used_in_orderby(col: ColumnElement) -> bool: + if not isinstance(col, Label): + return False + regexp = re.compile(f"\\(.*\\b{re.escape(col.name)}\\b.*\\)", re.IGNORECASE) + return any(regexp.search(str(x)) for x in orderby_exprs) + + # Iterate through selected columns, if column alias appears in orderby + # use another `alias`. The final output columns will still use the + # original names, because they are updated by `labels_expected` after + # querying. + for col in select_exprs: + if is_alias_used_in_orderby(col): + col.name = f"{col.name}__" + + def exc_query(self, qry: Any) -> QueryResult: + qry_start_dttm = datetime.now() + query_str_ext = self.get_query_str_extended(qry) + sql = query_str_ext.sql + status = QueryStatus.SUCCESS + errors = None + error_message = None + + def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: + """ + Some engines change the case or generate bespoke column names, either by + default or due to lack of support for aliasing. This function ensures that + the column names in the DataFrame correspond to what is expected by + the viz components. + Sometimes a query may also contain only order by columns that are not used + as metrics or groupby columns, but need to present in the SQL `select`, + filtering by `labels_expected` make sure we only return columns users want. + :param df: Original DataFrame returned by the engine + :return: Mutated DataFrame + """ + labels_expected = query_str_ext.labels_expected + if df is not None and not df.empty: + if len(df.columns) < len(labels_expected): + raise QueryObjectValidationError( + _("Db engine did not return all queried columns") + ) + if len(df.columns) > len(labels_expected): + df = df.iloc[:, 0 : len(labels_expected)] + df.columns = labels_expected + return df + + try: + df = self.database.get_df( + sql, self.schema, mutator=assign_column_label # type: ignore + ) + except Exception as ex: # pylint: disable=broad-except + df = pd.DataFrame() + status = QueryStatus.FAILED + logger.warning( + "Query %s on schema %s failed", sql, self.schema, exc_info=True + ) + error_message = utils.error_msg_from_exception(ex) + + return QueryResult( + status=status, + df=df, + duration=datetime.now() - qry_start_dttm, + query=sql, + errors=errors, + error_message=error_message, + ) + + def get_rendered_sql( + self, template_processor: Optional[BaseTemplateProcessor] = None + ) -> str: + """ + Render sql with template engine (Jinja). + """ + + sql = self.sql + if template_processor: + try: + sql = template_processor.process_template(sql) + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error while rendering virtual dataset query: %(msg)s", + msg=ex.message, + ) + ) from ex + sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True) + if not sql: + raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) + if len(sqlparse.split(sql)) > 1: + raise QueryObjectValidationError( + _("Virtual dataset query cannot consist of multiple statements") + ) + return sql + + def text(self, clause: str) -> TextClause: + return self.db_engine_spec.get_text_clause(clause) + + def get_from_clause( + self, template_processor: Optional[BaseTemplateProcessor] = None + ) -> Tuple[Union[TableClause, Alias], Optional[str]]: + """ + Return where to select the columns and metrics from. Either a physical table + or a virtual table with it's own subquery. If the FROM is referencing a + CTE, the CTE is returned as the second value in the return tuple. + """ + + from_sql = self.get_rendered_sql(template_processor) + parsed_query = ParsedQuery(from_sql) + if not ( + parsed_query.is_unknown() + or self.db_engine_spec.is_readonly_query(parsed_query) + ): + raise QueryObjectValidationError( + _("Virtual dataset query must be read-only") + ) + + cte = self.db_engine_spec.get_cte_query(from_sql) + from_clause = ( + sa.table(CTE_ALIAS) + if cte + else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) + ) + + return from_clause, cte + + def adhoc_metric_to_sqla( + self, + metric: AdhocMetric, + columns_by_name: Dict[str, "TableColumn"], # # pylint: disable=unused-argument + template_processor: Optional[BaseTemplateProcessor] = None, + ) -> ColumnElement: + """ + Turn an adhoc metric into a sqlalchemy column. + + :param dict metric: Adhoc metric definition + :param dict columns_by_name: Columns for the current table + :param template_processor: template_processor instance + :returns: The metric defined as a sqlalchemy column + :rtype: sqlalchemy.sql.column + """ + expression_type = metric.get("expressionType") + label = utils.get_metric_name(metric) + + if expression_type == utils.AdhocMetricExpressionType.SIMPLE: + metric_column = metric.get("column") or {} + column_name = cast(str, metric_column.get("column_name")) + sqla_column = sa.column(column_name) + sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) + elif expression_type == utils.AdhocMetricExpressionType.SQL: + expression = self._process_sql_expression( # type: ignore + expression=metric["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) + sqla_metric = literal_column(expression) + else: + raise QueryObjectValidationError("Adhoc metric expressionType is invalid") + + return self.make_sqla_column_compatible(sqla_metric, label) + + @property + def template_params_dict(self) -> Dict[Any, Any]: + return {} + + @staticmethod + def filter_values_handler( # pylint: disable=too-many-arguments + values: Optional[FilterValues], + target_generic_type: utils.GenericDataType, + target_native_type: Optional[str] = None, + is_list_target: bool = False, + db_engine_spec: Optional[ + Type["BaseEngineSpec"] + ] = None, # fix(hughhh): Optional[Type[BaseEngineSpec]] + db_extra: Optional[Dict[str, Any]] = None, + ) -> Optional[FilterValues]: + if values is None: + return None + + def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]: + if ( + isinstance(value, (float, int)) + and target_generic_type == utils.GenericDataType.TEMPORAL + and target_native_type is not None + and db_engine_spec is not None + ): + value = db_engine_spec.convert_dttm( + target_type=target_native_type, + dttm=datetime.utcfromtimestamp(value / 1000), + db_extra=db_extra, + ) + value = literal_column(value) + if isinstance(value, str): + value = value.strip("\t\n") + + if target_generic_type == utils.GenericDataType.NUMERIC: + # For backwards compatibility and edge cases + # where a column data type might have changed + return utils.cast_to_num(value) + if value == NULL_STRING: + return None + if value == EMPTY_STRING: + return "" + if target_generic_type == utils.GenericDataType.BOOLEAN: + return utils.cast_to_boolean(value) + return value + + if isinstance(values, (list, tuple)): + values = [handle_single_value(v) for v in values] # type: ignore + else: + values = handle_single_value(values) + if is_list_target and not isinstance(values, (tuple, list)): + values = [values] # type: ignore + elif not is_list_target and isinstance(values, (tuple, list)): + values = values[0] if values else None + return values + + def get_query_str(self, query_obj: QueryObjectDict) -> str: + query_str_ext = self.get_query_str_extended(query_obj) + all_queries = query_str_ext.prequeries + [query_str_ext.sql] + return ";\n\n".join(all_queries) + ";" + + def _get_series_orderby( + self, + series_limit_metric: Metric, + metrics_by_name: Mapping[str, "SqlMetric"], + columns_by_name: Mapping[str, "TableColumn"], + ) -> Column: + if utils.is_adhoc_metric(series_limit_metric): + assert isinstance(series_limit_metric, dict) + ob = self.adhoc_metric_to_sqla( + series_limit_metric, columns_by_name # type: ignore + ) + elif ( + isinstance(series_limit_metric, str) + and series_limit_metric in metrics_by_name + ): + ob = metrics_by_name[series_limit_metric].get_sqla_col() + else: + raise QueryObjectValidationError( + _("Metric '%(metric)s' does not exist", metric=series_limit_metric) + ) + return ob + + def adhoc_column_to_sqla( + self, + col: Type["AdhocColumn"], # type: ignore + template_processor: Optional[BaseTemplateProcessor] = None, + ) -> ColumnElement: + """ + Turn an adhoc column into a sqlalchemy column. + + :param col: Adhoc column definition + :param template_processor: template_processor instance + :returns: The metric defined as a sqlalchemy column + :rtype: sqlalchemy.sql.column + """ + label = utils.get_column_name(col) # type: ignore + expression = self._process_sql_expression( # type: ignore + expression=col["sqlExpression"], # type: ignore + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) + sqla_column = literal_column(expression) + return self.make_sqla_column_compatible(sqla_column, label) + + def _get_top_groups( + self, + df: pd.DataFrame, + dimensions: List[str], + groupby_exprs: Dict[str, Any], + columns_by_name: Dict[str, "TableColumn"], + ) -> ColumnElement: + groups = [] + for _unused, row in df.iterrows(): + group = [] + for dimension in dimensions: + value = self._normalize_prequery_result_type( + row, + dimension, + columns_by_name, + ) + + group.append(groupby_exprs[dimension] == value) + groups.append(and_(*group)) + + return or_(*groups) + + def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + self, + apply_fetch_values_predicate: bool = False, + columns: Optional[List[Column]] = None, + extras: Optional[Dict[str, Any]] = None, + filter: Optional[ # pylint: disable=redefined-builtin + List[utils.QueryObjectFilterClause] + ] = None, + from_dttm: Optional[datetime] = None, + granularity: Optional[str] = None, + groupby: Optional[List[Column]] = None, + inner_from_dttm: Optional[datetime] = None, + inner_to_dttm: Optional[datetime] = None, + is_rowcount: bool = False, + is_timeseries: bool = True, + metrics: Optional[List[Metric]] = None, + orderby: Optional[List[OrderBy]] = None, + order_desc: bool = True, + to_dttm: Optional[datetime] = None, + series_columns: Optional[List[Column]] = None, + series_limit: Optional[int] = None, + series_limit_metric: Optional[Metric] = None, + row_limit: Optional[int] = None, + row_offset: Optional[int] = None, + timeseries_limit: Optional[int] = None, + timeseries_limit_metric: Optional[Metric] = None, + ) -> SqlaQuery: + """Querying any sqla table from this common interface""" + if granularity not in self.dttm_cols and granularity is not None: + granularity = self.main_dttm_col + + extras = extras or {} + time_grain = extras.get("time_grain_sqla") + + template_kwargs = { + "columns": columns, + "from_dttm": from_dttm.isoformat() if from_dttm else None, + "groupby": groupby, + "metrics": metrics, + "row_limit": row_limit, + "row_offset": row_offset, + "time_column": granularity, + "time_grain": time_grain, + "to_dttm": to_dttm.isoformat() if to_dttm else None, + "table_columns": [col.get("column_name") for col in self.columns], + "filter": filter, + } + columns = columns or [] + groupby = groupby or [] + series_column_names = utils.get_column_names(series_columns or []) + # deprecated, to be removed in 2.0 + if is_timeseries and timeseries_limit: + series_limit = timeseries_limit + series_limit_metric = series_limit_metric or timeseries_limit_metric + template_kwargs.update(self.template_params_dict) + extra_cache_keys: List[Any] = [] + template_kwargs["extra_cache_keys"] = extra_cache_keys + removed_filters: List[str] = [] + applied_template_filters: List[str] = [] + template_kwargs["removed_filters"] = removed_filters + template_kwargs["applied_filters"] = applied_template_filters + template_processor = None # self.get_template_processor(**template_kwargs) + db_engine_spec = self.db_engine_spec + prequeries: List[str] = [] + orderby = orderby or [] + need_groupby = bool(metrics is not None or groupby) + metrics = metrics or [] + + # For backward compatibility + if granularity not in self.dttm_cols and granularity is not None: + granularity = self.main_dttm_col + + columns_by_name: Dict[str, "TableColumn"] = { + col.get("column_name"): col + for col in self.columns # col.column_name: col for col in self.columns + } + + if not granularity and is_timeseries: + raise QueryObjectValidationError( + _( + "Datetime column not provided as part table configuration " + "and is required by this type of chart" + ) + ) + if not metrics and not columns and not groupby: + raise QueryObjectValidationError(_("Empty query?")) + + metrics_exprs: List[ColumnElement] = [] + for metric in metrics: + if utils.is_adhoc_metric(metric): + assert isinstance(metric, dict) + metrics_exprs.append( + self.adhoc_metric_to_sqla( + metric=metric, + columns_by_name=columns_by_name, # type: ignore + template_processor=template_processor, + ) + ) + else: + raise QueryObjectValidationError( + _("Metric '%(metric)s' does not exist", metric=metric) + ) + + if metrics_exprs: + main_metric_expr = metrics_exprs[0] + else: + main_metric_expr, label = literal_column("COUNT(*)"), "ccount" + main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label) + + # To ensure correct handling of the ORDER BY labeling we need to reference the + # metric instance if defined in the SELECT clause. + # use the key of the ColumnClause for the expected label + metrics_exprs_by_label = {m.key: m for m in metrics_exprs} + metrics_exprs_by_expr = {str(m): m for m in metrics_exprs} + + # Since orderby may use adhoc metrics, too; we need to process them first + orderby_exprs: List[ColumnElement] = [] + for orig_col, ascending in orderby: + col: Union[AdhocMetric, ColumnElement] = orig_col + if isinstance(col, dict): + col = cast(AdhocMetric, col) + if col.get("sqlExpression"): + col["sqlExpression"] = self._process_sql_expression( # type: ignore + expression=col["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) + if utils.is_adhoc_metric(col): + # add adhoc sort by column to columns_by_name if not exists + col = self.adhoc_metric_to_sqla(col, columns_by_name) # type: ignore + # if the adhoc metric has been defined before + # use the existing instance. + col = metrics_exprs_by_expr.get(str(col), col) + need_groupby = True + elif col in columns_by_name: + col = columns_by_name[col].get_sqla_col() + elif col in metrics_exprs_by_label: + col = metrics_exprs_by_label[col] + need_groupby = True + + if isinstance(col, ColumnElement): + orderby_exprs.append(col) + else: + # Could not convert a column reference to valid ColumnElement + raise QueryObjectValidationError( + _("Unknown column used in orderby: %(col)s", col=orig_col) + ) + + select_exprs: List[Union[Column, Label]] = [] + groupby_all_columns = {} + groupby_series_columns = {} + + # filter out the pseudo column __timestamp from columns + columns = [col for col in columns if col != utils.DTTM_ALIAS] + dttm_col = columns_by_name.get(granularity) if granularity else None + + if need_groupby: + # dedup columns while preserving order + columns = groupby or columns + for selected in columns: + if isinstance(selected, str): + # if groupby field/expr equals granularity field/expr + if selected == granularity: + table_col = columns_by_name[selected] + outer = table_col.get_timestamp_expression( + time_grain=time_grain, + label=selected, + template_processor=template_processor, + ) + # if groupby field equals a selected column + elif selected in columns_by_name: + if isinstance(columns_by_name[selected], dict): + outer = sa.column(f"{selected}") + outer = self.make_sqla_column_compatible(outer, selected) + else: + outer = columns_by_name[selected].get_sqla_col() + else: + selected = self.validate_adhoc_subquery( + selected, + self.database_id, + self.schema, + ) + outer = sa.column(f"{selected}") + outer = self.make_sqla_column_compatible(outer, selected) + else: + outer = self.adhoc_column_to_sqla( + col=selected, template_processor=template_processor + ) + groupby_all_columns[outer.name] = outer + if not series_column_names or outer.name in series_column_names: + groupby_series_columns[outer.name] = outer + select_exprs.append(outer) + elif columns: + for selected in columns: + selected = self.validate_adhoc_subquery( + selected, + self.database_id, + self.schema, + ) + if isinstance(columns_by_name[selected], dict): + select_exprs.append(sa.column(f"{selected}")) + else: + select_exprs.append( + columns_by_name[selected].get_sqla_col() + if selected in columns_by_name + else self.make_sqla_column_compatible(literal_column(selected)) + ) + metrics_exprs = [] + + if granularity: + if granularity not in columns_by_name or not dttm_col: + raise QueryObjectValidationError( + _( + 'Time column "%(col)s" does not exist in dataset', + col=granularity, + ) + ) + time_filters: List[Any] = [] + + if is_timeseries: + timestamp = dttm_col.get_timestamp_expression( + time_grain=time_grain, template_processor=template_processor + ) + # always put timestamp as the first column + select_exprs.insert(0, timestamp) + groupby_all_columns[timestamp.name] = timestamp + + # Always remove duplicates by column name, as sometimes `metrics_exprs` + # can have the same name as a groupby column (e.g. when users use + # raw columns as custom SQL adhoc metric). + select_exprs = utils.remove_duplicates( + select_exprs + metrics_exprs, key=lambda x: x.name + ) + + # Expected output columns + labels_expected = [c.key for c in select_exprs] + + # Order by columns are "hidden" columns, some databases require them + # always be present in SELECT if an aggregation function is used + if not db_engine_spec.allows_hidden_ordeby_agg: + select_exprs = utils.remove_duplicates(select_exprs + orderby_exprs) + + qry = sa.select(select_exprs) + + tbl, cte = self.get_from_clause(template_processor) + + if groupby_all_columns: + qry = qry.group_by(*groupby_all_columns.values()) + + where_clause_and = [] + having_clause_and = [] + + for flt in filter: # type: ignore + if not all(flt.get(s) for s in ["col", "op"]): + continue + flt_col = flt["col"] + val = flt.get("val") + op = flt["op"].upper() + col_obj: Optional["TableColumn"] = None + sqla_col: Optional[Column] = None + if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: + col_obj = dttm_col + elif utils.is_adhoc_column(flt_col): + sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore + else: + col_obj = columns_by_name.get(flt_col) + filter_grain = flt.get("grain") + + if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): + if utils.get_column_name(flt_col) in removed_filters: + # Skip generating SQLA filter when the jinja template handles it. + continue + + if col_obj or sqla_col is not None: + if sqla_col is not None: + pass + elif col_obj and filter_grain: + sqla_col = col_obj.get_timestamp_expression( + time_grain=filter_grain, template_processor=template_processor + ) + elif col_obj and isinstance(col_obj, dict): + sqla_col = sa.column(col_obj.get("column_name")) + elif col_obj: + sqla_col = col_obj.get_sqla_col() + + if col_obj and isinstance(col_obj, dict): + col_type = col_obj.get("type") + else: + col_type = col_obj.type if col_obj else None + col_spec = db_engine_spec.get_column_spec( + native_type=col_type, + db_extra=self.database.get_extra(), # type: ignore + ) + is_list_target = op in ( + utils.FilterOperator.IN.value, + utils.FilterOperator.NOT_IN.value, + ) + + if col_obj and isinstance(col_obj, dict): + col_advanced_data_type = "" + else: + col_advanced_data_type = ( + col_obj.advanced_data_type if col_obj else "" + ) + + if col_spec and not col_advanced_data_type: + target_generic_type = col_spec.generic_type + else: + target_generic_type = utils.GenericDataType.STRING + eq = self.filter_values_handler( + values=val, + target_generic_type=target_generic_type, + target_native_type=col_type, + is_list_target=is_list_target, + db_engine_spec=db_engine_spec, + db_extra=self.database.get_extra(), # type: ignore + ) + if ( + col_advanced_data_type != "" + and feature_flag_manager.is_feature_enabled( + "ENABLE_ADVANCED_DATA_TYPES" + ) + and col_advanced_data_type in ADVANCED_DATA_TYPES + ): + values = eq if is_list_target else [eq] # type: ignore + bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ + col_advanced_data_type + ].translate_type( + { + "type": col_advanced_data_type, + "values": values, + } + ) + if bus_resp["error_message"]: + raise AdvancedDataTypeResponseError( + _(bus_resp["error_message"]) + ) + + where_clause_and.append( + ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( + sqla_col, op, bus_resp["values"] + ) + ) + elif is_list_target: + assert isinstance(eq, (tuple, list)) + if len(eq) == 0: + raise QueryObjectValidationError( + _("Filter value list cannot be empty") + ) + if len(eq) > len( + eq_without_none := [x for x in eq if x is not None] + ): + is_null_cond = sqla_col.is_(None) + if eq: + cond = or_(is_null_cond, sqla_col.in_(eq_without_none)) + else: + cond = is_null_cond + else: + cond = sqla_col.in_(eq) + if op == utils.FilterOperator.NOT_IN.value: + cond = ~cond + where_clause_and.append(cond) + elif op == utils.FilterOperator.IS_NULL.value: + where_clause_and.append(sqla_col.is_(None)) + elif op == utils.FilterOperator.IS_NOT_NULL.value: + where_clause_and.append(sqla_col.isnot(None)) + elif op == utils.FilterOperator.IS_TRUE.value: + where_clause_and.append(sqla_col.is_(True)) + elif op == utils.FilterOperator.IS_FALSE.value: + where_clause_and.append(sqla_col.is_(False)) + else: + if eq is None: + raise QueryObjectValidationError( + _( + "Must specify a value for filters " + "with comparison operators" + ) + ) + if op == utils.FilterOperator.EQUALS.value: + where_clause_and.append(sqla_col == eq) + elif op == utils.FilterOperator.NOT_EQUALS.value: + where_clause_and.append(sqla_col != eq) + elif op == utils.FilterOperator.GREATER_THAN.value: + where_clause_and.append(sqla_col > eq) + elif op == utils.FilterOperator.LESS_THAN.value: + where_clause_and.append(sqla_col < eq) + elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value: + where_clause_and.append(sqla_col >= eq) + elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value: + where_clause_and.append(sqla_col <= eq) + elif op == utils.FilterOperator.LIKE.value: + where_clause_and.append(sqla_col.like(eq)) + elif op == utils.FilterOperator.ILIKE.value: + where_clause_and.append(sqla_col.ilike(eq)) + else: + raise QueryObjectValidationError( + _("Invalid filter operation type: %(op)s", op=op) + ) + # todo(hugh): fix this w/ template_processor + # where_clause_and += self.get_sqla_row_level_filters(template_processor) + if extras: + where = extras.get("where") + if where: + try: + where = template_processor.process_template( # type: ignore + f"({where})" + ) + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error in jinja expression in WHERE clause: %(msg)s", + msg=ex.message, + ) + ) from ex + where_clause_and += [self.text(where)] + having = extras.get("having") + if having: + try: + having = template_processor.process_template( # type: ignore + f"({having})" + ) + except TemplateError as ex: + raise QueryObjectValidationError( + _( + "Error in jinja expression in HAVING clause: %(msg)s", + msg=ex.message, + ) + ) from ex + having_clause_and += [self.text(having)] + if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore + qry = qry.where(self.get_fetch_values_predicate()) # type: ignore + if granularity: + qry = qry.where(and_(*(time_filters + where_clause_and))) + else: + qry = qry.where(and_(*where_clause_and)) + qry = qry.having(and_(*having_clause_and)) + + self.make_orderby_compatible(select_exprs, orderby_exprs) + + for col, (orig_col, ascending) in zip(orderby_exprs, orderby): + if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label): + # if engine does not allow using SELECT alias in ORDER BY + # revert to the underlying column + col = col.element + + if ( + db_engine_spec.allows_alias_in_select + and db_engine_spec.allows_hidden_cc_in_orderby + and col.name in [select_col.name for select_col in select_exprs] + ): + col = literal_column(col.name) + direction = sa.asc if ascending else sa.desc + qry = qry.order_by(direction(col)) + + if row_limit: + qry = qry.limit(row_limit) + if row_offset: + qry = qry.offset(row_offset) + + if series_limit and groupby_series_columns: + if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries: + # some sql dialects require for order by expressions + # to also be in the select clause -- others, e.g. vertica, + # require a unique inner alias + inner_main_metric_expr = self.make_sqla_column_compatible( + main_metric_expr, "mme_inner__" + ) + inner_groupby_exprs = [] + inner_select_exprs = [] + for gby_name, gby_obj in groupby_series_columns.items(): + label = utils.get_column_name(gby_name) + inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") + inner_groupby_exprs.append(inner) + inner_select_exprs.append(inner) + + inner_select_exprs += [inner_main_metric_expr] + subq = sa.select(inner_select_exprs).select_from(tbl) + inner_time_filter = [] + + if dttm_col and not db_engine_spec.time_groupby_inline: + inner_time_filter = [ + dttm_col.get_time_filter( + inner_from_dttm or from_dttm, + inner_to_dttm or to_dttm, + ) + ] + subq = subq.where(and_(*(where_clause_and + inner_time_filter))) + subq = subq.group_by(*inner_groupby_exprs) + + ob = inner_main_metric_expr + direction = sa.desc if order_desc else sa.asc + subq = subq.order_by(direction(ob)) + subq = subq.limit(series_limit) + + on_clause = [] + for gby_name, gby_obj in groupby_series_columns.items(): + # in this case the column name, not the alias, needs to be + # conditionally mutated, as it refers to the column alias in + # the inner query + col_name = db_engine_spec.make_label_compatible(gby_name + "__") + on_clause.append(gby_obj == sa.column(col_name)) + + tbl = tbl.join(subq.alias(), and_(*on_clause)) + + # run prequery to get top groups + prequery_obj = { + "is_timeseries": False, + "row_limit": series_limit, + "metrics": metrics, + "granularity": granularity, + "groupby": groupby, + "from_dttm": inner_from_dttm or from_dttm, + "to_dttm": inner_to_dttm or to_dttm, + "filter": filter, + "orderby": orderby, + "extras": extras, + "columns": columns, + "order_desc": True, + } + + result = self.query(prequery_obj) # type: ignore + prequeries.append(result.query) + dimensions = [ + c + for c in result.df.columns + if c not in metrics and c in groupby_series_columns + ] + top_groups = self._get_top_groups( + result.df, dimensions, groupby_series_columns, columns_by_name + ) + qry = qry.where(top_groups) + + qry = qry.select_from(tbl) + + if is_rowcount: + if not db_engine_spec.allows_subqueries: + raise QueryObjectValidationError( + _("Database does not support subqueries") + ) + label = "rowcount" + col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) + qry = sa.select([col]).select_from(qry.alias("rowcount_qry")) + labels_expected = [label] + + return SqlaQuery( + applied_template_filters=applied_template_filters, + cte=cte, + extra_cache_keys=extra_cache_keys, + labels_expected=labels_expected, + sqla_query=qry, + prequeries=prequeries, + ) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 74c43718ef78..e7fd2c383369 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -17,7 +17,7 @@ """A collection of ORM sqlalchemy models for SQL Lab""" import re from datetime import datetime -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING import simplejson as json import sqlalchemy as sqla @@ -42,25 +42,31 @@ from superset import security_manager from superset.models.helpers import ( AuditMixinNullable, + ExploreMixin, ExtraJSONMixin, ImportExportMixin, ) from superset.models.tags import QueryUpdater from superset.sql_parse import CtasMethod, ParsedQuery, Table from superset.sqllab.limiting_factor import LimitingFactor -from superset.utils.core import QueryStatus, user_label +from superset.superset_typing import ResultSetColumnType +from superset.utils.core import GenericDataType, QueryStatus, user_label +if TYPE_CHECKING: + from superset.db_engine_specs import BaseEngineSpec -class Query(Model, ExtraJSONMixin): + +class Query(Model, ExtraJSONMixin, ExploreMixin): # pylint: disable=abstract-method """ORM model for SQL query Now that SQL Lab support multi-statement execution, an entry in this table may represent multiple SQL statements executed sequentially""" __tablename__ = "query" + type = "query" id = Column(Integer, primary_key=True) client_id = Column(String(11), unique=True, nullable=False) - + query_language = "sql" database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False) # Store the tmp table into the DB only if the user asks for it. @@ -167,8 +173,54 @@ def sql_tables(self) -> List[Table]: return list(ParsedQuery(self.sql).tables) @property - def columns(self) -> List[Table]: - return self.extra.get("columns", []) + def columns(self) -> List[ResultSetColumnType]: + bool_types = ("BOOL",) + num_types = ( + "DOUBLE", + "FLOAT", + "INT", + "BIGINT", + "NUMBER", + "LONG", + "REAL", + "NUMERIC", + "DECIMAL", + "MONEY", + ) + date_types = ("DATE", "TIME") + str_types = ("VARCHAR", "STRING", "CHAR") + columns = [] + col_type = "" + for col in self.extra.get("columns", []): + computed_column = {**col} + col_type = col.get("type") + + if col_type and any(map(lambda t: t in col_type.upper(), str_types)): + computed_column["type_generic"] = GenericDataType.STRING + if col_type and any(map(lambda t: t in col_type.upper(), bool_types)): + computed_column["type_generic"] = GenericDataType.BOOLEAN + if col_type and any(map(lambda t: t in col_type.upper(), num_types)): + computed_column["type_generic"] = GenericDataType.NUMERIC + if col_type and any(map(lambda t: t in col_type.upper(), date_types)): + computed_column["type_generic"] = GenericDataType.TEMPORAL + + computed_column["column_name"] = col.get("name") + computed_column["groupby"] = True + columns.append(computed_column) + return columns # type: ignore + + @property + def data(self) -> Dict[str, Any]: + return { + "name": self.tab_name, + "columns": self.columns, + "metrics": [], + "id": self.id, + "type": self.type, + "sql": self.sql, + "owners": self.owners_data, + "database": {"id": self.database_id, "backend": self.database.backend}, + } def raise_for_access(self) -> None: """ @@ -179,6 +231,53 @@ def raise_for_access(self) -> None: security_manager.raise_for_access(query=self) + @property + def db_engine_spec(self) -> Type["BaseEngineSpec"]: + return self.database.db_engine_spec + + @property + def owners_data(self) -> List[Dict[str, Any]]: + return [] + + @property + def uid(self) -> str: + return f"{self.id}__{self.type}" + + @property + def is_rls_supported(self) -> bool: + return False + + @property + def cache_timeout(self) -> int: + return 0 + + @property + def column_names(self) -> List[Any]: + return [col.get("column_name") for col in self.columns] + + @property + def offset(self) -> int: + return 0 + + @property + def main_dttm_col(self) -> Optional[str]: + for col in self.columns: + if col.get("is_dttm"): + return col.get("column_name") # type: ignore + return None + + @property + def dttm_cols(self) -> List[Any]: + return [col.get("column_name") for col in self.columns if col.get("is_dttm")] + + @property + def default_endpoint(self) -> str: + return "" + + @staticmethod + def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: + return [] + class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): """ORM model for SQL query""" diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 571fd94219aa..1b60512a37d7 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -517,6 +517,7 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca query.rows = result_set.size query.progress = 100 query.set_extra_json_key("progress", None) + query.set_extra_json_key("columns", result_set.columns) if query.select_as_cta: query.select_sql = database.select_star( query.tmp_table_name, diff --git a/superset/utils/core.py b/superset/utils/core.py index 5ce52f9f4b9d..44d76ea533fe 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1674,21 +1674,31 @@ def extract_dataframe_dtypes( "date": GenericDataType.TEMPORAL, } - columns_by_name = ( - {column.column_name: column for column in datasource.columns} - if datasource - else {} - ) + columns_by_name: Dict[str, Any] = {} + if datasource: + for column in datasource.columns: + if isinstance(column, dict): + columns_by_name[column.get("column_name")] = column + else: + columns_by_name[column.column_name] = column + generic_types: List[GenericDataType] = [] for column in df.columns: column_object = columns_by_name.get(column) series = df[column] inferred_type = infer_dtype(series) - generic_type = ( - GenericDataType.TEMPORAL - if column_object and column_object.is_dttm - else inferred_type_map.get(inferred_type, GenericDataType.STRING) - ) + if isinstance(column_object, dict): # type: ignore + generic_type = ( + GenericDataType.TEMPORAL + if column_object and column_object.get("is_dttm") + else inferred_type_map.get(inferred_type, GenericDataType.STRING) + ) + else: + generic_type = ( + GenericDataType.TEMPORAL + if column_object and column_object.is_dttm + else inferred_type_map.get(inferred_type, GenericDataType.STRING) + ) generic_types.append(generic_type) return generic_types @@ -1718,11 +1728,20 @@ def is_test() -> bool: return strtobool(os.environ.get("SUPERSET_TESTENV", "false")) -def get_time_filter_status( +def get_time_filter_status( # pylint: disable=too-many-branches datasource: "BaseDatasource", applied_time_extras: Dict[str, str], ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: - temporal_columns = {col.column_name for col in datasource.columns if col.is_dttm} + + temporal_columns: Set[Any] + if datasource.type == "query": + temporal_columns = { + col.get("column_name") for col in datasource.columns if col.get("is_dttm") + } + else: + temporal_columns = { + col.column_name for col in datasource.columns if col.is_dttm + } applied: List[Dict[str, str]] = [] rejected: List[Dict[str, str]] = [] time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL)