diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ae014becef755..97fdc232be8ff 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -25,7 +25,7 @@ import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShortCompressionCodecNameMapper, Utils} /** * :: DeveloperApi :: @@ -53,10 +53,14 @@ private[spark] object CompressionCodec { || codec.isInstanceOf[LZ4CompressionCodec]) } - private val shortCompressionCodecNames = Map( - "lz4" -> classOf[LZ4CompressionCodec].getName, - "lzf" -> classOf[LZFCompressionCodec].getName, - "snappy" -> classOf[SnappyCompressionCodec].getName) + /** Maps the short versions of compression codec names to fully-qualified class names. */ + private val shortCompressionCodecNameMapper = new ShortCompressionCodecNameMapper { + override def lz4: Option[String] = Some(classOf[LZ4CompressionCodec].getName) + override def lzf: Option[String] = Some(classOf[LZFCompressionCodec].getName) + override def snappy: Option[String] = Some(classOf[SnappyCompressionCodec].getName) + } + + private val shortCompressionCodecMap = shortCompressionCodecNameMapper.getAsMap def getCodecName(conf: SparkConf): String = { conf.get(configKey, DEFAULT_COMPRESSION_CODEC) @@ -67,7 +71,7 @@ private[spark] object CompressionCodec { } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { - val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) + val codecClass = shortCompressionCodecNameMapper.get(codecName).getOrElse(codecName) val codec = try { val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) @@ -84,10 +88,10 @@ private[spark] object CompressionCodec { * If it is already a short name, just return it. */ def getShortName(codecName: String): String = { - if (shortCompressionCodecNames.contains(codecName)) { + if (shortCompressionCodecMap.contains(codecName)) { codecName } else { - shortCompressionCodecNames + shortCompressionCodecMap .collectFirst { case (k, v) if v == codecName => k } .getOrElse { throw new IllegalArgumentException(s"No short name for codec $codecName.") } } @@ -95,7 +99,7 @@ private[spark] object CompressionCodec { val FALLBACK_COMPRESSION_CODEC = "snappy" val DEFAULT_COMPRESSION_CODEC = "lz4" - val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq + val ALL_COMPRESSION_CODECS = shortCompressionCodecMap.values.toSeq } /** diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e0c9bf02a1a20..5bc827f88b427 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -60,6 +60,51 @@ private[spark] object CallSite { val empty = CallSite("", "") } +/** An utility class to map short compression codec names to qualified ones. */ +private[spark] class ShortCompressionCodecNameMapper { + + def get(codecName: String): Option[String] = codecName.toLowerCase match { + case "none" => none + case "uncompressed" => uncompressed + case "bzip2" => bzip2 + case "deflate" => deflate + case "gzip" => gzip + case "lzo" => lzo + case "lz4" => lz4 + case "lzf" => lzf + case "snappy" => snappy + case _ => None + } + + def getAsMap: Map[String, String] = { + Seq( + ("none", none), + ("uncompressed", uncompressed), + ("bzip2", bzip2), + ("deflate", deflate), + ("gzip", gzip), + ("lzo", lzo), + ("lz4", lz4), + ("lzf", lzf), + ("snappy", snappy) + ).flatMap { case (shortCodecName, codecName) => + if (codecName.isDefined) Some(shortCodecName, codecName.get) else None + }.toMap + } + + // To support short codec names, derived classes need to override the methods below that return + // corresponding qualified codec names. + def none: Option[String] = None + def uncompressed: Option[String] = None + def bzip2: Option[String] = None + def deflate: Option[String] = None + def gzip: Option[String] = None + def lzo: Option[String] = None + def lz4: Option[String] = None + def lzf: Option[String] = None + def snappy: Option[String] = None +} + /** * Various utility methods used by Spark. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index d6bdd3d825565..c5839fc3b6488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala index bc8ef4ad7e236..abfbd09366be8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala @@ -19,23 +19,28 @@ package org.apache.spark.sql.execution.datasources import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.SequenceFile.CompressionType -import org.apache.hadoop.io.compress.{BZip2Codec, GzipCodec, Lz4Codec, SnappyCodec} +import org.apache.hadoop.io.compress.{BZip2Codec, DeflateCodec, GzipCodec, Lz4Codec, SnappyCodec} +import org.apache.spark.util.ShortCompressionCodecNameMapper import org.apache.spark.util.Utils private[datasources] object CompressionCodecs { - private val shortCompressionCodecNames = Map( - "bzip2" -> classOf[BZip2Codec].getName, - "gzip" -> classOf[GzipCodec].getName, - "lz4" -> classOf[Lz4Codec].getName, - "snappy" -> classOf[SnappyCodec].getName) + + /** Maps the short versions of compression codec names to fully-qualified class names. */ + private val hadoopShortCodecNameMapper = new ShortCompressionCodecNameMapper { + override def bzip2: Option[String] = Some(classOf[BZip2Codec].getCanonicalName) + override def deflate: Option[String] = Some(classOf[DeflateCodec].getCanonicalName) + override def gzip: Option[String] = Some(classOf[GzipCodec].getCanonicalName) + override def lz4: Option[String] = Some(classOf[Lz4Codec].getCanonicalName) + override def snappy: Option[String] = Some(classOf[SnappyCodec].getCanonicalName) + } /** * Return the full version of the given codec class. * If it is already a class name, just return it. */ def getCodecClassName(name: String): String = { - val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name) + val codecName = hadoopShortCodecNameMapper.get(name).getOrElse(name) try { // Validate the codec name Utils.classForName(codecName) @@ -43,7 +48,8 @@ private[datasources] object CompressionCodecs { } catch { case e: ClassNotFoundException => throw new IllegalArgumentException(s"Codec [$codecName] " + - s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.") + s"is not available. Known codecs are " + + s"${hadoopShortCodecNameMapper.getAsMap.keys.mkString(", ")}.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 184cbb2f296b0..d0a52781bfb27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -50,7 +50,7 @@ import org.apache.spark.sql.execution.datasources.{PartitionSpec, _} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShortCompressionCodecNameMapper, Utils} private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { @@ -284,10 +284,8 @@ private[sql] class ParquetRelation( conf.set( ParquetOutputFormat.COMPRESSION, ParquetRelation - .shortParquetCompressionCodecNames - .getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, - CompressionCodecName.UNCOMPRESSED).name()) + .parquetShortCodecNameMapper.get(sqlContext.conf.parquetCompressionCodec) + .getOrElse(CompressionCodecName.UNCOMPRESSED.name())) new BucketedOutputWriterFactory { override def newInstance( @@ -903,11 +901,12 @@ private[sql] object ParquetRelation extends Logging { } } - // The parquet compression short names - val shortParquetCompressionCodecNames = Map( - "NONE" -> CompressionCodecName.UNCOMPRESSED, - "UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED, - "SNAPPY" -> CompressionCodecName.SNAPPY, - "GZIP" -> CompressionCodecName.GZIP, - "LZO" -> CompressionCodecName.LZO) + /** Maps the short versions of compression codec names to qualified compression names. */ + val parquetShortCodecNameMapper = new ShortCompressionCodecNameMapper { + override def none: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name()) + override def uncompressed: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name()) + override def gzip: Option[String] = Some(CompressionCodecName.GZIP.name()) + override def lzo: Option[String] = Some(CompressionCodecName.LZO.name()) + override def snappy: Option[String] = Some(CompressionCodecName.SNAPPY.name()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 6ae42a30fb00c..0337ead894525 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.text -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.Utils @@ -58,18 +58,13 @@ class TextSuite extends QueryTest with SharedSQLContext { } test("SPARK-13503 Support to specify the option for compression codec for TEXT") { - val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf") - - val tempFile = Utils.createTempDir() - tempFile.delete() - df.write - .option("compression", "gZiP") - .text(tempFile.getCanonicalPath) - val compressedFiles = tempFile.listFiles() - assert(compressedFiles.exists(_.getName.endsWith(".gz"))) - verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath)) - - Utils.deleteRecursively(tempFile) + Seq("bzip2", "deflate", "gzip").map { codecName => + val tempDir = Utils.createTempDir() + val tempDirPath = tempDir.getAbsolutePath() + val df = sqlContext.read.text(testFile) + df.write.option("compression", codecName).mode(SaveMode.Overwrite).text(tempDirPath) + verifyFrame(sqlContext.read.text(tempDirPath)) + } } private def testFile: String = {