Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9464][SQL] Property checks for UTF8String #7830

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions unsafe/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's used to compare levenshteinDistance between UTF8String and StringUtils

<scope>test</scope>
</dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand Down
19 changes: 9 additions & 10 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,9 @@ public UTF8String trim() {
int s = 0;
int e = this.numBytes - 1;
// skip all of the space (0x20) in the left side
while (s < this.numBytes && getByte(s) == 0x20) s++;
while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++;
// skip all of the space (0x20) in the right side
while (e >= 0 && getByte(e) == 0x20) e--;

while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--;
if (s > e) {
// empty string
return UTF8String.fromBytes(new byte[0]);
Expand All @@ -316,7 +315,7 @@ public UTF8String trim() {
public UTF8String trimLeft() {
int s = 0;
// skip all of the space (0x20) in the left side
while (s < this.numBytes && getByte(s) == 0x20) s++;
while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++;
if (s == this.numBytes) {
// empty string
return UTF8String.fromBytes(new byte[0]);
Expand All @@ -328,7 +327,7 @@ public UTF8String trimLeft() {
public UTF8String trimRight() {
int e = numBytes - 1;
// skip all of the space (0x20) in the right side
while (e >= 0 && getByte(e) == 0x20) e--;
while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--;

if (e < 0) {
// empty string
Expand All @@ -354,7 +353,7 @@ public UTF8String reverse() {
}

public UTF8String repeat(int times) {
if (times <=0) {
if (times <= 0) {
return EMPTY_UTF8;
}

Expand Down Expand Up @@ -414,7 +413,7 @@ public int indexOf(UTF8String v, int start) {
*/
public UTF8String rpad(int len, UTF8String pad) {
int spaces = len - this.numChars(); // number of char need to pad
if (spaces <= 0) {
if (spaces <= 0 || pad.numChars() == 0) {
// no padding at all, return the substring of the current string
return substring(0, len);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenghao-intel and I think it's more proper to return the original string when pad is an empty string. Hive didn't handle this situation and just hangs forever.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use numBytes

} else {
Expand All @@ -429,7 +428,7 @@ public UTF8String rpad(int len, UTF8String pad) {
int idx = 0;
while (idx < count) {
copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
++idx;
++ idx;
offset += pad.numBytes;
}
copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);
Expand All @@ -446,7 +445,7 @@ public UTF8String rpad(int len, UTF8String pad) {
*/
public UTF8String lpad(int len, UTF8String pad) {
int spaces = len - this.numChars(); // number of char need to pad
if (spaces <= 0) {
if (spaces <= 0 || pad.numChars() == 0) {
// no padding at all, return the substring of the current string
return substring(0, len);
} else {
Expand All @@ -461,7 +460,7 @@ public UTF8String lpad(int len, UTF8String pad) {
int idx = 0;
while (idx < count) {
copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
++idx;
++ idx;
offset += pad.numBytes;
}
copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ public void pad() {
assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????")));
assertEquals(fromString("???????"), EMPTY_UTF8.rpad(7, fromString("?????")));


assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????")));
assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????")));
assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????")));
Expand All @@ -289,6 +288,18 @@ public void pad() {
assertEquals(
fromString("数据砖头孙行者孙行者孙行"),
fromString("数据砖头").rpad(12, fromString("孙行者")));

assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, fromString("孙行者")));
assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, EMPTY_UTF8));
assertEquals(fromString("数据砖头"), fromString("数据砖头").lpad(5, EMPTY_UTF8));
assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, EMPTY_UTF8));
assertEquals(EMPTY_UTF8, EMPTY_UTF8.lpad(3, EMPTY_UTF8));

assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, fromString("孙行者")));
assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, EMPTY_UTF8));
assertEquals(fromString("数据砖头"), fromString("数据砖头").rpad(5, EMPTY_UTF8));
assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, EMPTY_UTF8));
assertEquals(EMPTY_UTF8, EMPTY_UTF8.rpad(3, EMPTY_UTF8));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.unsafe.types

import org.apache.commons.lang3.StringUtils

import org.scalacheck.{Arbitrary, Gen}
import org.scalatest.prop.GeneratorDrivenPropertyChecks
// scalastyle:off
import org.scalatest.{FunSuite, Matchers}

import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8}

class UTF8StringPropertyChecks extends FunSuite with GeneratorDrivenPropertyChecks with Matchers {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't this extend SparkFunSuite?

Also add "Suite" to the end of the class name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in the unsafe package, which doesn't have a dependency on Spark Core, so it can't extend SparkFunSuite.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK makes sense. We should still rename the package to end with Suite to follow our convention.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to also add scaladoc saying this is a randomized test suite for utf8string

// scalastyle:on

test("toString") {
forAll { (s: String) =>
assert(toUTF8(s).toString() === s)
}
}

test("numChars") {
forAll { (s: String) =>
assert(toUTF8(s).numChars() === s.length)
}
}

test("startsWith") {
forAll { (s: String) =>
val utf8 = toUTF8(s)
assert(utf8.startsWith(utf8))
for (i <- 1 to s.length) {
assert(utf8.startsWith(toUTF8(s.dropRight(i))))
}
}
}

test("endsWith") {
forAll { (s: String) =>
val utf8 = toUTF8(s)
assert(utf8.endsWith(utf8))
for (i <- 1 to s.length) {
assert(utf8.endsWith(toUTF8(s.drop(i))))
}
}
}

test("toUpperCase") {
forAll { (s: String) =>
assert(toUTF8(s).toUpperCase === toUTF8(s.toUpperCase))
}
}

test("toLowerCase") {
forAll { (s: String) =>
assert(toUTF8(s).toLowerCase === toUTF8(s.toLowerCase))
}
}

test("compare") {
forAll { (s1: String, s2: String) =>
assert(Math.signum(toUTF8(s1).compareTo(toUTF8(s2))) === Math.signum(s1.compareTo(s2)))
}
}

test("substring") {
forAll { (s: String) =>
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
assert(toUTF8(s).substring(start, end).toString === s.substring(start, end))
}
}
}

test("contains") {
forAll { (s: String) =>
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
val substring = s.substring(start, end)
assert(toUTF8(s).contains(toUTF8(substring)) === s.contains(substring))
}
}
}

val whitespaceChar: Gen[Char] = Gen.choose(0x00, 0x20).map(_.toChar)
val whitespaceString: Gen[String] = Gen.listOf(whitespaceChar).map(_.mkString)
val randomString: Gen[String] = Arbitrary.arbString.arbitrary

test("trim, trimLeft, trimRight") {
// lTrim and rTrim are both modified from java.lang.String.trim
def lTrim(s: String): String = {
var st = 0
val array: Array[Char] = s.toCharArray
while ((st < s.length) && (array(st) <= ' ')) {
st += 1
}
if (st > 0) s.substring(st, s.length) else s
}
def rTrim(s: String): String = {
var len = s.length
val array: Array[Char] = s.toCharArray
while ((len > 0) && (array(len - 1) <= ' ')) {
len -= 1
}
if (len < s.length) s.substring(0, len) else s
}

forAll(
whitespaceString,
randomString,
whitespaceString
) { (start: String, middle: String, end: String) =>
val s = start + middle + end
assert(toUTF8(s).trim() === toUTF8(s.trim()))
assert(toUTF8(s).trimLeft() === toUTF8(lTrim(s)))
assert(toUTF8(s).trimRight() === toUTF8(rTrim(s)))
}
}

test("reverse") {
forAll { (s: String) =>
assert(toUTF8(s).reverse === toUTF8(s.reverse))
}
}

test("indexOf") {
forAll { (s: String) =>
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
val substring = s.substring(start, end)
assert(toUTF8(s).indexOf(toUTF8(substring), 0) === s.indexOf(substring))
}
}
}

val randomInt = Gen.choose(-100, 100)

test("repeat") {
def repeat(str: String, times: Int): String = {
if (times > 0) str * times else ""
}
// ScalaCheck always generating too large repeat times which might hang the test forever.
forAll(randomString, randomInt) { (s: String, times: Int) =>
assert(toUTF8(s).repeat(times) === toUTF8(repeat(s, times)))
}
}

test("lpad, rpad") {
def padding(origin: String, pad: String, length: Int, isLPad: Boolean): String = {
if (length <= 0) return ""
if (length <= origin.length) {
if (length <= 0) "" else origin.substring(0, length)
} else {
if (pad.length == 0) return origin
val toPad = length - origin.length
val partPad = if (toPad % pad.length == 0) "" else pad.substring(0, toPad % pad.length)
if (isLPad) {
pad * (toPad / pad.length) + partPad + origin
} else {
origin + pad * (toPad / pad.length) + partPad
}
}
}

forAll (
randomString,
randomString,
randomInt
) { (s: String, pad: String, length: Int) =>
assert(toUTF8(s).lpad(length, toUTF8(pad)) ===
toUTF8(padding(s, pad, length, true)))
assert(toUTF8(s).rpad(length, toUTF8(pad)) ===
toUTF8(padding(s, pad, length, false)))
}
}

val nullalbeSeq = Gen.listOf(Gen.oneOf[String](null: String, randomString))

test("concat") {
def concat(orgin: Seq[String]): String =
if (orgin.exists(_ == null)) null else orgin.mkString

forAll { (inputs: Seq[String]) =>
assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(inputs.mkString))
}
forAll (nullalbeSeq) { (inputs: Seq[String]) =>
assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(concat(inputs)))
}
}

test("concatWs") {
def concatWs(sep: String, inputs: Seq[String]): String = {
if (sep == null) return null
inputs.filter(_ != null).mkString(sep)
}

forAll { (sep: String, inputs: Seq[String]) =>
assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) ===
toUTF8(inputs.mkString(sep)))
}
forAll(randomString, nullalbeSeq) {(sep: String, inputs: Seq[String]) =>
assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) ===
toUTF8(concatWs(sep, inputs)))
}
}

// TODO: enable this when we find a proper way to generate valid patterns
ignore("split") {
forAll { (s: String, pattern: String, limit: Int) =>
assert(toUTF8(s).split(toUTF8(pattern), limit) ===
s.split(pattern, limit).map(toUTF8(_)))
}
}

test("levenshteinDistance") {
forAll { (one: String, another: String) =>
assert(toUTF8(one).levenshteinDistance(toUTF8(another)) ===
StringUtils.getLevenshteinDistance(one, another))
}
}

test("hashCode") {
forAll { (s: String) =>
assert(toUTF8(s).hashCode() === toUTF8(s).hashCode())
}
}

test("equals") {
forAll { (one: String, another: String) =>
assert(toUTF8(one).equals(toUTF8(another)) === one.equals(another))
}
}
}