Skip to content

Commit

Permalink
Add a common utility code to map short names to fully-qualified codec…
Browse files Browse the repository at this point in the history
… names
  • Loading branch information
maropu committed Feb 27, 2016
1 parent f77dc4e commit 25e9250
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 57 deletions.
22 changes: 13 additions & 9 deletions core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream}

import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
import org.apache.spark.util.{ShortCompressionCodecNameMapper, Utils}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -53,10 +53,14 @@ private[spark] object CompressionCodec {
|| codec.isInstanceOf[LZ4CompressionCodec])
}

private val shortCompressionCodecNames = Map(
"lz4" -> classOf[LZ4CompressionCodec].getName,
"lzf" -> classOf[LZFCompressionCodec].getName,
"snappy" -> classOf[SnappyCompressionCodec].getName)
/** Maps the short versions of compression codec names to fully-qualified class names. */
private val shortCompressionCodecNameMapper = new ShortCompressionCodecNameMapper {
override def lz4: Option[String] = Some(classOf[LZ4CompressionCodec].getName)
override def lzf: Option[String] = Some(classOf[LZFCompressionCodec].getName)
override def snappy: Option[String] = Some(classOf[SnappyCompressionCodec].getName)
}

private val shortCompressionCodecMap = shortCompressionCodecNameMapper.getAsMap

def getCodecName(conf: SparkConf): String = {
conf.get(configKey, DEFAULT_COMPRESSION_CODEC)
Expand All @@ -67,7 +71,7 @@ private[spark] object CompressionCodec {
}

def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
val codecClass = shortCompressionCodecNameMapper.get(codecName).getOrElse(codecName)
val codec = try {
val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf])
Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
Expand All @@ -84,18 +88,18 @@ private[spark] object CompressionCodec {
* If it is already a short name, just return it.
*/
def getShortName(codecName: String): String = {
if (shortCompressionCodecNames.contains(codecName)) {
if (shortCompressionCodecMap.contains(codecName)) {
codecName
} else {
shortCompressionCodecNames
shortCompressionCodecMap
.collectFirst { case (k, v) if v == codecName => k }
.getOrElse { throw new IllegalArgumentException(s"No short name for codec $codecName.") }
}
}

val FALLBACK_COMPRESSION_CODEC = "snappy"
val DEFAULT_COMPRESSION_CODEC = "lz4"
val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq
val ALL_COMPRESSION_CODECS = shortCompressionCodecMap.values.toSeq
}

/**
Expand Down
45 changes: 45 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,51 @@ private[spark] object CallSite {
val empty = CallSite("", "")
}

/** An utility class to map short compression codec names to qualified ones. */
private[spark] class ShortCompressionCodecNameMapper {

def get(codecName: String): Option[String] = codecName.toLowerCase match {
case "none" => none
case "uncompressed" => uncompressed
case "bzip2" => bzip2
case "deflate" => deflate
case "gzip" => gzip
case "lzo" => lzo
case "lz4" => lz4
case "lzf" => lzf
case "snappy" => snappy
case _ => None
}

def getAsMap: Map[String, String] = {
Seq(
("none", none),
("uncompressed", uncompressed),
("bzip2", bzip2),
("deflate", deflate),
("gzip", gzip),
("lzo", lzo),
("lz4", lz4),
("lzf", lzf),
("snappy", snappy)
).flatMap { case (shortCodecName, codecName) =>
if (codecName.isDefined) Some(shortCodecName, codecName.get) else None
}.toMap
}

// To support short codec names, derived classes need to override the methods below that return
// corresponding qualified codec names.
def none: Option[String] = None
def uncompressed: Option[String] = None
def bzip2: Option[String] = None
def deflate: Option[String] = None
def gzip: Option[String] = None
def lzo: Option[String] = None
def lz4: Option[String] = None
def lzf: Option[String] = None
def snappy: Option[String] = None
}

/**
* Various utility methods used by Spark.
*/
Expand Down
30 changes: 15 additions & 15 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,37 @@ package org.apache.spark.sql.execution.datasources

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.{BZip2Codec, GzipCodec, Lz4Codec, SnappyCodec}
import org.apache.hadoop.io.compress.{BZip2Codec, DeflateCodec, GzipCodec, Lz4Codec, SnappyCodec}
import org.apache.spark.util.ShortCompressionCodecNameMapper

import org.apache.spark.util.Utils

private[datasources] object CompressionCodecs {
private val shortCompressionCodecNames = Map(
"bzip2" -> classOf[BZip2Codec].getName,
"gzip" -> classOf[GzipCodec].getName,
"lz4" -> classOf[Lz4Codec].getName,
"snappy" -> classOf[SnappyCodec].getName)

/** Maps the short versions of compression codec names to fully-qualified class names. */
private val hadoopShortCodecNameMapper = new ShortCompressionCodecNameMapper {
override def bzip2: Option[String] = Some(classOf[BZip2Codec].getCanonicalName)
override def deflate: Option[String] = Some(classOf[DeflateCodec].getCanonicalName)
override def gzip: Option[String] = Some(classOf[GzipCodec].getCanonicalName)
override def lz4: Option[String] = Some(classOf[Lz4Codec].getCanonicalName)
override def snappy: Option[String] = Some(classOf[SnappyCodec].getCanonicalName)
}

/**
* Return the full version of the given codec class.
* If it is already a class name, just return it.
*/
def getCodecClassName(name: String): String = {
val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name)
val codecName = hadoopShortCodecNameMapper.get(name).getOrElse(name)
try {
// Validate the codec name
Utils.classForName(codecName)
codecName
} catch {
case e: ClassNotFoundException =>
throw new IllegalArgumentException(s"Codec [$codecName] " +
s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.")
s"is not available. Known codecs are " +
s"${hadoopShortCodecNameMapper.getAsMap.keys.mkString(", ")}.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import org.apache.spark.sql.execution.datasources.{PartitionSpec, _}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.util.{SerializableConfiguration, ShortCompressionCodecNameMapper, Utils}

private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister {

Expand Down Expand Up @@ -284,10 +284,8 @@ private[sql] class ParquetRelation(
conf.set(
ParquetOutputFormat.COMPRESSION,
ParquetRelation
.shortParquetCompressionCodecNames
.getOrElse(
sqlContext.conf.parquetCompressionCodec.toUpperCase,
CompressionCodecName.UNCOMPRESSED).name())
.parquetShortCodecNameMapper.get(sqlContext.conf.parquetCompressionCodec)
.getOrElse(CompressionCodecName.UNCOMPRESSED.name()))

new BucketedOutputWriterFactory {
override def newInstance(
Expand Down Expand Up @@ -903,11 +901,12 @@ private[sql] object ParquetRelation extends Logging {
}
}

// The parquet compression short names
val shortParquetCompressionCodecNames = Map(
"NONE" -> CompressionCodecName.UNCOMPRESSED,
"UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED,
"SNAPPY" -> CompressionCodecName.SNAPPY,
"GZIP" -> CompressionCodecName.GZIP,
"LZO" -> CompressionCodecName.LZO)
/** Maps the short versions of compression codec names to qualified compression names. */
val parquetShortCodecNameMapper = new ShortCompressionCodecNameMapper {
override def none: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name())
override def uncompressed: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name())
override def gzip: Option[String] = Some(CompressionCodecName.GZIP.name())
override def lzo: Option[String] = Some(CompressionCodecName.LZO.name())
override def snappy: Option[String] = Some(CompressionCodecName.SNAPPY.name())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.datasources.text

import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -58,18 +58,13 @@ class TextSuite extends QueryTest with SharedSQLContext {
}

test("SPARK-13503 Support to specify the option for compression codec for TEXT") {
val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf")

val tempFile = Utils.createTempDir()
tempFile.delete()
df.write
.option("compression", "gZiP")
.text(tempFile.getCanonicalPath)
val compressedFiles = tempFile.listFiles()
assert(compressedFiles.exists(_.getName.endsWith(".gz")))
verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath))

Utils.deleteRecursively(tempFile)
Seq("bzip2", "deflate", "gzip").map { codecName =>
val tempDir = Utils.createTempDir()
val tempDirPath = tempDir.getAbsolutePath()
val df = sqlContext.read.text(testFile)
df.write.option("compression", codecName).mode(SaveMode.Overwrite).text(tempDirPath)
verifyFrame(sqlContext.read.text(tempDirPath))
}
}

private def testFile: String = {
Expand Down

0 comments on commit 25e9250

Please sign in to comment.