17
17
18
18
package org .apache .seatunnel .connectors .seatunnel .file .sink .writer ;
19
19
20
+ import org .apache .seatunnel .api .table .type .ArrayType ;
20
21
import org .apache .seatunnel .api .table .type .BasicType ;
22
+ import org .apache .seatunnel .api .table .type .DecimalType ;
23
+ import org .apache .seatunnel .api .table .type .MapType ;
21
24
import org .apache .seatunnel .api .table .type .SeaTunnelDataType ;
22
25
import org .apache .seatunnel .api .table .type .SeaTunnelRow ;
26
+ import org .apache .seatunnel .api .table .type .SeaTunnelRowType ;
23
27
import org .apache .seatunnel .connectors .seatunnel .file .sink .config .TextFileSinkConfig ;
24
28
25
29
import lombok .NonNull ;
28
32
import org .apache .orc .OrcFile ;
29
33
import org .apache .orc .TypeDescription ;
30
34
import org .apache .orc .Writer ;
35
+ import org .apache .orc .storage .common .type .HiveDecimal ;
31
36
import org .apache .orc .storage .ql .exec .vector .BytesColumnVector ;
32
37
import org .apache .orc .storage .ql .exec .vector .ColumnVector ;
38
+ import org .apache .orc .storage .ql .exec .vector .DecimalColumnVector ;
33
39
import org .apache .orc .storage .ql .exec .vector .DoubleColumnVector ;
40
+ import org .apache .orc .storage .ql .exec .vector .ListColumnVector ;
34
41
import org .apache .orc .storage .ql .exec .vector .LongColumnVector ;
42
+ import org .apache .orc .storage .ql .exec .vector .MapColumnVector ;
43
+ import org .apache .orc .storage .ql .exec .vector .StructColumnVector ;
44
+ import org .apache .orc .storage .ql .exec .vector .TimestampColumnVector ;
35
45
import org .apache .orc .storage .ql .exec .vector .VectorizedRowBatch ;
36
46
37
47
import java .io .IOException ;
48
+ import java .math .BigDecimal ;
38
49
import java .math .BigInteger ;
39
50
import java .nio .charset .StandardCharsets ;
51
+ import java .sql .Timestamp ;
52
+ import java .time .LocalDate ;
53
+ import java .time .LocalDateTime ;
54
+ import java .time .LocalTime ;
55
+ import java .time .temporal .ChronoField ;
40
56
import java .util .HashMap ;
57
+ import java .util .List ;
41
58
import java .util .Map ;
42
59
43
60
public class OrcWriteStrategy extends AbstractWriteStrategy {
@@ -109,37 +126,53 @@ private Writer getOrCreateWriter(@NonNull String filePath) {
109
126
}
110
127
111
128
private TypeDescription buildFieldWithRowType (SeaTunnelDataType <?> type ) {
112
- if (BasicType .BOOLEAN_TYPE .equals (type )) {
113
- return TypeDescription .createBoolean ();
129
+ switch (type .getSqlType ()) {
130
+ case ARRAY :
131
+ BasicType <?> elementType = ((ArrayType <?, ?>) type ).getElementType ();
132
+ return TypeDescription .createList (buildFieldWithRowType (elementType ));
133
+ case MAP :
134
+ SeaTunnelDataType <?> keyType = ((MapType <?, ?>) type ).getKeyType ();
135
+ SeaTunnelDataType <?> valueType = ((MapType <?, ?>) type ).getValueType ();
136
+ return TypeDescription .createMap (buildFieldWithRowType (keyType ), buildFieldWithRowType (valueType ));
137
+ case STRING :
138
+ return TypeDescription .createString ();
139
+ case BOOLEAN :
140
+ return TypeDescription .createBoolean ();
141
+ case TINYINT :
142
+ return TypeDescription .createByte ();
143
+ case SMALLINT :
144
+ return TypeDescription .createShort ();
145
+ case INT :
146
+ return TypeDescription .createInt ();
147
+ case BIGINT :
148
+ return TypeDescription .createLong ();
149
+ case FLOAT :
150
+ return TypeDescription .createFloat ();
151
+ case DOUBLE :
152
+ return TypeDescription .createDouble ();
153
+ case DECIMAL :
154
+ int precision = ((DecimalType ) type ).getPrecision ();
155
+ int scale = ((DecimalType ) type ).getScale ();
156
+ return TypeDescription .createDecimal ().withScale (scale ).withPrecision (precision );
157
+ case BYTES :
158
+ return TypeDescription .createBinary ();
159
+ case DATE :
160
+ return TypeDescription .createDate ();
161
+ case TIME :
162
+ case TIMESTAMP :
163
+ return TypeDescription .createTimestamp ();
164
+ case ROW :
165
+ TypeDescription struct = TypeDescription .createStruct ();
166
+ SeaTunnelDataType <?>[] fieldTypes = ((SeaTunnelRowType ) type ).getFieldTypes ();
167
+ for (int i = 0 ; i < fieldTypes .length ; i ++) {
168
+ struct .addField (((SeaTunnelRowType ) type ).getFieldName (i ), buildFieldWithRowType (fieldTypes [i ]));
169
+ }
170
+ return struct ;
171
+ case NULL :
172
+ default :
173
+ String errorMsg = String .format ("Orc file not support this type [%s]" , type .getSqlType ());
174
+ throw new UnsupportedOperationException (errorMsg );
114
175
}
115
- if (BasicType .SHORT_TYPE .equals (type )) {
116
- return TypeDescription .createShort ();
117
- }
118
- if (BasicType .INT_TYPE .equals (type )) {
119
- return TypeDescription .createInt ();
120
- }
121
- if (BasicType .LONG_TYPE .equals (type )) {
122
- return TypeDescription .createLong ();
123
- }
124
- if (BasicType .FLOAT_TYPE .equals (type )) {
125
- return TypeDescription .createFloat ();
126
- }
127
- if (BasicType .DOUBLE_TYPE .equals (type )) {
128
- return TypeDescription .createDouble ();
129
- }
130
- if (BasicType .BYTE_TYPE .equals (type )) {
131
- return TypeDescription .createByte ();
132
- }
133
- if (BasicType .STRING_TYPE .equals (type )) {
134
- return TypeDescription .createString ();
135
- }
136
- if (BasicType .VOID_TYPE .equals (type )) {
137
- return TypeDescription .createString ();
138
- }
139
-
140
- // TODO map struct array
141
-
142
- return TypeDescription .createString ();
143
176
}
144
177
145
178
private TypeDescription buildSchemaWithRowType () {
@@ -169,9 +202,101 @@ private void setColumn(Object value, ColumnVector vector, int row) {
169
202
BytesColumnVector bytesColumnVector = (BytesColumnVector ) vector ;
170
203
setByteColumnVector (value , bytesColumnVector , row );
171
204
break ;
205
+ case DECIMAL :
206
+ DecimalColumnVector decimalColumnVector = (DecimalColumnVector ) vector ;
207
+ setDecimalColumnVector (value , decimalColumnVector , row );
208
+ break ;
209
+ case TIMESTAMP :
210
+ TimestampColumnVector timestampColumnVector = (TimestampColumnVector ) vector ;
211
+ setTimestampColumnVector (value , timestampColumnVector , row );
212
+ break ;
213
+ case LIST :
214
+ ListColumnVector listColumnVector = (ListColumnVector ) vector ;
215
+ setListColumnVector (value , listColumnVector , row );
216
+ break ;
217
+ case MAP :
218
+ MapColumnVector mapColumnVector = (MapColumnVector ) vector ;
219
+ setMapColumnVector (value , mapColumnVector , row );
220
+ break ;
221
+ case STRUCT :
222
+ StructColumnVector structColumnVector = (StructColumnVector ) vector ;
223
+ setStructColumnVector (value , structColumnVector , row );
224
+ break ;
172
225
default :
173
- throw new RuntimeException ("Unexpected ColumnVector subtype" );
226
+ throw new RuntimeException ("Unexpected ColumnVector subtype " + vector .type );
227
+ }
228
+ }
229
+ }
230
+
231
+ private void setStructColumnVector (Object value , StructColumnVector structColumnVector , int row ) {
232
+ if (value instanceof SeaTunnelRow ) {
233
+ SeaTunnelRow seaTunnelRow = (SeaTunnelRow ) value ;
234
+ Object [] fields = seaTunnelRow .getFields ();
235
+ for (int i = 0 ; i < fields .length ; i ++) {
236
+ setColumn (fields [i ], structColumnVector .fields [i ], row );
237
+ }
238
+ } else {
239
+ throw new RuntimeException ("SeaTunnelRow type expected for field" );
240
+ }
241
+
242
+ }
243
+
244
+ private void setMapColumnVector (Object value , MapColumnVector mapColumnVector , int row ) {
245
+ if (value instanceof Map ) {
246
+ Map <?, ?> map = (Map <?, ?>) value ;
247
+
248
+ mapColumnVector .offsets [row ] = mapColumnVector .childCount ;
249
+ mapColumnVector .lengths [row ] = map .size ();
250
+ mapColumnVector .childCount += map .size ();
251
+
252
+ int i = 0 ;
253
+ for (Map .Entry <?, ?> entry : map .entrySet ()) {
254
+ int mapElem = (int ) mapColumnVector .offsets [row ] + i ;
255
+ setColumn (entry .getKey (), mapColumnVector .keys , mapElem );
256
+ setColumn (entry .getValue (), mapColumnVector .values , mapElem );
257
+ ++i ;
174
258
}
259
+ } else {
260
+ throw new RuntimeException ("Map type expected for field" );
261
+ }
262
+ }
263
+
264
+ private void setListColumnVector (Object value , ListColumnVector listColumnVector , int row ) {
265
+ Object [] valueArray ;
266
+ if (value instanceof Object []) {
267
+ valueArray = (Object []) value ;
268
+ } else if (value instanceof List ) {
269
+ valueArray = ((List <?>) value ).toArray ();
270
+ } else {
271
+ throw new RuntimeException ("List and Array type expected for field" );
272
+ }
273
+ listColumnVector .offsets [row ] = listColumnVector .childCount ;
274
+ listColumnVector .lengths [row ] = valueArray .length ;
275
+ listColumnVector .childCount += valueArray .length ;
276
+
277
+ for (int i = 0 ; i < valueArray .length ; i ++) {
278
+ int listElem = (int ) listColumnVector .offsets [row ] + i ;
279
+ setColumn (valueArray [i ], listColumnVector .child , listElem );
280
+ }
281
+ }
282
+
283
+ private void setDecimalColumnVector (Object value , DecimalColumnVector decimalColumnVector , int row ) {
284
+ if (value instanceof BigDecimal ) {
285
+ decimalColumnVector .set (row , HiveDecimal .create ((BigDecimal ) value ));
286
+ } else {
287
+ throw new RuntimeException ("BigDecimal type expected for field" );
288
+ }
289
+ }
290
+
291
+ private void setTimestampColumnVector (Object value , TimestampColumnVector timestampColumnVector , int row ) {
292
+ if (value instanceof Timestamp ) {
293
+ timestampColumnVector .set (row , (Timestamp ) value );
294
+ } else if (value instanceof LocalDateTime ) {
295
+ timestampColumnVector .set (row , Timestamp .valueOf ((LocalDateTime ) value ));
296
+ } else if (value instanceof LocalTime ) {
297
+ timestampColumnVector .set (row , Timestamp .valueOf (((LocalTime ) value ).atDate (LocalDate .ofEpochDay (0 ))));
298
+ } else {
299
+ throw new RuntimeException ("Time series type expected for field" );
175
300
}
176
301
}
177
302
@@ -186,10 +311,12 @@ private void setLongColumnVector(Object value, LongColumnVector longVector, int
186
311
} else if (value instanceof BigInteger ) {
187
312
BigInteger bigInt = (BigInteger ) value ;
188
313
longVector .vector [row ] = bigInt .longValue ();
189
- } else if (value instanceof Short ) {
190
- longVector .vector [row ] = (Short ) value ;
191
314
} else if (value instanceof Byte ) {
192
315
longVector .vector [row ] = (Byte ) value ;
316
+ } else if (value instanceof Short ) {
317
+ longVector .vector [row ] = (Short ) value ;
318
+ } else if (value instanceof LocalDate ) {
319
+ longVector .vector [row ] = ((LocalDate ) value ).getLong (ChronoField .EPOCH_DAY );
193
320
} else {
194
321
throw new RuntimeException ("Long or Integer type expected for field" );
195
322
}
0 commit comments