Skip to content
Closed
36 changes: 31 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,23 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* @since 1.4.0
*/
def insertInto(tableName: String): Unit = {
insertInto(tableName,None)
}

/**
* Inserts the content of the `DataFrame` to the specific table partition.
*
* {{{
* scala> Seq((3, 4)).toDF("j", "i").write.insertInto(ptTableName,"pt1='0101',pt2='0202'")
* }}}
*
* @since 3.0
*/
def insertInto(tableName: String,partionInfo: String): Unit = {
insertInto(tableName,Some(partionInfo))
}

private def insertInto(tableName: String,partionInfo: Option[String]): Unit = {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._

Expand All @@ -355,16 +372,25 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val canUseV2 = lookupV2Provider().isDefined
val sessionCatalogOpt = session.sessionState.analyzer.sessionCatalog

var parition = Map[String, Option[String]]()
if (partionInfo.isDefined){
val res = partionInfo.get.split(",").foreach(partion=>{
val partionKey = partion.replaceAll("'","").split("=")(0)
val partionValue = partion.replaceAll("'","").split("=")(1)
parition += (partionKey -> Some(partionValue))
})
}

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
insertInto(catalog, ident)

case CatalogObjectIdentifier(None, ident)
if canUseV2 && sessionCatalogOpt.isDefined && ident.namespace().length <= 1 =>
if canUseV2 && sessionCatalogOpt.isDefined && ident.namespace().length <= 1 =>
insertInto(sessionCatalogOpt.get, ident)

case AsTableIdentifier(tableIdentifier) =>
insertInto(tableIdentifier)
insertInto(tableIdentifier,parition)
case other =>
throw new AnalysisException(
s"Couldn't find a catalog to handle the identifier ${other.quoted}.")
Expand All @@ -376,7 +402,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

val table = catalog.asTableCatalog.loadTable(ident) match {
case _: UnresolvedTable =>
return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption))
return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption),Map.empty[String, Option[String]])
case t =>
DataSourceV2Relation.create(t)
}
Expand Down Expand Up @@ -406,11 +432,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
}

private def insertInto(tableIdent: TableIdentifier): Unit = {
private def insertInto(tableIdent: TableIdentifier,partionInfo: Map[String, Option[String]]): Unit = {
runCommand(df.sparkSession, "insertInto") {
InsertIntoTable(
table = UnresolvedRelation(tableIdent),
partition = Map.empty[String, Option[String]],
partition = partionInfo,
query = df.logicalPlan,
overwrite = modeForDSV1 == SaveMode.Overwrite,
ifPartitionNotExists = false)
Expand Down