Skip to content
Permalink
Browse files
[Bug] Fix row type decimal convert bug (#26)
* Fix row type decimal convert bug
  • Loading branch information
aiwenmo committed Apr 15, 2022
1 parent 2ffc9b8 commit 19e24c741e79a24acb480256bbedfe9351c5d4dd
Showing 4 changed files with 65 additions and 30 deletions.
@@ -18,7 +18,6 @@
package org.apache.doris.flink.serialization;

import org.apache.arrow.memory.RootAllocator;

import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DecimalVector;
@@ -36,12 +35,7 @@
import org.apache.doris.flink.exception.DorisException;
import org.apache.doris.flink.rest.models.Schema;
import org.apache.doris.thrift.TScanBatchResult;

import org.apache.flink.table.data.DecimalData;
import org.apache.flink.table.data.StringData;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;
@@ -50,6 +44,9 @@
import java.util.List;
import java.util.NoSuchElementException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* row batch data container.
*/
@@ -243,7 +240,7 @@ public void convertArrowToRowBatch() throws DorisException {
continue;
}
BigDecimal value = decimalVector.getObject(rowIndex).stripTrailingZeros();
addValueToRow(rowIndex, DecimalData.fromBigDecimal(value, value.precision(), value.scale()));
addValueToRow(rowIndex, value);
}
break;
case "DATE":
@@ -261,7 +258,7 @@ public void convertArrowToRowBatch() throws DorisException {
continue;
}
String value = new String(varCharVector.get(rowIndex));
addValueToRow(rowIndex, StringData.fromString(value));
addValueToRow(rowIndex, value);
}
break;
default:
@@ -33,6 +33,8 @@
import org.apache.flink.table.connector.source.abilities.SupportsFilterPushDown;
import org.apache.flink.table.connector.source.abilities.SupportsProjectionPushDown;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.types.logical.RowType;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@@ -80,7 +82,8 @@ public ScanRuntimeProvider getScanRuntimeProvider(ScanContext runtimeProviderCon
.setPassword(options.getPassword())
.setTableIdentifier(options.getTableIdentifier())
.setPartitions(dorisPartitions)
.setReadOptions(readOptions);
.setReadOptions(readOptions)
.setRowType((RowType) physicalSchema.toRowDataType().getLogicalType());
return InputFormatProvider.of(builder.build());
}

@@ -29,16 +29,23 @@
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.io.InputSplitAssigner;
import org.apache.flink.table.data.DecimalData;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.flink.table.data.StringData;
import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;

import java.io.IOException;
import java.math.BigDecimal;
import java.sql.PreparedStatement;
import java.util.ArrayList;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* InputFormat for {@link DorisDynamicTableSource}.
*/
@@ -56,10 +63,13 @@ public class DorisRowDataInputFormat extends RichInputFormat<RowData, DorisTable
private ScalaValueReader scalaValueReader;
private transient boolean hasNext;

public DorisRowDataInputFormat(DorisOptions options, List<PartitionDefinition> dorisPartitions, DorisReadOptions readOptions) {
private RowType rowType;

public DorisRowDataInputFormat(DorisOptions options, List<PartitionDefinition> dorisPartitions, DorisReadOptions readOptions, RowType rowType) {
this.options = options;
this.dorisPartitions = dorisPartitions;
this.readOptions = readOptions;
this.rowType = rowType;
}

@Override
@@ -136,15 +146,30 @@ public RowData nextRecord(RowData reuse) throws IOException {
return null;
}
List next = (List) scalaValueReader.next();
GenericRowData genericRowData = new GenericRowData(next.size());
for (int i = 0; i < next.size(); i++) {
genericRowData.setField(i, next.get(i));
GenericRowData genericRowData = new GenericRowData(rowType.getFieldCount());
for (int i = 0; i < next.size() && i < rowType.getFieldCount(); i++) {
Object value = deserialize(rowType.getTypeAt(i), next.get(i));
genericRowData.setField(i, value);
}
//update hasNext after we've read the record
hasNext = scalaValueReader.hasNext();
return genericRowData;
}

private Object deserialize(LogicalType type, Object val) {
switch (type.getTypeRoot()) {
case DECIMAL:
final DecimalType decimalType = ((DecimalType) type);
final int precision = decimalType.getPrecision();
final int scala = decimalType.getScale();
return DecimalData.fromBigDecimal((BigDecimal) val, precision, scala);
case VARCHAR:
return StringData.fromString((String) val);
default:
return val;
}
}

@Override
public BaseStatistics getStatistics(BaseStatistics cachedStatistics) throws IOException {
return cachedStatistics;
@@ -182,6 +207,7 @@ public static class Builder {
private DorisOptions.Builder optionsBuilder;
private List<PartitionDefinition> partitions;
private DorisReadOptions readOptions;
private RowType rowType;


public Builder() {
@@ -218,9 +244,14 @@ public Builder setReadOptions(DorisReadOptions readOptions) {
return this;
}

public Builder setRowType(RowType rowType) {
this.rowType = rowType;
return this;
}

public DorisRowDataInputFormat build() {
return new DorisRowDataInputFormat(
optionsBuilder.build(), partitions, readOptions
optionsBuilder.build(), partitions, readOptions, rowType
);
}
}
@@ -44,7 +44,6 @@
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;
import org.apache.flink.calcite.shaded.com.google.common.collect.Lists;
import org.apache.flink.table.data.DecimalData;
import org.apache.flink.table.data.StringData;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
@@ -248,10 +247,10 @@ public void testRowBatch() throws Exception {
1L,
(float) 1.1,
(double) 1.1,
StringData.fromString("2008-08-08"),
StringData.fromString("2008-08-08 00:00:00"),
"2008-08-08",
"2008-08-08 00:00:00",
DecimalData.fromBigDecimal(new BigDecimal(12.34), 4, 2),
StringData.fromString("char1")
"char1"
);

List<Object> expectedRow2 = Arrays.asList(
@@ -262,10 +261,10 @@ public void testRowBatch() throws Exception {
2L,
(float) 2.2,
(double) 2.2,
StringData.fromString("1900-08-08"),
StringData.fromString("1900-08-08 00:00:00"),
"1900-08-08",
"1900-08-08 00:00:00",
DecimalData.fromBigDecimal(new BigDecimal(88.88), 4, 2),
StringData.fromString("char2")
"char2"
);

List<Object> expectedRow3 = Arrays.asList(
@@ -276,22 +275,25 @@ public void testRowBatch() throws Exception {
3L,
(float) 3.3,
(double) 3.3,
StringData.fromString("2100-08-08"),
StringData.fromString("2100-08-08 00:00:00"),
"2100-08-08",
"2100-08-08 00:00:00",
DecimalData.fromBigDecimal(new BigDecimal(10.22), 4, 2),
StringData.fromString("char3")
"char3"
);

Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow1 = rowBatch.next();
actualRow1.set(9, DecimalData.fromBigDecimal((BigDecimal) actualRow1.get(9), 4, 2));
Assert.assertEquals(expectedRow1, actualRow1);

Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow2 = rowBatch.next();
actualRow2.set(9, DecimalData.fromBigDecimal((BigDecimal) actualRow2.get(9), 4, 2));
Assert.assertEquals(expectedRow2, actualRow2);

Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow3 = rowBatch.next();
actualRow3.set(9, DecimalData.fromBigDecimal((BigDecimal) actualRow3.get(9), 4, 2));
Assert.assertEquals(expectedRow3, actualRow3);

Assert.assertFalse(rowBatch.hasNext());
@@ -420,16 +422,18 @@ public void testDecimalV2() throws Exception {

Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow0 = rowBatch.next();
Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(12.340000000), 11, 9), actualRow0.get(0));
Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(12.340000000), 11, 9),
DecimalData.fromBigDecimal((BigDecimal) actualRow0.get(0), 11, 9));

Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow1 = rowBatch.next();

Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(88.880000000), 11, 9), actualRow1.get(0));
Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(88.880000000), 11, 9),
DecimalData.fromBigDecimal((BigDecimal) actualRow1.get(0), 11, 9));

Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow2 = rowBatch.next();
Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(10.000000000),11, 9), actualRow2.get(0));
Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(10.000000000), 11, 9),
DecimalData.fromBigDecimal((BigDecimal) actualRow2.get(0), 11, 9));

Assert.assertFalse(rowBatch.hasNext());
thrown.expect(NoSuchElementException.class);

0 comments on commit 19e24c7

Please sign in to comment.