Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ Suggests:
odbc,
duckdb,
pool,
ParallelLogger
ParallelLogger,
AzureStor
License: Apache License
VignetteBuilder: knitr
URL: https://ohdsi.github.io/DatabaseConnector/, https://github.com/OHDSI/DatabaseConnector
Expand Down
1 change: 1 addition & 0 deletions DatabaseConnector.Rproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Version: 1.0
ProjectId: 9d51e576-41a3-432f-b696-8bfdc3eed676
Comment thread
schuemie marked this conversation as resolved.

RestoreWorkspace: No
SaveWorkspace: No
Expand Down
81 changes: 81 additions & 0 deletions R/BulkLoad.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ checkBulkLoadCredentials <- function(connection) {
return(FALSE)
}
return(TRUE)
} else if (dbms(connection) == "spark") {
envSet <- FALSE
container <- FALSE

if (Sys.getenv("AZR_STORAGE_ACCOUNT") != "" && Sys.getenv("AZR_ACCOUNT_KEY") != "" && Sys.setenv("AZR_CONTAINER_NAME") != "") {
envSet <- TRUE
}

# List storage containers to confirm the container
# specified in the configuration exists
ensure_installed("AzureStor")
azureEndpoint <- getAzureEndpoint()
containerList <- getAzureContainerNames(azureEndpoint)

if (Sys.getenv("AZR_CONTAINER_NAME") %in% containerList) {
container <- TRUE
}

return(envSet & container)
} else {
return(FALSE)
}
Expand All @@ -72,6 +91,18 @@ getHiveSshUser <- function() {
return(if (sshUser == "") "root" else sshUser)
}

getAzureEndpoint <- function() {
azureEndpoint <- AzureStor::storage_endpoint(
paste0("https://", Sys.getenv("AZR_STORAGE_ACCOUNT"), ".dfs.core.windows.net"),
key = Sys.getenv("AZR_ACCOUNT_KEY")
)
return(azureEndpoint)
}

getAzureContainerNames <- function(azureEndpoint) {
return(names(AzureStor::list_storage_containers(azureEndpoint)))
}

countRows <- function(connection, sqlTableName) {
sql <- "SELECT COUNT(*) FROM @table"
count <- renderTranslateQuerySql(
Expand Down Expand Up @@ -354,3 +385,53 @@ bulkLoadPostgres <- function(connection, sqlTableName, sqlFieldNames, sqlDataTyp
delta <- Sys.time() - startTime
inform(paste("Bulk load to PostgreSQL took", signif(delta, 3), attr(delta, "units")))
}

bulkLoadSpark <- function(connection, sqlTableName, data) {
ensure_installed("AzureStor")
logTrace(sprintf("Inserting %d rows into table '%s' using DataBricks bulk load", nrow(data), sqlTableName))
start <- Sys.time()

csvFileName <- tempfile("spark_insert_", fileext = ".csv")
write.csv(x = data, na = "", file = csvFileName, row.names = FALSE, quote = TRUE)
on.exit(unlink(csvFileName))

azureEndpoint <- getAzureEndpoint()
containers <- AzureStor::list_storage_containers(azureEndpoint)
targetContainer <- containers[[Sys.getenv("AZR_CONTAINER_NAME")]]
AzureStor::storage_upload(
targetContainer,
src=csvFileName,
dest=csvFileName
)

on.exit(
AzureStor::delete_storage_file(
targetContainer,
file = csvFileName,
confirm = FALSE
),
add = TRUE
)

sql <- SqlRender::loadRenderTranslateSql(
sqlFilename = "sparkCopy.sql",
packageName = "DatabaseConnector",
dbms = "spark",
sqlTableName = sqlTableName,
fileName = basename(csvFileName),
azureAccountKey = Sys.getenv("AZR_ACCOUNT_KEY"),
azureStorageAccount = Sys.getenv("AZR_STORAGE_ACCOUNT")
)

tryCatch(
{
DatabaseConnector::executeSql(connection = connection, sql = sql, reportOverallTime = FALSE)
},
error = function(e) {
abort("Error in DataBricks bulk upload. Please check DataBricks/Azure Storage access.")
}
)
delta <- Sys.time() - start
inform(paste("Bulk load to DataBricks took", signif(delta, 3), attr(delta, "units")))
}

9 changes: 9 additions & 0 deletions R/InsertTable.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ validateInt64Insert <- function() {
#' "some_aws_region", "AWS_BUCKET_NAME" = "some_bucket_name", "AWS_OBJECT_KEY" = "some_object_key",
#' "AWS_SSE_TYPE" = "server_side_encryption_type").
#'
#' Spark (DataBricks): The MPP bulk loading relies upon the AzureStor library
#' to test a connection to an Azure ADLS Gen2 storage container using Azure credentials.
#' Credentials are configured directly into the System Environment using the
#' following keys: Sys.setenv("AZR_STORAGE_ACCOUNT" =
#' "some_azure_storage_account", "AZR_ACCOUNT_KEY" = "some_secret_account_key", "AZR_CONTAINER_NAME" =
#' "some_container_name").
#'
#' PDW: The MPP bulk loading relies upon the client
#' having a Windows OS and the DWLoader exe installed, and the following permissions granted: --Grant
#' BULK Load permissions - needed at a server level USE master; GRANT ADMINISTER BULK OPERATIONS TO
Expand Down Expand Up @@ -308,6 +315,8 @@ insertTable.default <- function(connection,
bulkLoadHive(connection, sqlTableName, sqlFieldNames, data)
} else if (dbms == "postgresql") {
bulkLoadPostgres(connection, sqlTableName, sqlFieldNames, sqlDataTypes, data)
} else if (dbms == "spark") {
bulkLoadSpark(connection, sqlTableName, data)
}
} else if (useCtasHack) {
# Inserting using CTAS hack ----------------------------------------------------------------
Expand Down
34 changes: 34 additions & 0 deletions extras/TestBulkLoad.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,37 @@ all.equal(data, data2)

renderTranslateExecuteSql(connection, "DROP TABLE scratch_mschuemi.insert_test;")
disconnect(connection)


# Spark ------------------------------------------------------------------------------
# Assumes Spark (DataBricks) environmental variables have been set
options(sqlRenderTempEmulationSchema = Sys.getenv("DATABRICKS_SCRATCH_SCHEMA"))
databricksConnectionString <- paste0("jdbc:databricks://", Sys.getenv('DATABRICKS_HOST'), "/default;transportMode=http;ssl=1;AuthMech=3;httpPath=", Sys.getenv('DATABRICKS_HTTP_PATH'))
connectionDetails <- createConnectionDetails(dbms = "spark",
connectionString = databricksConnectionString,
user = "token",
password = Sys.getenv("DATABRICKS_TOKEN"))


connection <- connect(connectionDetails)
system.time(
insertTable(connection = connection,
tableName = "scratch.scratch_asena5.insert_test",
data = data,
dropTableIfExists = TRUE,
createTable = TRUE,
tempTable = FALSE,
progressBar = TRUE,
camelCaseToSnakeCase = TRUE,
bulkLoad = TRUE)
)
data2 <- querySql(connection, "SELECT * FROM scratch.scratch_asena5.insert_test;", snakeCaseToCamelCase = TRUE, integer64AsNumeric = FALSE)

data <- data[order(data$id), ]
data2 <- data2[order(data2$id), ]
row.names(data) <- NULL
row.names(data2) <- NULL
all.equal(data, data2)

renderTranslateExecuteSql(connection, "DROP TABLE scratch.scratch_asena5.insert_test;")
disconnect(connection)
10 changes: 10 additions & 0 deletions inst/sql/sql_server/sparkCopy.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
COPY INTO @sqlTableName
FROM 'abfss://@azureStorageAccount.dfs.core.windows.net/@fileName'
WITH (
CREDENTIAL (AZURE_SAS_TOKEN = '@azureAccountKey')
)
FILEFORMAT = CSV
FORMAT_OPTIONS (
'header' = 'true',
'inferSchema' = 'true'
);