forked from vesoft-inc/nebula-exchange
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Processor.scala
249 lines (227 loc) · 9.36 KB
/
Processor.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/
package com.vesoft.exchange.common.processor
import com.vesoft.exchange.common.VidType
import com.vesoft.exchange.common.utils.{HDFSUtils, NebulaPartitioner, NebulaUtils}
import com.vesoft.exchange.common.utils.NebulaUtils.DEFAULT_EMPTY_VALUE
import com.vesoft.nebula.{
Coordinate,
Date,
DateTime,
Geography,
LineString,
NullType,
Point,
Polygon,
PropertyType,
Time,
Value
}
import org.apache.log4j.Logger
import org.apache.spark.sql.{DataFrame, Dataset, Encoders, Row, SparkSession}
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
/**
* com.vesoft.exchange.common.processor is a converter.
* It is responsible for converting the dataframe row data into Nebula Graph's vertex or edge,
* and submit data to writer.
*/
trait Processor extends Serializable {
@transient
private[this] lazy val LOG = Logger.getLogger(this.getClass)
/**
* process dataframe to vertices or edges
*/
def process(): Unit
/**
* handle special types of attributes
*
* String type: add "" for attribute value, if value contains escape symbol,then keep it.
*
* Date type: add date() function for attribute value.
* eg: convert attribute value 2020-01-01 to date("2020-01-01")
*
* Time type: add time() function for attribute value.
* eg: convert attribute value 12:12:12.1111 to time("12:12:12.1111")
*
* DataTime type: add datetime() function for attribute value.
* eg: convert attribute value 2020-01-01T22:30:40 to datetime("2020-01-01T22:30:40")
*/
def extraValueForClient(row: Row, field: String, fieldTypeMap: Map[String, Int]): Any = {
val index = row.schema.fieldIndex(field)
if (row.isNullAt(index)) return null
PropertyType.findByValue(fieldTypeMap(field)) match {
case PropertyType.STRING | PropertyType.FIXED_STRING => {
var value = row.get(index).toString
if (value.equals(DEFAULT_EMPTY_VALUE)) {
value = ""
}
val result = NebulaUtils.escapeUtil(value).mkString("\"", "", "\"")
result
}
case PropertyType.DATE => "date(\"" + row.get(index) + "\")"
case PropertyType.DATETIME => "datetime(\"" + row.get(index) + "\")"
case PropertyType.TIME => "time(\"" + row.get(index) + "\")"
case PropertyType.TIMESTAMP => {
val value = row.get(index).toString
if (NebulaUtils.isNumic(value)) {
value
} else {
"timestamp(\"" + row.get(index) + "\")"
}
}
case PropertyType.GEOGRAPHY => "ST_GeogFromText(\"" + row.get(index) + "\")"
case _ => row.get(index)
}
}
def extraValueForSST(row: Row, field: String, fieldTypeMap: Map[String, Int]): Any = {
val index = row.schema.fieldIndex(field)
if (row.isNullAt(index)) {
val nullVal = new Value()
nullVal.setNVal(NullType.__NULL__)
return nullVal
}
PropertyType.findByValue(fieldTypeMap(field)) match {
case PropertyType.UNKNOWN =>
throw new IllegalArgumentException("date type in nebula is UNKNOWN.")
case PropertyType.STRING | PropertyType.FIXED_STRING => {
val value = row.get(index).toString
if (value.equals(DEFAULT_EMPTY_VALUE)) "" else value
}
case PropertyType.BOOL => row.get(index).toString.toBoolean
case PropertyType.DOUBLE => row.get(index).toString.toDouble
case PropertyType.FLOAT => row.get(index).toString.toFloat
case PropertyType.INT8 => row.get(index).toString.toByte
case PropertyType.INT16 => row.get(index).toString.toShort
case PropertyType.INT32 => row.get(index).toString.toInt
case PropertyType.INT64 | PropertyType.VID => row.get(index).toString.toLong
case PropertyType.TIME => {
val values = row.get(index).toString.split(":")
if (values.size < 3) {
throw new UnsupportedOperationException(
s"wrong format for time value: ${row.get(index)}, correct format is 12:00:00:0000")
}
val secs: Array[String] = values(2).split("\\.")
val sec: Byte = secs(0).toByte
val microSec: Int = if (secs.length == 2) secs(1).toInt else 0
new Time(values(0).toByte, values(1).toByte, sec, microSec)
}
case PropertyType.DATE => {
val values = row.get(index).toString.split("-")
if (values.size < 3) {
throw new UnsupportedOperationException(
s"wrong format for date value: ${row.get(index)}, correct format is 2020-01-01")
}
new Date(values(0).toShort, values(1).toByte, values(2).toByte)
}
case PropertyType.DATETIME => {
val rowValue = row.get(index).toString
var dateTimeValue: Array[String] = null
if (rowValue.contains("T")) {
dateTimeValue = rowValue.split("T")
} else if (rowValue.trim.contains(" ")) {
dateTimeValue = rowValue.trim.split(" ")
} else {
throw new UnsupportedOperationException(
s"wrong format for datetime value: $rowValue, " +
s"correct format is 2020-01-01T12:00:00.0000 or 2020-01-01 12:00:00.0000")
}
if (dateTimeValue.size < 2) {
throw new UnsupportedOperationException(
s"wrong format for datetime value: $rowValue, " +
s"correct format is 2020-01-01T12:00:00.0000 or 2020-01-01 12:00:00.0000")
}
val dateValues = dateTimeValue(0).split("-")
val timeValues = dateTimeValue(1).split(":")
if (dateValues.size < 3 || timeValues.size < 3) {
throw new UnsupportedOperationException(
s"wrong format for datetime value: $rowValue, " +
s"correct format is 2020-01-01T12:00:00.0000 or 2020-01-01 12:00:00")
}
val secs: Array[String] = timeValues(2).split("\\.")
val sec: Byte = secs(0).toByte
val microsec: Int = if (secs.length == 2) secs(1).toInt else 0
new DateTime(dateValues(0).toShort,
dateValues(1).toByte,
dateValues(2).toByte,
timeValues(0).toByte,
timeValues(1).toByte,
sec,
microsec)
}
case PropertyType.TIMESTAMP => {
val value = row.get(index).toString
if (!NebulaUtils.isNumic(value)) {
throw new IllegalArgumentException(
s"timestamp only support long type, your value is ${value}")
}
row.get(index).toString.toLong
}
case PropertyType.GEOGRAPHY => {
val wkt = row.get(index).toString
val jtsGeom = new org.locationtech.jts.io.WKTReader().read(wkt)
convertJTSGeometryToGeography(jtsGeom)
}
}
}
def fetchOffset(path: String): Long = {
HDFSUtils.getContent(path).toLong
}
def convertJTSGeometryToGeography(jtsGeom: org.locationtech.jts.geom.Geometry): Geography = {
jtsGeom.getGeometryType match {
case "Point" => {
val jtsPoint = jtsGeom.asInstanceOf[org.locationtech.jts.geom.Point]
val jtsCoord = jtsPoint.getCoordinate
Geography.ptVal(new Point(new Coordinate(jtsCoord.x, jtsCoord.y)))
}
case "LineString" => {
val jtsLineString = jtsGeom.asInstanceOf[org.locationtech.jts.geom.LineString]
val jtsCoordList = jtsLineString.getCoordinates
val coordList = new ListBuffer[Coordinate]()
for (jtsCoord <- jtsCoordList) {
coordList += new Coordinate(jtsCoord.x, jtsCoord.y)
}
Geography.lsVal(new LineString(coordList.asJava))
}
case "Polygon" => {
val jtsPolygon = jtsGeom.asInstanceOf[org.locationtech.jts.geom.Polygon]
val coordListList = new java.util.ArrayList[java.util.List[Coordinate]]()
val jtsShell = jtsPolygon.getExteriorRing
val coordList = new ListBuffer[Coordinate]()
for (jtsCoord <- jtsShell.getCoordinates) {
coordList += new Coordinate(jtsCoord.x, jtsCoord.y)
}
coordListList.add(coordList.asJava)
val jtsHolesNum = jtsPolygon.getNumInteriorRing
for (i <- 0 until jtsHolesNum) {
val coordList = new ListBuffer[Coordinate]()
val jtsHole = jtsPolygon.getInteriorRingN(i)
for (jtsCoord <- jtsHole.getCoordinates) {
coordList += new Coordinate(jtsCoord.x, jtsCoord.y)
}
coordListList.add(coordList.asJava)
}
Geography.pgVal(new Polygon(coordListList))
}
}
}
def printChoice(streamFlag: Boolean, context: String): Unit = {
if (streamFlag) LOG.warn(context)
else assert(assertion = false, context)
}
def customRepartition(spark: SparkSession,
data: Dataset[(Array[Byte], Array[Byte])],
partitionNum: Int): Dataset[(Array[Byte], Array[Byte])] = {
import spark.implicits._
data.rdd
.partitionBy(new NebulaPartitioner(partitionNum))
.map(kv => SSTData(kv._1, kv._2))
.toDF()
.map { row =>
(row.getAs[Array[Byte]](0), row.getAs[Array[Byte]](1))
}(Encoders.tuple(Encoders.BINARY, Encoders.BINARY))
}
}
case class SSTData(key: Array[Byte], value: Array[Byte])