In [None]:
var IntermediateFolderPath = "abfss://<container_name>@<storage_account_name>.dfs.core.windows.net/intermediate_output/"
var StorageAccountName = "<storage_account_name>"
var StorageAccountAccessKey = "<storage_account_access_key>"

var DatabasePrefix = ""
var TablePrefix = ""
var IgnoreIfExists = false
var OverrideIfExists = false

var LocationPrefixMappings:Map[String, String] = Map("dbfs:/user/hive/warehouse"->"abfss://<container_name>@<storage_account_name>.dfs.core.windows.net/catalog/hive/warehouse")

In [4]:

spark.conf.set(
  "fs.azure.account.key." + StorageAccountName + ".dfs.core.windows.net",
  StorageAccountAccessKey
)

In [None]:
import java.net.URI
import java.util.Calendar

import scala.collection.mutable.{ListBuffer, Map}
import org.apache.spark.sql._
import org.apache.spark.sql.types.{ObjectType, _}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.json4s._
import org.json4s.JsonAST.JString
import org.json4s.jackson.Serialization


var locationPrefixMappingList = Map(LocationPrefixMappings.toSeq: _*).toList.sortBy(pair => pair._1).reverse

object ImportMetadata {

  val spark = SparkSession.builder().getOrCreate()


  // define custom json serializer for java.net.URI
  case object URISerializer extends CustomSerializer[URI](format => ( {
    case JString(uri) => new URI(uri)
  }, {
    case uri: URI => JString(uri.toString())
  }))

  // define custom json serializer for  org.apache.spark.sql.types.StructType
  case object SturctTypeSerializer extends CustomSerializer[StructType](format => ( {
    case JString(structType)  => DataType.fromJson(structType).asInstanceOf[StructType]
  }, {
    case structType: StructType => JString(structType.json)
  }))


  // define formats for org.json4s.jackson.Serialization
  implicit val formats = DefaultFormats + URISerializer + SturctTypeSerializer// define custom json serializer for java.net.URI


  case class CatalogPartitions(database: String, table: String, tablePartitons: Seq[CatalogTablePartition])

  case class CatalogTables(database: String, tables: Seq[CatalogTable])

  def ConvertLocation(location: String) : String = {
    var locationMapping = locationPrefixMappingList.find(mapping => {location.startsWith(mapping._1)})

    if (locationMapping != None) {
      return location.replaceFirst(locationMapping.get._1, locationMapping.get._2)
    }

    return location;
  }

  def ConvertCatalogDatabase(databsae: CatalogDatabase) : CatalogDatabase = {
    //class CatalogDatabase(
    // name: String,
    // description: String,
    // locationUri: URI,
    // properties: Map[String, String])
    // extends Product
    var convertedDatabsae  = new CatalogDatabase(
      DatabasePrefix + databsae.name,
      databsae.description,
      new URI(ConvertLocation(databsae.locationUri.toString())), //databsae.locationUri
      databsae.properties)

    return convertedDatabsae;
  }

  def ConvertCatalogStorageFormat(format : CatalogStorageFormat) : CatalogStorageFormat = {

    var formatlocation: Option[URI] = None
    if (format.locationUri != None) {
      formatlocation = Some(new URI(ConvertLocation(format.locationUri.get.toString())))
    }

    //class CatalogStorageFormat(
    // locationUri: Option[URI],
    // inputFormat: Option[String],
    // outputFormat: Option[String],
    // serde: Option[String],
    // compressed: Boolean,
    // properties: Map[String, String])
    //  extends Product
    var convertedStorageFormat = new CatalogStorageFormat(
      formatlocation,
      format.inputFormat,
      format.outputFormat,
      format.serde,
      format.compressed,
      format.properties
    )

    return  convertedStorageFormat;
  }

  def ConvertCatalogTable(table: CatalogTable) : CatalogTable = {

    var dbName = Some(DatabasePrefix + table.identifier.database.get);
    var tblName = TablePrefix + table.identifier.table;

    //class CatalogTable(
    // identifier: TableIdentifier,
    // tableType: CatalogTableType,
    // storage: CatalogStorageFormat,
    // schema: StructType,
    // provider: Option[String] = None,
    // partitionColumnNames: scala.Seq[String] = Seq.empty,
    // bucketSpec: Option[BucketSpec] = None,
    // owner: String = "",
    // createTime: Long = System.currentTimeMi...,
    // lastAccessTime: Long = -1,
    // createVersion: String = "",
    // properties: Map[String, String] = Map.empty,
    // stats: Option[CatalogStatistics] = None,
    // viewText: Option[String] = None,
    // comment: Option[String] = None,
    // unsupportedFeatures: scala.Seq[String] = Seq.empty,
    // tracksPartitionsInCatalog: Boolean = false,
    // schemaPreservesCase: Boolean = true,
    // ignoredProperties: Map[String, String] = Map.empty)
    // extends Product
    var convertedTable = new CatalogTable(
      //class TableIdentifier(table: String,database: Option[String])
      // extends IdentifierWithDatabase
      new TableIdentifier(tblName, dbName),
      table.tableType,
      ConvertCatalogStorageFormat(table.storage),
      table.schema,
      table.provider,
      table.partitionColumnNames,
      table.bucketSpec,
      table.owner,
      table.createTime,
      table.lastAccessTime,
      table.createVersion,
      table.properties,
      table.stats,
      table.viewText,
      table.comment,
      table.unsupportedFeatures,
      table.tracksPartitionsInCatalog,
      table.schemaPreservesCase,
      table.ignoredProperties)

    return convertedTable;
  }


  def ConvertCatalogTablePartition(partition : CatalogTablePartition) : CatalogTablePartition = {
    //class CatalogTablePartition(
    // spec: CatalogTypes.TablePartitionSpec,
    // storage: CatalogStorageFormat,
    // parameters: Map[String, String] = Map.empty,
    // createTime: Long = System.currentTimeMi...,
    // lastAccessTime: Long = -1,
    // stats: Option[CatalogStatistics] = None)
    // extends Product
    var convertedPartition = new CatalogTablePartition(
      partition.spec,
      ConvertCatalogStorageFormat(partition.storage),
      partition.parameters,
      partition.createTime,
      partition.lastAccessTime,
      partition.stats
    );

    return convertedPartition;
  }

  val MaxRetryCount = 3;

  def RetriableFunc(func: () => Unit, retryCount: Int = 0): Unit = {
    try {
      func()
    } catch {
      case e:Exception => {
        if (retryCount < MaxRetryCount){
          RetriableFunc(func, retryCount + 1)
        } else {
          throw e
        }
      }
    }
  }

  def RetriableQueryFunc(func: () => Object, retryCount: Int = 0): Object = {
    try {
      func()
    } catch {
      case e:Exception => {
        if (retryCount < MaxRetryCount){
          RetriableQueryFunc(func, retryCount + 1)
        } else {
          throw e
        }
      }
    }
  }

 
  def CreateDatabases(dataPath: String) = {

    println("Start to create databases " + Calendar.getInstance().getTime())

    val ds = spark.read.format("text").load(dataPath)

    var createdCount = 0;
    var existsDbs = spark.sharedState.externalCatalog.listDatabases()
    var data = ds.collect()
    var total = data.size

    data.foreach(row => {
      var jsonString = row.getString(0)
      var newDb = ConvertCatalogDatabase(Serialization.read[CatalogDatabase](jsonString))

      var exists = existsDbs.contains(newDb.name)
      if (exists && !IgnoreIfExists && !OverrideIfExists) {

        println(createdCount + "/" + total + " databases created. " + Calendar.getInstance().getTime())
        println("Database " + newDb.name + " already exists")

        throw new Exception("Database " + newDb.name + " already exists")
      } else if (!exists || OverrideIfExists) {
        CreateDatabase(newDb)
      }

      createdCount+=1;

      if (createdCount%100 == 0) {
        println(createdCount + "/" + total + " databases created" + Calendar.getInstance().getTime())
      }
    });

    println("Databases Created completed. Total " + createdCount + " database created. " + Calendar.getInstance().getTime())
  }

  def CreateDatabase(db:CatalogDatabase) = {
    // Drop exists db if overrideIfExists
    if (OverrideIfExists) {
      RetriableFunc(() =>  {
        spark.sharedState.externalCatalog.dropDatabase(db.name, true, true)
      })
    }

    // Create db
    RetriableFunc(() => {
      spark.sharedState.externalCatalog.createDatabase(db, IgnoreIfExists)
    })
  }

  def CreateTables(dataPath: String) = {
    println("Start to create tables " + Calendar.getInstance().getTime())

    val ds = spark.read.format("text").load(dataPath);

    var createdCount = 0;
    ds.collect().foreach(row => {
      var jsonString = row.getString(0)
      var tables = Serialization.read[CatalogTables](jsonString);

      var existsTables = spark.sharedState.externalCatalog.listTables(DatabasePrefix + tables.database)
      var perTables = tables.tables.toParArray

      perTables.foreach(table => {
        var newTable = ConvertCatalogTable(table)
        var exists = existsTables.contains(newTable.identifier.table)
        if (exists && !IgnoreIfExists) {

          println(createdCount + " tables created. " + Calendar.getInstance().getTime())
          println("Table " + newTable.identifier.database + "." + newTable.identifier.table + " already exists")

          throw new Exception("Table " + newTable.identifier.database + "." + newTable.identifier.table + " already exists")
        } else if (!exists) {
          CreateTable(newTable)
        }

        createdCount += 1;
      })

      println(createdCount + " tables created" + Calendar.getInstance().getTime())
    })

    println("Tables Created completed. Total " + createdCount + " table created. " + Calendar.getInstance().getTime())
  }

  def CreateTable(table:CatalogTable) = {
    // Create table
    RetriableFunc(() => {
      spark.sharedState.externalCatalog.createTable(table, IgnoreIfExists)
    })
  }

  def CreatePartitions(dataPath: String) = {
    println("Start to create partitions " + Calendar.getInstance().getTime())

    val ds = spark.read.format("text").load(dataPath);

    var createdCount = 0;
    ds.collect().foreach(row => {
      var jsonString = row.getString(0)
      var parts = Serialization.read[CatalogPartitions](jsonString);

      var catalogTablePartitions = new ListBuffer[CatalogTablePartition]()
      parts.tablePartitons.foreach( part => {
        catalogTablePartitions += ConvertCatalogTablePartition(part)
      })

      RetriableFunc(() => {
        spark.sharedState.externalCatalog.createPartitions(DatabasePrefix + parts.database, TablePrefix + parts.table, catalogTablePartitions, IgnoreIfExists)
      })

      createdCount+=catalogTablePartitions.size;
      println(createdCount +  " partitions created" + Calendar.getInstance().getTime())
    });

    println("Partition Created completed. Total " + createdCount + " partition created. " + Calendar.getInstance().getTime())
  }

  def ImportCatalogObjectsFromFile(inputPath: String) = {

    val dbsPath = inputPath + "databases";
    val tablesPath = inputPath + "tables";
    val partPath = inputPath + "partitions";

    CreateDatabases(dbsPath)
    CreateTables(tablesPath)
    CreatePartitions(partPath)
  }
}

In [None]:
ImportMetadata.CreateDatabases(IntermediateFolderPath + "databases")    

In [None]:
ImportMetadata.CreateTables(IntermediateFolderPath + "tables")

In [None]:
ImportMetadata.CreatePartitions(IntermediateFolderPath + "partitions")