diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/decoders/StringDecoders.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/decoders/StringDecoders.scala index 6feb545c1..c9eda476a 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/decoders/StringDecoders.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/decoders/StringDecoders.scala @@ -43,7 +43,7 @@ object StringDecoders { * @param improvedNullDetection if true, return null if all bytes are zero * @return A string representation of the binary data */ - def decodeEbcdicString(bytes: Array[Byte], trimmingType: Int, conversionTable: Array[Char], improvedNullDetection: Boolean): String = { + final def decodeEbcdicString(bytes: Array[Byte], trimmingType: Int, conversionTable: Array[Char], improvedNullDetection: Boolean): String = { if (improvedNullDetection && isArrayNull(bytes)) return null @@ -73,7 +73,7 @@ object StringDecoders { * @param improvedNullDetection if true, return null if all bytes are zero * @return A string representation of the binary data */ - def decodeAsciiString(bytes: Array[Byte], trimmingType: Int, improvedNullDetection: Boolean): String = { + final def decodeAsciiString(bytes: Array[Byte], trimmingType: Int, improvedNullDetection: Boolean): String = { if (improvedNullDetection && isArrayNull(bytes)) return null @@ -105,7 +105,7 @@ object StringDecoders { * @param improvedNullDetection if true, return null if all bytes are zero * @return A string representation of the binary data */ - def decodeUtf16String(bytes: Array[Byte], trimmingType: Int, isUtf16BigEndian: Boolean, improvedNullDetection: Boolean): String = { + final def decodeUtf16String(bytes: Array[Byte], trimmingType: Int, isUtf16BigEndian: Boolean, improvedNullDetection: Boolean): String = { if (improvedNullDetection && isArrayNull(bytes)) return null @@ -132,7 +132,7 @@ object StringDecoders { * @param bytes A byte array that represents the binary data * @return A HEX string representation of the binary data */ - def decodeHex(bytes: Array[Byte]): String = { + final def decodeHex(bytes: Array[Byte]): String = { val hexChars = new Array[Char](bytes.length * 2) var i = 0 while (i < bytes.length) { @@ -150,7 +150,7 @@ object StringDecoders { * @param bytes A byte array that represents the binary data * @return A string representation of the bytes */ - def decodeRaw(bytes: Array[Byte]): Array[Byte] = bytes + final def decodeRaw(bytes: Array[Byte]): Array[Byte] = bytes /** * A decoder for any EBCDIC uncompressed numbers supporting @@ -165,7 +165,7 @@ object StringDecoders { * @param improvedNullDetection if true, return null if all bytes are zero * @return A string representation of the binary data */ - def decodeEbcdicNumber(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): String = { + final def decodeEbcdicNumber(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): String = { if (improvedNullDetection && isArrayNull(bytes)) return null @@ -236,7 +236,8 @@ object StringDecoders { * @param improvedNullDetection if true, return null if all bytes are zero * @return A string representation of the binary data */ - def decodeAsciiNumber(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): String = { + final def decodeAsciiNumber(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): String = { + val allowedDigitChars = " 0123456789" if (improvedNullDetection && isArrayNull(bytes)) return null @@ -251,7 +252,10 @@ object StringDecoders { if (char == '.' || char == ',') { buf.append('.') } else { - buf.append(char) + if (allowedDigitChars.contains(char)) + buf.append(char) + else + return null } } i = i + 1 @@ -269,7 +273,7 @@ object StringDecoders { * @param bytes A byte array that represents the binary data * @return A boxed integer */ - def decodeEbcdicInt(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): Integer = { + final def decodeEbcdicInt(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): Integer = { try { decodeEbcdicNumber(bytes, isUnsigned, improvedNullDetection).toInt } catch { @@ -283,7 +287,7 @@ object StringDecoders { * @param bytes A byte array that represents the binary data * @return A boxed integer */ - def decodeAsciiInt(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): Integer = { + final def decodeAsciiInt(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): Integer = { try { decodeAsciiNumber(bytes, isUnsigned, improvedNullDetection).toInt } catch { @@ -297,7 +301,7 @@ object StringDecoders { * @param bytes A byte array that represents the binary data * @return A boxed long */ - def decodeEbcdicLong(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): java.lang.Long = { + final def decodeEbcdicLong(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): java.lang.Long = { try { decodeEbcdicNumber(bytes, isUnsigned, improvedNullDetection).toLong } catch { @@ -311,7 +315,7 @@ object StringDecoders { * @param bytes A byte array that represents the binary data * @return A boxed long */ - def decodeAsciiLong(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): java.lang.Long = { + final def decodeAsciiLong(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): java.lang.Long = { try { decodeAsciiNumber(bytes, isUnsigned, improvedNullDetection).toLong } catch { @@ -327,7 +331,7 @@ object StringDecoders { * @param scaleFactor Additional zeros to be added before of after the decimal point * @return A big decimal containing a big integral number */ - def decodeEbcdicBigNumber(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean, scale: Int = 0, scaleFactor: Int = 0): BigDecimal = { + final def decodeEbcdicBigNumber(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean, scale: Int = 0, scaleFactor: Int = 0): BigDecimal = { try { BigDecimal(BinaryUtils.addDecimalPoint(decodeEbcdicNumber(bytes, isUnsigned, improvedNullDetection), scale, scaleFactor)) } catch { @@ -343,7 +347,7 @@ object StringDecoders { * @param scaleFactor Additional zeros to be added before of after the decimal point * @return A big decimal containing a big integral number */ - def decodeAsciiBigNumber(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean, scale: Int = 0, scaleFactor: Int = 0): BigDecimal = { + final def decodeAsciiBigNumber(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean, scale: Int = 0, scaleFactor: Int = 0): BigDecimal = { try { BigDecimal(BinaryUtils.addDecimalPoint(decodeAsciiNumber(bytes, isUnsigned, improvedNullDetection), scale, scaleFactor)) } catch { @@ -358,7 +362,7 @@ object StringDecoders { * @param bytes A byte array that represents the binary data * @return A big decimal containing a big integral number */ - def decodeEbcdicBigDecimal(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): BigDecimal = { + final def decodeEbcdicBigDecimal(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): BigDecimal = { try { BigDecimal(decodeEbcdicNumber(bytes, isUnsigned, improvedNullDetection)) } catch { @@ -373,7 +377,7 @@ object StringDecoders { * @param bytes A byte array that represents the binary data * @return A big decimal containing a big integral number */ - def decodeAsciiBigDecimal(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): BigDecimal = { + final def decodeAsciiBigDecimal(bytes: Array[Byte], isUnsigned: Boolean, improvedNullDetection: Boolean): BigDecimal = { try { BigDecimal(decodeAsciiNumber(bytes, isUnsigned, improvedNullDetection)) } catch { diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/decoders/StringTools.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/decoders/StringTools.scala index 761d96fca..ac6a3f641 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/decoders/StringTools.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/decoders/StringTools.scala @@ -25,7 +25,7 @@ object StringTools { * @param s A string * @return The trimmed string */ - def trimLeft(s: String): String = { + final def trimLeft(s: String): String = { val len = s.length var st = 0 val v = s.toCharArray @@ -46,7 +46,7 @@ object StringTools { * @param s A string * @return The trimmed string */ - def trimRight(s: String): String = { + final def trimRight(s: String): String = { var len = s.length val st = 0 val v = s.toCharArray @@ -60,7 +60,7 @@ object StringTools { else s } - def isArrayNull(bytes: Array[Byte]): Boolean = { + final def isArrayNull(bytes: Array[Byte]): Boolean = { var i = 0 val size = bytes.length while (i < size) { diff --git a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/decoders/StringDecodersSpec.scala b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/decoders/StringDecodersSpec.scala index 9c11f753b..1643c482d 100644 --- a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/decoders/StringDecodersSpec.scala +++ b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/decoders/StringDecodersSpec.scala @@ -273,8 +273,8 @@ class StringDecodersSpec extends WordSpec { assert(decodeAsciiNumber("100,00-".getBytes, isUnsigned = false, improvedNullDetection = false) == "-100.00") } - "return trimmed string if non-digit characters are encountered" in { - assert(decodeAsciiNumber("AAABBBCCC".getBytes, isUnsigned = false, improvedNullDetection = false) == "AAABBBCCC") + "return null if non-digit characters are encountered" in { + assert(decodeAsciiNumber("AAABBBCCC".getBytes, isUnsigned = false, improvedNullDetection = false) == null) } } @@ -458,8 +458,8 @@ class StringDecodersSpec extends WordSpec { assert(decodeAsciiBigDecimal("12345678901234567890123456.12345678901234567890123456".getBytes, isUnsigned = true, improvedNullDetection = false) == BigDecimal("12345678901234567890123456.12345678901234567890123456")) } - "decode numbers in scientific format" in { - assert(decodeAsciiBigDecimal("200E+10".getBytes, isUnsigned = false, improvedNullDetection = false) == 2.00E+12) + "not decode numbers in scientific format" in { + assert(decodeAsciiBigDecimal("200E+10".getBytes, isUnsigned = false, improvedNullDetection = false) == null) } "return null for malformed numbers" in { diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/regression/Test15AsciiNumberOverflow.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/regression/Test15AsciiNumberOverflow.scala new file mode 100644 index 000000000..1d427ca3b --- /dev/null +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/regression/Test15AsciiNumberOverflow.scala @@ -0,0 +1,237 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.spark.cobol.source.regression + +import org.scalatest.WordSpec +import org.slf4j.{Logger, LoggerFactory} +import za.co.absa.cobrix.spark.cobol.source.base.{SimpleComparisonBase, SparkTestBase} +import za.co.absa.cobrix.spark.cobol.source.fixtures.BinaryFileFixture +import za.co.absa.cobrix.spark.cobol.utils.SparkUtils + +import java.nio.charset.StandardCharsets + +/** + * This suite checks if Spark is able to read numbers from ASCII files that overflow the expected data type. + */ +class Test15AsciiNumberOverflow extends WordSpec with SparkTestBase with BinaryFileFixture with SimpleComparisonBase { + private implicit val logger: Logger = LoggerFactory.getLogger(this.getClass) + + "Test ASCII CRLF text file with various numeric fields" should { + "decode integral types" when { + val copybook1 = + """ 01 R. + 05 N1 PIC 9(1). + 05 N2 PIC 9(2). + 05 N3 PIC +9(2). + 05 N4 PIC +9(18). + 05 N5 PIC +9(20). + """ + + val textFileContents: String = "122+33+123456789012345678+12345678901234567890\n3445551234567890123456789123456789012345678901\n" + + withTempTextFile("num_overflow", ".dat", StandardCharsets.UTF_8, textFileContents) { tmpFileName => + val df = spark + .read + .format("cobol") + .option("copybook_contents", copybook1) + .option("pedantic", "true") + .option("is_text", "true") + .option("encoding", "ascii") + .option("schema_retention_policy", "collapse_root") + .load(tmpFileName) + + val actualSchema = df.schema.treeString + val actualData = SparkUtils.prettyJSON(df.toJSON.collect().mkString("[", ",", "]")) + + "schema should match" in { + val expectedSchema = + """root + | |-- N1: integer (nullable = true) + | |-- N2: integer (nullable = true) + | |-- N3: integer (nullable = true) + | |-- N4: long (nullable = true) + | |-- N5: decimal(20,0) (nullable = true)""".stripMargin + + assertEqualsMultiline(actualSchema, expectedSchema) + } + + "data should match" in { + val expectedData = + """[ { + | "N1" : 1, + | "N2" : 22, + | "N3" : 33, + | "N4" : 123456789012345678, + | "N5" : 12345678901234567890 + |}, { + | "N1" : 3, + | "N2" : 44, + | "N3" : 555, + | "N4" : 1234567890123456789 + |} ]""".stripMargin + + assertEqualsMultiline(actualData, expectedData) + } + } + } + + "decode decimal types" in { + val copybook2 = + """ 01 R. + 05 D1 PIC 9V9. + 05 D2 PIC 9.9. + 05 D3 PIC +99.999. + """ + val textFileContents: String = "112.2+10.123\n334.4+55.666\n778.8999.999\n889.9-110222\n991.122" + + withTempTextFile("num_overflow", ".dat", StandardCharsets.UTF_8, textFileContents) { tmpFileName => + val df = spark + .read + .format("cobol") + .option("copybook_contents", copybook2) + .option("pedantic", "true") + .option("is_text", "true") + .option("encoding", "ascii") + .option("schema_retention_policy", "collapse_root") + .load(tmpFileName) + + val expected = + """[ { + | "D1" : 1.1, + | "D2" : 2.2, + | "D3" : 10.123 + |}, { + | "D1" : 3.3, + | "D2" : 4.4, + | "D3" : 55.666 + |}, { + | "D1" : 7.7, + | "D2" : 8.8 + |}, { + | "D1" : 8.8, + | "D2" : 9.9 + |}, { + | "D1" : 9.9, + | "D2" : 1.1 + |} ]""".stripMargin + + val actual = SparkUtils.prettyJSON(df.toJSON.collect().mkString("[", ",", "]")) + + assertEqualsMultiline(actual, expected) + } + } + } + + // Ignore exhaustive overflow tests since they take too much time. These tests were used to catch overflow exceptions + "Integral variants" ignore { + for (len <- Range(2, 40)) { + s"parse integers with length $len" in { + val fieldPic = "9" * len + + val copybook = + s""" 01 R. + 05 F PIC +$fieldPic. + """ + val n1 = getNumber(len-1, 0, sign = true) + val n2 = getNumber(len-1, 0, sign = false) + val n3 = getNumber(len, 0, sign = true) + val n4 = getNumber(len, 0, sign = false) + val n5 = getNumber(len+1, 0, sign = true) + val n6 = getNumber(len+1, 0, sign = false) + val n7 = getNumber(len+2, 0, sign = true) + val n8 = getNumber(len+2, 0, sign = false) + val n9 = getNumber(len+3, 0, sign = true) + val n10 = getNumber(len+3, 0, sign = false) + + val asciiFile = s"$n1\n$n2\n$n3\n$n4\n$n5\n$n6\n$n7\n$n8\n$n9\n$n10" + + //println(asciiFile) + withTempTextFile("num_overflow", ".dat", StandardCharsets.UTF_8, asciiFile) { tmpFileName => + val df = spark + .read + .format("cobol") + .option("copybook_contents", copybook) + .option("pedantic", "true") + .option("is_text", "true") + .option("encoding", "ascii") + .option("schema_retention_policy", "collapse_root") + .load(tmpFileName) + + df.count + } + } + } + } + + // Ignore exhaustive overflow tests since they take too much time. These tests were used to catch overflow exceptions + "Decimal variants" ignore { + for (len <- Range(2, 24)) { + s"parse decimal with length $len" when { + for (dec <- Range(1, len)) { + s"decimal point is placed at $dec" in { + val fieldPic = "9" * dec + "." + "9" * (len - dec) + + val copybook = + s""" 01 R. + 05 F PIC +$fieldPic. + """ + val n1 = getNumber(len, dec, sign = true) + val n2 = getNumber(len+1, dec, sign = false) + val n3 = getNumber(len+2, 0, sign = false) + val n4 = getNumber(len+1, 0, sign = false) + val n5 = getNumber(len+1, 0, sign = true) + val n6 = getNumber(len+2, dec, sign = true) + val n7 = getNumber(len+2, dec, sign = false) + val n8 = getNumber(len+3, dec, sign = true) + val n9 = getNumber(len+3, dec, sign = false) + + val asciiFile = s"$n1\n$n2\n$n3\n$n4\n$n5\n$n6\n$n7\n$n8\n$n9" + + //println(asciiFile) + withTempTextFile("num_overflow", ".dat", StandardCharsets.UTF_8, asciiFile) { tmpFileName => + val df = spark + .read + .format("cobol") + .option("copybook_contents", copybook) + .option("pedantic", "true") + .option("is_text", "true") + .option("encoding", "ascii") + .option("schema_retention_policy", "collapse_root") + .load(tmpFileName) + + df.count + } + + } + + } + } + } + } + + private def getNumber(len: Int, dec: Int, sign: Boolean): String = { + val model = "123456789012345678901234567890" + val s = if (sign) "+" else "" + val num = if (dec <= 0) { + model.take(len) + } else { + model.take(dec) + "." + model.take(len-dec) + } + + s + num + } +}