Skip to content

Commit

Permalink
move expression build logic to AstBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Sep 6, 2018
1 parent 148f477 commit 7d3cf0c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -552,11 +552,6 @@ object CatalogTypes {
*/
type TablePartitionSpec = Map[String, String]

/**
* Specifications of table partition filters. Seq of column name, comparison operator and value.
*/
type PartitionFiltersSpec = Seq[(String, String, String)]

/**
* Initialize an empty spec.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,18 +297,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
* Create a partition specification map with filters.
*/
override def visitDropPartitionSpec(
ctx: DropPartitionSpecContext): Seq[(String, String, String)] = {
ctx: DropPartitionSpecContext): Seq[Expression] = {
withOrigin(ctx) {
ctx.dropPartitionVal().asScala.map { pFilter =>
if (pFilter.identifier() == null || pFilter.constant() == null ||
pFilter.comparisonOperator() == null) {
throw new ParseException(s"Invalid partition spec: ${pFilter.getText}", ctx)
}
val partition = pFilter.identifier().getText
val value = visitStringConstant(pFilter.constant())
// We cannot use UnresolvedAttribute because resolution is performed after Analysis, when
// running the command. The type is not relevant, it is replaced during the real resolution
val partition =
AttributeReference(pFilter.identifier().getText, StringType)()
val value = Literal(visitStringConstant(pFilter.constant()))
val operator = pFilter.comparisonOperator().getChild(0).asInstanceOf[TerminalNode]
val stringOperator = SqlBaseParser.VOCABULARY.getSymbolicName(operator.getSymbol.getType)
(partition, stringOperator, value)
buildComparison(partition, value, operator)
}
}
}
Expand Down Expand Up @@ -1035,6 +1037,23 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
val left = expression(ctx.left)
val right = expression(ctx.right)
val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode]
buildComparison(left, right, operator)
}

/**
* Creates a comparison expression. The following comparison operators are supported:
* - Equal: '=' or '=='
* - Null-safe Equal: '<=>'
* - Not Equal: '<>' or '!='
* - Less than: '<'
* - Less then or Equal: '<='
* - Greater than: '>'
* - Greater then or Equal: '>='
*/
private def buildComparison(
left: Expression,
right: Expression,
operator: TerminalNode): Expression = {
operator.getSymbol.getType match {
case SqlBaseParser.EQ =>
EqualTo(left, right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}

import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.{PartitionFiltersSpec, TablePartitionSpec}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Cast, EqualNullSafe, EqualTo, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Not}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PartitioningUtils}
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
Expand Down Expand Up @@ -521,7 +521,7 @@ case class AlterTableRenamePartitionCommand(
*/
case class AlterTableDropPartitionCommand(
tableName: TableIdentifier,
partitionsFilters: Seq[PartitionFiltersSpec],
partitionsFilters: Seq[Seq[Expression]],
ifExists: Boolean,
purge: Boolean,
retainData: Boolean)
Expand All @@ -548,7 +548,8 @@ case class AlterTableDropPartitionCommand(
ifExists)
} else {
val partitionSpec = filtersSpec.map {
case (key, _, value) => key -> value
case EqualTo(key: Attribute, Literal(value, StringType)) =>
key.name -> value.toString
}.toMap
PartitioningUtils.normalizePartitionSpec(
partitionSpec,
Expand All @@ -567,42 +568,34 @@ case class AlterTableDropPartitionCommand(
Seq.empty[Row]
}

def hasComplexFilters(partitionFilterSpec: PartitionFiltersSpec): Boolean = {
!partitionFilterSpec.forall(_._2 == "EQ")
def hasComplexFilters(partitionFilterSpec: Seq[Expression]): Boolean = {
partitionFilterSpec.exists(!_.isInstanceOf[EqualTo])
}

def generatePartitionSpec(
partitionFilterSpec: PartitionFiltersSpec,
partitionFilterSpec: Seq[Expression],
partitionColumns: Seq[String],
partitionAttributes: Map[String, Attribute],
tableIdentifier: TableIdentifier,
catalog: SessionCatalog,
resolver: Resolver,
timeZone: Option[String],
ifExists: Boolean): Seq[TablePartitionSpec] = {
val filters = partitionFilterSpec.map { case (partitionColumn, operator, value) =>
val normalizedPartition = PartitioningUtils.normalizePartitionColumn(
partitionColumn,
partitionColumns,
tableIdentifier.quotedString,
resolver)
val partitionAttr = partitionAttributes(normalizedPartition)
val castedLiteralValue = Cast(Literal(value), partitionAttr.dataType, timeZone)
operator match {
case "EQ" =>
EqualTo(partitionAttr, castedLiteralValue)
case "NSEQ" =>
EqualNullSafe(partitionAttr, castedLiteralValue)
case "NEQ" | "NEQJ" =>
Not(EqualTo(partitionAttr, castedLiteralValue))
case "LT" =>
LessThan(partitionAttr, castedLiteralValue)
case "LTE" =>
LessThanOrEqual(partitionAttr, castedLiteralValue)
case "GT" =>
GreaterThan(partitionAttr, castedLiteralValue)
case "GTE" =>
GreaterThanOrEqual(partitionAttr, castedLiteralValue)
val filters = partitionFilterSpec.map { pFilter =>
pFilter.transform {
// Resolve the partition attributes
case partitionCol: Attribute =>
val normalizedPartition = PartitioningUtils.normalizePartitionColumn(
partitionCol.name,
partitionColumns,
tableIdentifier.quotedString,
resolver)
partitionAttributes(normalizedPartition)
}.transform {
// Cast the partition value to the data type of the corresponding partition attribute
case cmp @ BinaryComparison(partitionAttr, value)
if !partitionAttr.dataType.sameType(value.dataType) =>
cmp.withNewChildren(Seq(partitionAttr, Cast(value, partitionAttr.dataType, timeZone)))
}
}
val partitions = catalog.listPartitionsByFilter(tableIdentifier, filters)
Expand All @@ -623,15 +616,15 @@ object AlterTableDropPartitionCommand {
purge: Boolean,
retainData: Boolean): AlterTableDropPartitionCommand = {
AlterTableDropPartitionCommand(tableName,
specs.map(tablePartitionToPartitionFiltersSpec),
specs.map(tablePartitionToPartitionFilters),
ifExists,
purge,
retainData)
}

def tablePartitionToPartitionFiltersSpec(spec: TablePartitionSpec): PartitionFiltersSpec = {
def tablePartitionToPartitionFilters(spec: TablePartitionSpec): Seq[Expression] = {
spec.map {
case (key, value) => (key, "EQ", value)
case (key, value) => EqualTo(AttributeReference(key, StringType)(), Literal(value))
}.toSeq
}
}
Expand Down

0 comments on commit 7d3cf0c

Please sign in to comment.