Skip to content
Permalink
Browse files
DRILL-6472: Prevent using zero precision in CAST function
- Add check for the correctness of scale value;
- Add check for fitting the value to the value with the concrete scale and precision;
- Implement negative UDF for VarDecimal
- Add unit tests for new checks and UDF.
  • Loading branch information
vvysotskyi authored and arina-ielchiieva committed Jul 13, 2018
1 parent eb90ebd commit c39ba74796d1f47887306ae81aa70ccf454effb3
Showing 14 changed files with 227 additions and 104 deletions.
@@ -64,17 +64,20 @@ public void setup() {
public void eval() {
java.math.BigDecimal bd =
<#if type.from == "Decimal9" || type.from == "Decimal18">
java.math.BigDecimal.valueOf(in.value)
java.math.BigDecimal.valueOf(in.value);
<#else>
org.apache.drill.exec.util.DecimalUtility
<#if type.from.contains("Sparse")>
.getBigDecimalFromDrillBuf(in.buffer, in.start, in.nDecimalDigits, in.scale, true)
.getBigDecimalFromDrillBuf(in.buffer, in.start, in.nDecimalDigits, in.scale, true);
<#elseif type.from == "VarDecimal">
.getBigDecimalFromDrillBuf(in.buffer, in.start, in.end - in.start, in.scale)
.getBigDecimalFromDrillBuf(in.buffer, in.start, in.end - in.start, in.scale);
</#if>
</#if>
.setScale(scale.value, java.math.RoundingMode.HALF_UP)
.round(new java.math.MathContext(precision.value, java.math.RoundingMode.HALF_UP));

org.apache.drill.exec.util.DecimalUtility.checkValueOverflow(bd, precision.value, scale.value);

bd = bd.setScale(scale.value, java.math.RoundingMode.HALF_UP);

out.scale = scale.value;
out.precision = precision.value;
out.start = 0;
@@ -71,12 +71,11 @@ public void eval() {

out.start = 0;
java.math.BigDecimal bd =
new java.math.BigDecimal(
String.valueOf(in.value),
new java.math.MathContext(
precision.value,
java.math.RoundingMode.HALF_UP))
.setScale(scale.value, java.math.RoundingMode.HALF_UP);
new java.math.BigDecimal(String.valueOf(in.value));

org.apache.drill.exec.util.DecimalUtility.checkValueOverflow(bd, precision.value, scale.value);

bd = bd.setScale(scale.value, java.math.RoundingMode.HALF_UP);

byte[] bytes = bd.unscaledValue().toByteArray();
int len = bytes.length;
@@ -67,9 +67,11 @@ public void eval() {

out.start = 0;
out.buffer = buffer;
java.math.BigDecimal bd = new java.math.BigDecimal(in.value,
new java.math.MathContext(precision.value, java.math.RoundingMode.HALF_UP))
.setScale(out.scale, java.math.BigDecimal.ROUND_DOWN);
java.math.BigDecimal bd = new java.math.BigDecimal(in.value);

org.apache.drill.exec.util.DecimalUtility.checkValueOverflow(bd, precision.value, scale.value);

bd = bd.setScale(out.scale, java.math.BigDecimal.ROUND_DOWN);

byte[] bytes = bd.unscaledValue().toByteArray();
int len = bytes.length;
@@ -93,12 +93,12 @@ public void eval() {
byte[] buf = new byte[in.end - in.start];
in.buffer.getBytes(in.start, buf, 0, in.end - in.start);
String s = new String(buf, com.google.common.base.Charsets.UTF_8);
java.math.BigDecimal bd =
new java.math.BigDecimal(s,
new java.math.MathContext(
precision.value,
java.math.RoundingMode.HALF_UP))
.setScale(scale.value, java.math.RoundingMode.HALF_UP);
java.math.BigDecimal bd = new java.math.BigDecimal(s);

org.apache.drill.exec.util.DecimalUtility.checkValueOverflow(bd, precision.value, scale.value);

bd = bd.setScale(scale.value, java.math.RoundingMode.HALF_UP);

byte[] bytes = bd.unscaledValue().toByteArray();
int len = bytes.length;
out.buffer = buffer.reallocIfNeeded(len);
@@ -193,16 +193,18 @@ public void eval() {

</#list>

<#list ["Abs", "Ceil", "Floor", "Trunc", "Round"] as functionName>
<#list ["Abs", "Ceil", "Floor", "Trunc", "Round", "Negative"] as functionName>
<#if functionName == "Ceil">
@FunctionTemplate(names = {"ceil", "ceiling"},
<#elseif functionName == "Trunc">
@FunctionTemplate(names = {"trunc", "truncate"},
<#elseif functionName == "Negative">
@FunctionTemplate(names = {"negative", "u-", "-"},
<#else>
@FunctionTemplate(name = "${functionName?lower_case}",
</#if>
scope = FunctionTemplate.FunctionScope.SIMPLE,
<#if functionName == "Abs">
<#if functionName == "Abs" || functionName == "Negative">
returnType = FunctionTemplate.ReturnType.DECIMAL_MAX_SCALE,
<#elseif functionName == "Ceil" || functionName == "Floor"
|| functionName == "Trunc" || functionName == "Round">
@@ -226,6 +228,9 @@ public void eval() {
.getBigDecimalFromDrillBuf(in.buffer, in.start, in.end - in.start, in.scale)
<#if functionName == "Abs">
.abs();
result.scale = in.scale;
<#elseif functionName == "Negative">
.negate();
result.scale = in.scale;
<#elseif functionName == "Ceil">
.setScale(0, java.math.BigDecimal.ROUND_CEILING);
@@ -86,6 +86,7 @@

import com.google.common.base.Joiner;
import org.apache.drill.exec.store.ColumnExplorer;
import org.apache.drill.exec.util.DecimalUtility;

/**
* Class responsible for managing parsing, validation and toRel conversion for sql statements.
@@ -562,10 +563,22 @@ public RexNode makeCast(RelDataType type, RexNode exp, boolean matchNullability)
// that differs from the value from specified RelDataType, cast cannot be removed
// TODO: remove this code when CALCITE-1468 is fixed
if (type.getSqlTypeName() == SqlTypeName.DECIMAL && exp instanceof RexLiteral) {
if (type.getPrecision() < 1) {
throw UserException.validationError()
.message("Expected precision greater than 0, but was %s.", type.getPrecision())
.build(logger);
}
if (type.getScale() > type.getPrecision()) {
throw UserException.validationError()
.message("Expected scale less than or equal to precision, " +
"but was scale %s and precision %s.", type.getScale(), type.getPrecision())
.build(logger);
}
RexLiteral literal = (RexLiteral) exp;
Comparable value = literal.getValueAs(Comparable.class);
if (value instanceof BigDecimal) {
BigDecimal bigDecimal = (BigDecimal) value;
DecimalUtility.checkValueOverflow(bigDecimal, type.getPrecision(), type.getScale());
if (bigDecimal.scale() != type.getScale() || bigDecimal.precision() != type.getPrecision()) {
return makeAbstractCast(type, exp);
}
@@ -953,7 +953,7 @@ public void testNegate() throws Exception {
.sqlQuery(query)
.unOrdered()
.baselineColumns("col1")
.baselineValues(-1.1)
.baselineValues(new BigDecimal("-1.1"))
.go();
}
}
@@ -18,29 +18,35 @@
package org.apache.drill.exec.fn.impl;

import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.time.LocalDate;
import java.util.List;
import java.util.Map;

import org.apache.drill.categories.SqlFunctionTest;
import org.apache.drill.categories.UnlikelyTest;
import org.apache.drill.common.exceptions.UserRemoteException;
import org.apache.drill.exec.planner.physical.PlannerSettings;
import org.apache.drill.test.BaseTestQuery;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import mockit.integration.junit4.JMockit;

import static org.hamcrest.CoreMatchers.containsString;

@RunWith(JMockit.class)
@Category({UnlikelyTest.class, SqlFunctionTest.class})
public class TestCastFunctions extends BaseTestQuery {

@Rule
public ExpectedException thrown = ExpectedException.none();

@Test
public void testVarbinaryToDate() throws Exception {
testBuilder()
@@ -380,22 +386,14 @@ public void testCastIntAndBigIntToDecimal() throws Exception {
.baselineValues(new BigDecimal(1), new BigDecimal(1), new BigDecimal(1), new BigDecimal(1))
.baselineValues(new BigDecimal(-1), new BigDecimal(-1), new BigDecimal(-1), new BigDecimal(-1))

.baselineValues(new BigDecimal(Integer.MAX_VALUE)
.round(new MathContext(9, RoundingMode.HALF_UP))
.setScale(0, RoundingMode.HALF_UP),
.baselineValues(new BigDecimal(Integer.MAX_VALUE),
new BigDecimal(Integer.MAX_VALUE),
new BigDecimal(Long.MAX_VALUE)
.round(new MathContext(9, RoundingMode.HALF_UP))
.setScale(0, RoundingMode.HALF_UP),
new BigDecimal(Long.MAX_VALUE),
new BigDecimal(Long.MAX_VALUE))

.baselineValues(new BigDecimal(Integer.MIN_VALUE)
.round(new MathContext(9, RoundingMode.HALF_UP))
.setScale(0, RoundingMode.HALF_UP),
.baselineValues(new BigDecimal(Integer.MIN_VALUE),
new BigDecimal(Integer.MIN_VALUE),
new BigDecimal(Long.MIN_VALUE)
.round(new MathContext(9, RoundingMode.HALF_UP))
.setScale(0, RoundingMode.HALF_UP),
new BigDecimal(Long.MIN_VALUE),
new BigDecimal(Long.MIN_VALUE))

.baselineValues(new BigDecimal(123456789),
@@ -421,21 +419,13 @@ public void testCastDecimalToIntAndBigInt() throws Exception {
.baselineValues(0, 0, 0L, 0L)
.baselineValues(1, 1, 1L, 1L)
.baselineValues(-1, -1, -1L, -1L)
.baselineValues(new BigDecimal(Integer.MAX_VALUE)
.round(new MathContext(9, RoundingMode.HALF_UP))
.setScale(0, RoundingMode.HALF_UP).intValue(),
.baselineValues(Integer.MAX_VALUE,
(int) Long.MAX_VALUE,
new BigDecimal(Integer.MAX_VALUE)
.round(new MathContext(9, RoundingMode.HALF_UP))
.setScale(0, RoundingMode.HALF_UP).longValue(),
(long) Integer.MAX_VALUE,
Long.MAX_VALUE)
.baselineValues(new BigDecimal(Integer.MIN_VALUE)
.round(new MathContext(9, RoundingMode.HALF_UP))
.setScale(0, RoundingMode.HALF_UP).intValue(),
.baselineValues(Integer.MIN_VALUE,
(int) Long.MIN_VALUE,
new BigDecimal(Integer.MIN_VALUE)
.round(new MathContext(9, RoundingMode.HALF_UP))
.setScale(0, RoundingMode.HALF_UP).longValue(),
(long) Integer.MIN_VALUE,
Long.MIN_VALUE)
.baselineValues(123456789, 123456789, 123456789L, 123456789L)
.go();
@@ -604,4 +594,74 @@ public void testCastDecimalLiteral() throws Exception {
.baselineValues(new BigDecimal("100.00"))
.go();
}

@Test
public void testCastDecimalZeroPrecision() throws Exception {
String query = "select cast('123.0' as decimal(0, 5))";

thrown.expect(UserRemoteException.class);
thrown.expectMessage(containsString("VALIDATION ERROR: Expected precision greater than 0, but was 0"));

test(query);
}

@Test
public void testCastDecimalGreaterScaleThanPrecision() throws Exception {
String query = "select cast('123.0' as decimal(3, 5))";

thrown.expect(UserRemoteException.class);
thrown.expectMessage(containsString("VALIDATION ERROR: Expected scale less than or equal to precision, but was scale 5 and precision 3"));

test(query);
}

@Test
public void testCastIntDecimalOverflow() throws Exception {
String query = "select cast(i1 as DECIMAL(4, 0)) as s1 from (select cast(123456 as int) as i1)";

thrown.expect(UserRemoteException.class);
thrown.expectMessage(containsString("VALIDATION ERROR: Value 123456 overflows specified precision 4 with scale 0"));

test(query);
}

@Test
public void testCastBigIntDecimalOverflow() throws Exception {
String query = "select cast(i1 as DECIMAL(4, 0)) as s1 from (select cast(123456 as bigint) as i1)";

thrown.expect(UserRemoteException.class);
thrown.expectMessage(containsString("VALIDATION ERROR: Value 123456 overflows specified precision 4 with scale 0"));

test(query);
}

@Test
public void testCastFloatDecimalOverflow() throws Exception {
String query = "select cast(i1 as DECIMAL(4, 0)) as s1 from (select cast(123456.123 as float) as i1)";

thrown.expect(UserRemoteException.class);
thrown.expectMessage(containsString("VALIDATION ERROR: Value 123456.123 overflows specified precision 4 with scale 0"));

test(query);
}

@Test
public void testCastDoubleDecimalOverflow() throws Exception {
String query = "select cast(i1 as DECIMAL(4, 0)) as s1 from (select cast(123456.123 as double) as i1)";

thrown.expect(UserRemoteException.class);
thrown.expectMessage(containsString("VALIDATION ERROR: Value 123456.123 overflows specified precision 4 with scale 0"));

test(query);
}

@Test
public void testCastVarCharDecimalOverflow() throws Exception {
String query = "select cast(i1 as DECIMAL(4, 0)) as s1 from (select cast(123456.123 as varchar) as i1)";

thrown.expect(UserRemoteException.class);
thrown.expectMessage(containsString("VALIDATION ERROR: Value 123456.123 overflows specified precision 4 with scale 0"));

test(query);
}
}

0 comments on commit c39ba74

Please sign in to comment.