Skip to content

Commit

Permalink
Update InsertIntoHiveTable.scala
Browse files Browse the repository at this point in the history
  • Loading branch information
baishuo authored and liancheng committed Sep 17, 2014
1 parent 701a814 commit a2374a8
Showing 1 changed file with 130 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharOb
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf}

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.{SerializableWritable, SparkException, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Row
Expand Down Expand Up @@ -159,6 +159,30 @@ case class InsertIntoHiveTable(
writer.commitJob()
}

def getDynamicPartDir(tableInfo: TableDesc, row: Row, dynamicPartNum2: Int) :String = {
println("tableInfo.class:" + tableInfo.getClass + "|row(2):" + row(2))
println(tableInfo.getProperties.getProperty("columns") + "|" + tableInfo.getProperties.getProperty("partition_columns"))
dynamicPartNum2 match {
case 0 =>""
case i => {
val colsNum = tableInfo.getProperties.getProperty("columns").split("\\,").length
val partColStr = tableInfo.getProperties.getProperty("partition_columns")
val partCols = partColStr.split("/")
var buf = new StringBuffer()
if (partCols.length == dynamicPartNum2) {
for (j <- 0 until partCols.length) {
buf.append("/").append(partCols(j)).append("=").append(row(j + row.length - colsNum))
}
} else {
for (j <- 0 until dynamicPartNum2) {
buf.append("/").append(partCols(j + partCols.length - dynamicPartNum2)).append("=").append(row(j + colsNum))
}
}
buf.toString
}
}
}

override def execute() = result

/**
Expand All @@ -178,6 +202,12 @@ case class InsertIntoHiveTable(
val tableLocation = table.hiveQlTable.getDataLocation
val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation)
val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false)
var dynamicPartNum = 0
var dynamicPartPath = "";
val partitionSpec = partition.map {
case (key, Some(value)) => key -> value
case (key, None) => { dynamicPartNum += 1; key -> "" }// Should not reach here right now.
}
val rdd = childRdd.mapPartitions { iter =>
val serializer = newSerializer(fileSinkConf.getTableInfo)
val standardOI = ObjectInspectorUtils
Expand All @@ -191,7 +221,10 @@ case class InsertIntoHiveTable(
val outputData = new Array[Any](fieldOIs.length)
iter.map { row =>
var i = 0
while (i < row.length) {
while (i < fieldOIs.length) {
if (fieldOIs.length < row.length && row.length - fieldOIs.length == dynamicPartNum) {
dynamicPartPath = getDynamicPartDir(fileSinkConf.getTableInfo, row, dynamicPartNum)
}
// Casts Strings to HiveVarchars when necessary.
outputData(i) = wrap(row(i), fieldOIs(i))
i += 1
Expand All @@ -204,12 +237,81 @@ case class InsertIntoHiveTable(
// ORC stores compression information in table properties. While, there are other formats
// (e.g. RCFile) that rely on hadoop configurations to store compression information.
val jobConf = new JobConf(sc.hiveconf)
saveAsHiveFile(
rdd,
outputClass,
fileSinkConf,
jobConf,
sc.hiveconf.getBoolean("hive.exec.compress.output", false))
val jobConfSer = new SerializableWritable(jobConf)
if (dynamicPartNum>0) {
if (outputClass == null) {
throw new SparkException("Output value class not set")
}
jobConfSer.value.setOutputValueClass(outputClass)
if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) {
throw new SparkException("Output format class not set")
}
// Doesn't work in Scala 2.9 due to what may be a generics bug
// TODO: Should we uncomment this for Scala 2.10?
// conf.setOutputFormat(outputFormatClass)
jobConfSer.value.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName)
if (sc.hiveconf.getBoolean("hive.exec.compress.output", false)) {
// Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec",
// and "mapred.output.compression.type" have no impact on ORC because it uses table properties
// to store compression information.
jobConfSer.value.set("mapred.output.compress", "true")
fileSinkConf.setCompressed(true)
fileSinkConf.setCompressCodec(jobConfSer.value.get("mapred.output.compression.codec"))
fileSinkConf.setCompressType(jobConfSer.value.get("mapred.output.compression.type"))
}
jobConfSer.value.setOutputCommitter(classOf[FileOutputCommitter])

FileOutputFormat.setOutputPath(
jobConfSer.value,
SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, jobConfSer.value))

var writerMap = new scala.collection.mutable.HashMap[String, SparkHiveHadoopWriter]
def writeToFile2(context: TaskContext, iter: Iterator[Writable]) {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
val serializer = newSerializer(fileSinkConf.getTableInfo)
var count = 0
var writer2:SparkHiveHadoopWriter = null
while(iter.hasNext) {
val record = iter.next();
val location = fileSinkConf.getDirName
val partLocation = location + dynamicPartPath
writer2=writerMap.get(dynamicPartPath) match {
case Some(writer)=> writer
case None => {
val tempWriter = new SparkHiveHadoopWriter(jobConfSer.value, new FileSinkDesc(partLocation, fileSinkConf.getTableInfo, false))
tempWriter.setup(context.stageId, context.partitionId, attemptNumber)
tempWriter.open(dynamicPartPath);
writerMap += (dynamicPartPath -> tempWriter)
tempWriter
}
}
count += 1
writer2.write(record)
}
for((k,v) <- writerMap) {
v.close()
v.commit()
}
}

sc.sparkContext.runJob(rdd, writeToFile2 _)

for((k,v) <- writerMap) {
v.commitJob()
}
writerMap.clear()
//writer.commitJob()

} else {
saveAsHiveFile(
rdd,
outputClass,
fileSinkConf,
jobConf,
sc.hiveconf.getBoolean("hive.exec.compress.output", false))
}

// TODO: Handle dynamic partitioning.
val outputPath = FileOutputFormat.getOutputPath(jobConf)
Expand All @@ -220,25 +322,33 @@ case class InsertIntoHiveTable(
// holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint.
val holdDDLTime = false
if (partition.nonEmpty) {
val partitionSpec = partition.map {
case (key, Some(value)) => key -> value
case (key, None) => key -> "" // Should not reach here right now.
}
val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec)
db.validatePartitionNameCharacters(partVals)
// inheritTableSpecs is set to true. It should be set to false for a IMPORT query
// which is currently considered as a Hive native command.
val inheritTableSpecs = true
// TODO: Correctly set isSkewedStoreAsSubdir.
val isSkewedStoreAsSubdir = false
db.loadPartition(
outputPath,
qualifiedTableName,
partitionSpec,
overwrite,
holdDDLTime,
inheritTableSpecs,
isSkewedStoreAsSubdir)
if (dynamicPartNum>0) {
db.loadDynamicPartitions(
outputPath,
qualifiedTableName,
partitionSpec,
overwrite,
dynamicPartNum/*dpCtx.getNumDPCols()*/,
holdDDLTime,
isSkewedStoreAsSubdir
)
} else {
db.loadPartition(
outputPath,
qualifiedTableName,
partitionSpec,
overwrite,
holdDDLTime,
inheritTableSpecs,
isSkewedStoreAsSubdir)
}
} else {
db.loadTable(
outputPath,
Expand Down

0 comments on commit a2374a8

Please sign in to comment.