diff --git a/postgresql/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/PostgresqlQueryProcessor.scala b/postgresql/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/PostgresqlQueryProcessor.scala index 5e11bf79e..49fc027c5 100644 --- a/postgresql/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/PostgresqlQueryProcessor.scala +++ b/postgresql/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/PostgresqlQueryProcessor.scala @@ -63,22 +63,19 @@ class PostgresqlQueryProcessor(postgresRelation: PostgresqlXDRelation, val limit: Option[Int] = logicalPlan.collectFirst { case Limit(Literal(num: Int, _), _) => num } try { - new SQLBuilder(logicalPlan).toSQL.map { sqlQuery => - if (limit.exists(_ == 0)) Array.empty[InternalRow] - else { - - Try(executeQuery(sqlQuery)).getOrElse{ - val sqlWithLimit = s"$sqlText LIMIT ${limit.getOrElse(DefaultLimit)}" - executeQuery(sqlWithLimit) - } - + if (limit.exists(_ == 0)) Some(Array.empty[InternalRow]) + else { + lazy val sqlWithLimit = s"$sqlText LIMIT ${limit.getOrElse(DefaultLimit)}" + lazy val executeDirectQuery = Some(executeQuery(sqlWithLimit)) + new SQLBuilder(logicalPlan).toSQL.fold(executeDirectQuery){ sqlQuery => + Try(Some(executeQuery(sqlQuery))).getOrElse{executeDirectQuery} } } } catch { case exc: Exception => log.warn(s"Exception executing the native query $logicalPlan", exc); None } } -//spark code + //spark code private def getValue(idx: Int, rs: ResultSet, schema: StructType) : Any = { val metadata = schema.fields(idx).metadata val rsIdx= idx+1 diff --git a/postgresql/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/PostgresqlXDRelation.scala b/postgresql/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/PostgresqlXDRelation.scala index bcd735474..7153c3a14 100644 --- a/postgresql/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/PostgresqlXDRelation.scala +++ b/postgresql/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/PostgresqlXDRelation.scala @@ -97,7 +97,7 @@ class PostgresqlXDRelation( url: String, case _ => false } case bn: BinaryNode => bn match { - case _: Join => true + case Join(_, _,_, _) | Union(_, _) | Intersect(_, _) | Except(_, _) => true case _ => false } case unsupportedLogicalPlan =>logDebug(s"LogicalPlan $unsupportedLogicalPlan cannot be executed natively"); false diff --git a/postgresql/src/test/scala/com/stratio/crossdata/connector/postgresql/PostgresqlJoinIT.scala b/postgresql/src/test/scala/com/stratio/crossdata/connector/postgresql/PostgresqlJoinIT.scala index 5081bb0c6..660962dcf 100644 --- a/postgresql/src/test/scala/com/stratio/crossdata/connector/postgresql/PostgresqlJoinIT.scala +++ b/postgresql/src/test/scala/com/stratio/crossdata/connector/postgresql/PostgresqlJoinIT.scala @@ -77,4 +77,31 @@ class PostgresqlJoinIT extends PostgresqlWithSharedContext { result should have length 20 } + it should s"support a UNION natively" in { + assumeEnvironmentIsUpAndRunning + + val df = sql(s"SELECT id FROM $postgresqlSchema.$Table UNION ALL SELECT id FROM $postgresqlSchema.$aggregationTable") + val result = df.collect(ExecutionType.Native) + + result should have length 30 + } + + it should s"support a INTERSECT natively" in { + assumeEnvironmentIsUpAndRunning + + val df = sql(s"SELECT id FROM $postgresqlSchema.$Table INTERSECT SELECT id FROM $postgresqlSchema.$aggregationTable") + val result = df.collect(ExecutionType.Native) + + result should have length 10 + } + + it should s"support a EXCEPT natively" in { + assumeEnvironmentIsUpAndRunning + + val df = sql(s"SELECT id FROM $postgresqlSchema.$Table EXCEPT SELECT id FROM $postgresqlSchema.$aggregationTable") + val result = df.collect(ExecutionType.Native) + + result should have length 0 + } + } diff --git a/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverStreamsAPIIT.scala b/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverStreamsAPIIT.scala index 46dbbddad..9aef545ea 100644 --- a/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverStreamsAPIIT.scala +++ b/testsIT/src/test/scala/com/stratio/crossdata/driver/DriverStreamsAPIIT.scala @@ -55,7 +55,7 @@ class DriverStreamsAPIIT extends EndToEndTest with ScalaFutures { .requestNext(Row(2, "Fuse")) .request(1).expectComplete() - }(PatienceConfig(timeout = 2 seconds)) + }(PatienceConfig(timeout = 4 seconds)) } }