Skip to content

Commit

Permalink
Add schemaFor method to get StructType from Scala Type (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
manuzhang authored and MrPowers committed Apr 25, 2019
1 parent e1550e2 commit 53ff11d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package com.github.mrpowers.spark.daria.sql.types

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.functions._

import scala.reflect.runtime.universe._

object StructTypeHelpers {

def flattenSchema(schema: StructType, delimiter: String = ".", prefix: String = null): Array[Column] = {
Expand All @@ -27,4 +30,19 @@ object StructTypeHelpers {
})
}

/**
* gets a StructType from a Scala type and
* transforms field names from camel case to snake case
*/
def schemaFor[T: TypeTag]: StructType = {
val struct = ScalaReflection.schemaFor[T]
.dataType.asInstanceOf[StructType]

struct.copy(fields =
struct.fields.map { field =>
field.copy(name = com.github.mrpowers.spark.daria.utils.StringHelpers.camelCaseToSnakeCase(field.name))
}
)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ object StructTypeHelpersTest extends TestSuite {

}

'schemaFor - {
"gets schema from a scala Type" - {
val actualSchema = StructTypeHelpers.schemaFor[FooBar]
val expectedSchema = StructType(List(
StructField("foo", IntegerType, false),
StructField("bar", StringType),
StructField("foo_bar", ArrayType(IntegerType, false))
))

assert(actualSchema == expectedSchema)
}
}

}

case class FooBar(foo: Int, bar: String, fooBar: Array[Int])
}

0 comments on commit 53ff11d

Please sign in to comment.