Skip to content

Commit

Permalink
Fix ST_Union_Aggr bug, enable Hive Support, and clean up the tests (#284
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jiayuasu committed Oct 30, 2018
1 parent 7ae258e commit 9768147
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 192 deletions.
2 changes: 2 additions & 0 deletions sql/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
*.iml
/latest/
/spark-warehouse/
/metastore_db/
*.log
6 changes: 6 additions & 0 deletions sql/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_${scala.compat.version}</artifactId>
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
<build>
<sourceDirectory>src/main/scala</sourceDirectory>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
*/
package org.apache.spark.sql.geosparksql.expressions

import com.vividsolutions.jts.geom.{Coordinate, Geometry, GeometryFactory, Polygon}
import com.vividsolutions.jts.geom.{Coordinate, Geometry, GeometryFactory}
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.geosparksql.UDT.GeometryUDT
Expand Down Expand Up @@ -58,15 +58,15 @@ class ST_Union_Aggr extends UserDefinedAggregateFunction {
}

override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val accumulateUnion = buffer.getAs[Polygon](0)
val newPolygon = input.getAs[Polygon](0)
val accumulateUnion = buffer.getAs[Geometry](0)
val newPolygon = input.getAs[Geometry](0)
if (accumulateUnion.getArea == 0) buffer(0) = newPolygon
else buffer(0) = accumulateUnion.union(newPolygon)
}

override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val leftPolygon = buffer1.getAs[Polygon](0)
val rightPolygon = buffer2.getAs[Polygon](0)
val leftPolygon = buffer1.getAs[Geometry](0)
val rightPolygon = buffer2.getAs[Geometry](0)
if (leftPolygon.getCoordinates()(0).x == -999999999) buffer1(0) = rightPolygon
else if (rightPolygon.getCoordinates()(0).x == -999999999) buffer1(0) = leftPolygon
else buffer1(0) = leftPolygon.union(rightPolygon)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,31 @@ case class ST_GeomFromWKT(inputExpressions: Seq[Expression])
}


/**
* Return a Geometry from a WKT string
*
* @param inputExpressions This function takes 1 parameter which is the geometry string. The string format must be WKT.
*/
case class ST_GeomFromText(inputExpressions: Seq[Expression])
extends Expression with CodegenFallback with UserDataGeneratator {
override def nullable: Boolean = false

override def eval(inputRow: InternalRow): Any = {
// This is an expression which takes one input expressions
assert(inputExpressions.length == 1)
val geomString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString
var fileDataSplitter = FileDataSplitter.WKT
var formatMapper = new FormatMapper(fileDataSplitter, false)
var geometry = formatMapper.readGeometry(geomString)
return new GenericArrayData(GeometrySerializer.serialize(geometry))
}

override def dataType: DataType = new GeometryUDT()

override def children: Seq[Expression] = inputExpressions
}


/**
* Return a Geometry from a WKB string
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.geosparksql.expressions._

import scala.reflect.runtime.{universe => ru}

object Catalog {
val expressions:Seq[FunctionBuilder] = Seq(
ST_PointFromText,
ST_PolygonFromText,
ST_LineStringFromText,
ST_GeomFromText,
ST_GeomFromWKT,
ST_GeomFromWKB,
ST_Point,
Expand Down
66 changes: 66 additions & 0 deletions sql/src/test/scala/org/datasyslab/geosparksql/TestBaseScala.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/**
* FILE: TestBaseScala
* Copyright (c) 2015 - 2018 GeoSpark Development Team
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package org.datasyslab.geosparksql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.SparkSession
import org.datasyslab.geospark.serde.GeoSparkKryoRegistrator
import org.datasyslab.geosparksql.utils.GeoSparkSQLRegistrator
import org.scalatest.{BeforeAndAfterAll, FunSpec}

trait TestBaseScala extends FunSpec with BeforeAndAfterAll{
Logger.getLogger("org.apache").setLevel(Level.WARN)
Logger.getLogger("com").setLevel(Level.WARN)
Logger.getLogger("akka").setLevel(Level.WARN)
Logger.getLogger("org.datasyslab.geospark").setLevel(Level.WARN)

val warehouseLocation = System.getProperty("user.dir") + "/target/"
var sparkSession = SparkSession.builder().config("spark.serializer", classOf[KryoSerializer].getName).
config("spark.kryo.registrator", classOf[GeoSparkKryoRegistrator].getName).
master("local[*]").appName("geosparksqlScalaTest")
.config("spark.sql.warehouse.dir", warehouseLocation)
.enableHiveSupport().getOrCreate()

val resourceFolder = System.getProperty("user.dir") + "/src/test/resources/"

val mixedWkbGeometryInputLocation = resourceFolder + "county_small_wkb.tsv"
val mixedWktGeometryInputLocation = resourceFolder + "county_small.tsv"
val shapefileInputLocation = resourceFolder + "shapefiles/dbf"
val geojsonInputLocation = resourceFolder + "testPolygon.json"
val arealmPointInputLocation = resourceFolder + "arealm.csv"
val csvPointInputLocation = resourceFolder + "testpoint.csv"
val csvPolygonInputLocation = resourceFolder + "testenvelope.csv"
val unionPolygonInputLocation = resourceFolder + "testunion.csv"

override def beforeAll(): Unit = {
GeoSparkSQLRegistrator.registerAll(sparkSession)
}

override def afterAll(): Unit = {
//GeoSparkSQLRegistrator.dropAll(spark)
//spark.stop
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,44 +27,18 @@
package org.datasyslab.geosparksql

import com.vividsolutions.jts.geom.Geometry
import org.apache.log4j.{Level, Logger}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.SparkSession
import org.datasyslab.geospark.enums.{FileDataSplitter, GridType, IndexType}
import org.datasyslab.geospark.formatMapper.shapefileParser.ShapefileReader
import org.datasyslab.geospark.serde.GeoSparkKryoRegistrator
import org.datasyslab.geospark.spatialOperator.JoinQuery
import org.datasyslab.geospark.spatialRDD.{CircleRDD, PolygonRDD, SpatialRDD}
import org.datasyslab.geosparksql.utils.{Adapter, GeoSparkSQLRegistrator}
import org.scalatest.{BeforeAndAfterAll, FunSpec}
import org.datasyslab.geosparksql.utils.Adapter

class adapterTestScala extends FunSpec with BeforeAndAfterAll {

var sparkSession: SparkSession = _

override def afterAll(): Unit = {
//UdfRegistrator.dropAll(sparkSession)
//sparkSession.stop
}
class adapterTestScala extends TestBaseScala {

describe("GeoSpark-SQL Scala Adapter Test") {
sparkSession = SparkSession.builder().config("spark.serializer", classOf[KryoSerializer].getName).
config("spark.kryo.registrator", classOf[GeoSparkKryoRegistrator].getName).
master("local[*]").appName("readTestScala").getOrCreate()
Logger.getLogger("org").setLevel(Level.WARN)
Logger.getLogger("akka").setLevel(Level.WARN)

GeoSparkSQLRegistrator.registerAll(sparkSession.sqlContext)

val resourceFolder = System.getProperty("user.dir") + "/src/test/resources/"

val mixedWktGeometryInputLocation = resourceFolder + "county_small.tsv"
val csvPointInputLocation = resourceFolder + "arealm.csv"
val shapefileInputLocation = resourceFolder + "shapefiles/polygon"
val geojsonInputLocation = resourceFolder + "testPolygon.json"

it("Read CSV point into a SpatialRDD") {
var df = sparkSession.read.format("csv").option("delimiter", "\t").option("header", "false").load(csvPointInputLocation)
var df = sparkSession.read.format("csv").option("delimiter", "\t").option("header", "false").load(arealmPointInputLocation)
df.show()
df.createOrReplaceTempView("inputtable")
var spatialDf = sparkSession.sql("select ST_PointFromText(inputtable._c0,\",\") as arealandmark from inputtable")
Expand All @@ -77,7 +51,7 @@ class adapterTestScala extends FunSpec with BeforeAndAfterAll {
}

it("Read CSV point into a SpatialRDD by passing coordinates") {
var df = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
var df = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(arealmPointInputLocation)
df.show()
df.createOrReplaceTempView("inputtable")
var spatialDf = sparkSession.sql("select ST_Point(cast(inputtable._c0 as Decimal(24,20)),cast(inputtable._c1 as Decimal(24,20))) as arealandmark from inputtable")
Expand All @@ -91,7 +65,7 @@ class adapterTestScala extends FunSpec with BeforeAndAfterAll {
}

it("Read CSV point into a SpatialRDD with unique Id by passing coordinates") {
var df = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
var df = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(arealmPointInputLocation)
df.show()
df.createOrReplaceTempView("inputtable")
// Use Column _c0 as the unique Id but the id can be anything in the same row
Expand Down Expand Up @@ -142,8 +116,7 @@ class adapterTestScala extends FunSpec with BeforeAndAfterAll {
}

it("Read GeoJSON to DataFrame") {
import org.apache.spark.sql.functions.callUDF
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.{callUDF, col}
var spatialRDD = new PolygonRDD(sparkSession.sparkContext, geojsonInputLocation, FileDataSplitter.GEOJSON, true)
spatialRDD.analyze()
var df = Adapter.toDf(spatialRDD, sparkSession).withColumn("geometry", callUDF("ST_GeomFromWKT", col("geometry")))
Expand All @@ -159,7 +132,7 @@ class adapterTestScala extends FunSpec with BeforeAndAfterAll {
polygonRDD.rawSpatialRDD = Adapter.toRdd(polygonDf)
polygonRDD.analyze()

var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(arealmPointInputLocation)
pointCsvDF.createOrReplaceTempView("pointtable")
var pointDf = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable")
var pointRDD = new SpatialRDD[Geometry]
Expand All @@ -181,7 +154,7 @@ class adapterTestScala extends FunSpec with BeforeAndAfterAll {
}

it("Convert distance join result to DataFrame") {
var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(arealmPointInputLocation)
pointCsvDF.createOrReplaceTempView("pointtable")
var pointDf = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable")
var pointRDD = new SpatialRDD[Geometry]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,14 @@

package org.datasyslab.geosparksql

import com.vividsolutions.jts.geom.{Coordinate, GeometryFactory, Polygon}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.SparkSession
import org.datasyslab.geospark.serde.GeoSparkKryoRegistrator
import org.datasyslab.geosparksql.utils.GeoSparkSQLRegistrator
import org.scalatest.{BeforeAndAfterAll, FunSpec}
import com.vividsolutions.jts.geom.{Coordinate, Geometry, GeometryFactory}

class aggregateFunctionTestScala extends FunSpec with BeforeAndAfterAll {

var sparkSession: SparkSession = _

override def afterAll(): Unit = {
//GeoSparkSQLRegistrator.dropAll(sparkSession)
//sparkSession.stop
}
class aggregateFunctionTestScala extends TestBaseScala {

describe("GeoSpark-SQL Aggregate Function Test") {
sparkSession = SparkSession.builder().config("spark.serializer", classOf[KryoSerializer].getName).
config("spark.kryo.registrator", classOf[GeoSparkKryoRegistrator].getName).
master("local[*]").appName("readTestScala").getOrCreate()
Logger.getLogger("org").setLevel(Level.WARN)
Logger.getLogger("akka").setLevel(Level.WARN)

GeoSparkSQLRegistrator.registerAll(sparkSession.sqlContext)

val resourceFolder = System.getProperty("user.dir") + "/src/test/resources/"

val csvPolygonInputLocation = resourceFolder + "testunion.csv"
val plainPointInputLocation = resourceFolder + "testpoint.csv"

it("Passed ST_Envelope_aggr") {
var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(plainPointInputLocation)
var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
pointCsvDF.createOrReplaceTempView("pointtable")
var pointDf = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable")
pointDf.createOrReplaceTempView("pointdf")
Expand All @@ -76,14 +51,14 @@ class aggregateFunctionTestScala extends FunSpec with BeforeAndAfterAll {

it("Passed ST_Union_aggr") {

var polygonCsvDf = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPolygonInputLocation)
var polygonCsvDf = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(unionPolygonInputLocation)
polygonCsvDf.createOrReplaceTempView("polygontable")
polygonCsvDf.show()
var polygonDf = sparkSession.sql("select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable")
polygonDf.createOrReplaceTempView("polygondf")
polygonDf.show()
var union = sparkSession.sql("select ST_Union_Aggr(polygondf.polygonshape) from polygondf")
assert(union.take(1)(0).get(0).asInstanceOf[Polygon].getArea == 10100)
assert(union.take(1)(0).get(0).asInstanceOf[Geometry].getArea == 10100)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,52 +26,23 @@

package org.datasyslab.geosparksql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.SparkSession
import org.datasyslab.geospark.formatMapper.GeoJsonReader
import org.datasyslab.geospark.formatMapper.shapefileParser.ShapefileReader
import org.datasyslab.geospark.serde.GeoSparkKryoRegistrator
import org.datasyslab.geosparksql.utils.{Adapter, GeoSparkSQLRegistrator}
import org.scalatest.{BeforeAndAfterAll, FunSpec}
import org.datasyslab.geosparksql.utils.Adapter

class constructorTestScala extends FunSpec with BeforeAndAfterAll {

var sparkSession: SparkSession = _


override def afterAll(): Unit = {
//GeoSparkSQLRegistrator.dropAll(sparkSession)
//sparkSession.stop
}
class constructorTestScala extends TestBaseScala {

describe("GeoSpark-SQL Constructor Test") {
sparkSession = SparkSession.builder().config("spark.serializer", classOf[KryoSerializer].getName).
config("spark.kryo.registrator", classOf[GeoSparkKryoRegistrator].getName).
master("local[*]").appName("readTestScala").getOrCreate()
Logger.getLogger("org").setLevel(Level.WARN)
Logger.getLogger("akka").setLevel(Level.WARN)

GeoSparkSQLRegistrator.registerAll(sparkSession.sqlContext)

val resourceFolder = System.getProperty("user.dir") + "/src/test/resources/"

val mixedWktGeometryInputLocation = resourceFolder + "county_small.tsv"
val mixedWkbGeometryInputLocation = resourceFolder + "county_small_wkb.tsv"
val plainPointInputLocation = resourceFolder + "testpoint.csv"
val shapefileInputLocation = resourceFolder + "shapefiles/dbf"
val csvPointInputLocation = resourceFolder + "arealm.csv"
val geoJsonGeomInputLocation = resourceFolder + "testPolygon.json"

it("Passed ST_Point") {
var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(plainPointInputLocation)
var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
pointCsvDF.createOrReplaceTempView("pointtable")
var pointDf = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable")
assert(pointDf.count() == 1000)
}

it("Passed ST_PointFromText") {
var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(arealmPointInputLocation)
pointCsvDF.createOrReplaceTempView("pointtable")
pointCsvDF.show(false)

Expand All @@ -83,7 +54,16 @@ class constructorTestScala extends FunSpec with BeforeAndAfterAll {
var polygonWktDf = sparkSession.read.format("csv").option("delimiter", "\t").option("header", "false").load(mixedWktGeometryInputLocation)
polygonWktDf.createOrReplaceTempView("polygontable")
polygonWktDf.show()
var polygonDf = sparkSession.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable")
var polygonDf = sparkSession.sql("select ST_GeomFromWkt(polygontable._c0) as countyshape from polygontable")
polygonDf.show(10)
assert(polygonDf.count() == 100)
}

it("Passed ST_GeomFromText") {
var polygonWktDf = sparkSession.read.format("csv").option("delimiter", "\t").option("header", "false").load(mixedWktGeometryInputLocation)
polygonWktDf.createOrReplaceTempView("polygontable")
polygonWktDf.show()
var polygonDf = sparkSession.sql("select ST_GeomFromText(polygontable._c0) as countyshape from polygontable")
polygonDf.show(10)
assert(polygonDf.count() == 100)
}
Expand All @@ -98,7 +78,7 @@ class constructorTestScala extends FunSpec with BeforeAndAfterAll {
}

it("Passed GeoJsonReader to DataFrame") {
var spatialRDD = GeoJsonReader.readToGeometryRDD(sparkSession.sparkContext, geoJsonGeomInputLocation)
var spatialRDD = GeoJsonReader.readToGeometryRDD(sparkSession.sparkContext, geojsonInputLocation)
var spatialDf = Adapter.toDf(spatialRDD, sparkSession)
spatialDf.show()
}
Expand Down
Loading

0 comments on commit 9768147

Please sign in to comment.