diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala index 8ab7372..8b6e235 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala @@ -67,4 +67,9 @@ object functions { substring(col, 0, len) } + /** Like array but doesn't include null elements */ + def arrayExNull(cols: Column*): Column = { + split(concat_ws(",,,", cols: _*), ",,,") + } + } diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsSpec.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsSpec.scala index 82431dc..7033c5c 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsSpec.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsSpec.scala @@ -2,7 +2,7 @@ package com.github.mrpowers.spark.daria.sql import java.sql.{Date, Timestamp} -import com.github.mrpowers.spark.fast.tests.DataFrameComparer +import com.github.mrpowers.spark.fast.tests.{ColumnComparer, DataFrameComparer} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.scalatest.FunSpec @@ -11,6 +11,7 @@ import SparkSessionExt._ class FunctionsSpec extends FunSpec with DataFrameComparer + with ColumnComparer with SparkSessionTestWrapper { describe("#singleSpace") { @@ -539,4 +540,44 @@ class FunctionsSpec } + describe("arrayExNull") { + + it("creates an array excluding null elements") { + + val sourceDF = spark.createDF( + List( + ("a", "b"), + (null, "b"), + ("a", null), + (null, null) + ), List( + ("c1", StringType, true), + ("c2", StringType, true) + ) + ) + + val actualDF = sourceDF.withColumn( + "mucho_cols", + functions.arrayExNull(col("c1"), col("c2")) + ) + + val expectedDF = spark.createDF( + List( + ("a", "b", Array("a", "b")), + (null, "b", Array("b")), + ("a", null, Array("a")), + (null, null, Array[String]()) + ), List( + ("c1", StringType, true), + ("c2", StringType, true), + ("mucho_cols", ArrayType(StringType, true), false) + ) + ) + + actualDF.collect().deep.equals(expectedDF.collect().deep) + + } + + } + }