Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions Sources/SparkConnect/Catalog.swift
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,114 @@ public actor Catalog: Sendable {
return try await self.listDatabases(pattern: dbName).count > 0
}

/// Creates a table from the given path and returns the corresponding ``DataFrame``.
/// - Parameters:
/// - tableName: A qualified or unqualified name that designates a table. If no database
/// identifier is provided, it refers to a table in the current database.
/// - path: A path to load a table.
/// - source: A data source.
/// - description: A table description.
/// - options: A dictionary for table options
/// - Returns: A ``DataFrame``.
public func createTable(
_ tableName: String,
_ path: String? = nil,
source: String? = nil,
description: String? = nil,
options: [String: String]? = nil
) -> DataFrame {
let df = getDataFrame({
var createTable = Spark_Connect_CreateTable()
createTable.tableName = tableName
if let source {
createTable.source = source
}
createTable.description_p = description ?? ""
if let options {
for (k, v) in options {
createTable.options[k] = v
}
}
if let path {
createTable.options["path"] = path
}
var catalog = Spark_Connect_Catalog()
catalog.createTable = createTable
return catalog
})
return df
}

/// Check if the table or view with the specified name exists. This can either be a temporary
/// view or a table/view.
/// - Parameter tableName: a qualified or unqualified name that designates a table/view. It follows the same
/// resolution rule with SQL: search for temp views first then table/views in the current
/// database (namespace).
/// - Returns: Return true if it exists.
public func tableExists(_ tableName: String) async throws -> Bool {
let df = getDataFrame({
var tableExists = Spark_Connect_TableExists()
tableExists.tableName = tableName
var catalog = Spark_Connect_Catalog()
catalog.tableExists = tableExists
return catalog
})
return "true" == (try await df.collect().first!.get(0) as! String)
}

/// Check if the table or view with the specified name exists. This can either be a temporary
/// view or a table/view.
/// - Parameters:
/// - dbName: an unqualified name that designates a database.
/// - tableName: an unqualified name that designates a table.
/// - Returns: Return true if it exists.
public func tableExists(_ dbName: String, _ tableName: String) async throws -> Bool {
let df = getDataFrame({
var tableExists = Spark_Connect_TableExists()
tableExists.tableName = tableName
tableExists.dbName = dbName
var catalog = Spark_Connect_Catalog()
catalog.tableExists = tableExists
return catalog
})
return "true" == (try await df.collect().first!.get(0) as! String)
}

/// Check if the function with the specified name exists. This can either be a temporary function
/// or a function.
/// - Parameter functionName: a qualified or unqualified name that designates a function. It follows the same
/// resolution rule with SQL: search for built-in/temp functions first then functions in the
/// current database (namespace).
/// - Returns: Return true if it exists.
public func functionExists(_ functionName: String) async throws -> Bool {
let df = getDataFrame({
var functionExists = Spark_Connect_FunctionExists()
functionExists.functionName = functionName
var catalog = Spark_Connect_Catalog()
catalog.functionExists = functionExists
return catalog
})
return "true" == (try await df.collect().first!.get(0) as! String)
}

/// Check if the function with the specified name exists in the specified database under the Hive
/// Metastore.
/// - Parameters:
/// - dbName: an unqualified name that designates a database.
/// - functionName: an unqualified name that designates a function.
/// - Returns: Return true if it exists.
public func functionExists(_ dbName: String, _ functionName: String) async throws -> Bool {
let df = getDataFrame({
var functionExists = Spark_Connect_FunctionExists()
functionExists.functionName = functionName
functionExists.dbName = dbName
var catalog = Spark_Connect_Catalog()
catalog.functionExists = functionExists
return catalog
})
return "true" == (try await df.collect().first!.get(0) as! String)
}

/// Caches the specified table in-memory.
/// - Parameters:
/// - tableName: A qualified or unqualified name that designates a table/view.
Expand Down
44 changes: 44 additions & 0 deletions Tests/SparkConnectTests/CatalogTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,50 @@ struct CatalogTests {
#expect(try await spark.catalog.databaseExists(dbName) == false)
await spark.stop()
}

@Test
func createTable() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTable(spark, tableName)({
try await spark.range(1).write.orc("/tmp/\(tableName)")
#expect(try await spark.catalog.createTable(tableName, "/tmp/\(tableName)", source: "orc").count() == 1)
#expect(try await spark.catalog.tableExists(tableName))
})
await spark.stop()
}

@Test
func tableExists() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
try await SQLHelper.withTable(spark, tableName)({
try await spark.range(1).write.parquet("/tmp/\(tableName)")
#expect(try await spark.catalog.tableExists(tableName) == false)
#expect(try await spark.catalog.createTable(tableName, "/tmp/\(tableName)").count() == 1)
#expect(try await spark.catalog.tableExists(tableName))
#expect(try await spark.catalog.tableExists("default", tableName))
#expect(try await spark.catalog.tableExists("default2", tableName) == false)
})
#expect(try await spark.catalog.tableExists(tableName) == false)

try await #require(throws: Error.self) {
try await spark.catalog.tableExists("invalid table name")
}
await spark.stop()
}

@Test
func functionExists() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.catalog.functionExists("base64"))
#expect(try await spark.catalog.functionExists("non_exist_function") == false)

try await #require(throws: Error.self) {
try await spark.catalog.functionExists("invalid function name")
}
await spark.stop()
}
#endif

@Test
Expand Down
Loading