## Demo 5: Connecting to Database Management Systems using PySpark

This notebook illustrates how to use PySpark to extract data from common database systems including:
- A local instance of MySQL
- An instance of MySQL hosted in Azure
- A local instance of Microsoft SQL Server
- An instance of Azure SQL Database
- A local instance of MongoDb
- An instance of MongoDb Atlas hosted on AWS

This requires a number of prerequisite steps that include accessing driver software (i.e., connection libraries) from the Spark Session, and providing the appropriate connection information. These include the following JAR files which should be located in the 'Current Working Directory'; i.e., `os.getcwd()`:
- `mysql-connector-j-9.1.0\mysql-connector-j-9.1.0.jar`
- `sqljdbc_12.8\enu\jars\mssql-jdbc-12.8.1.jre11.jar`

And the following JAR Package which is downloaded from Maven while instantiating the Spark Session object:
- `mongo-spark-connector_2.12:3.0.1`

Of course, this notebook assumes the existence of the following database management systems and corresponding databases:
    
| Database Product | Location | Database |
| ----- | ----- | ----- |
| MySQL| Local Workstation| Northwind_DW2|
| Azure MySQL | Azure Cloud | Northwind_DW2 |
| MS SQL Server| Local Workstation| AdventureWorksLT2022|
| Azure SQL Database | Azure Cloud | AdventureWorksLT |
| MongoDB| Local Workstation| Northwind_DW2 JSON Extracts|
| MongoDB Atlas | AWS Cloud | Northwind_DW2 JSON Extracts |

### Prerequisites:
#### Import Required Libraries

In [1]:
import findspark
findspark.init()
findspark.find()

'C:\\spark-3.5.4-bin-hadoop3'

In [2]:
import os
import sys
import json
import pymongo
import certifi

from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

#### Instantiate Global Variables

In [3]:
mongodb_args = {
    "cluster_location" : "local", # "atlas"
    "user_name" : "jtupitza",
    "password" : "Passw0rd1234",
    "cluster_name" : "sandbox",
    "cluster_subnet" : "zibbf",
    "db_name" : "northwind_purchasing",
    "collection" : "",
    "null_column_threshold" : 0.5
}

jars = []
mysql_spark_jar = os.path.join(os.getcwd(), "mysql-connector-j-9.1.0", "mysql-connector-j-9.1.0.jar")
mssql_spark_jar = os.path.join(os.getcwd(), "sqljdbc_12.8", "enu", "jars", "mssql-jdbc-12.8.1.jre11.jar")

jars.append(mysql_spark_jar)
jars.append(mssql_spark_jar)

dest_database = "healthcare_dlh"
sql_warehouse_dir = os.path.abspath('spark-warehouse')
database_dir = os.path.join(sql_warehouse_dir, dest_database)

#### Define Global Functions

In [4]:
def get_mongo_uri(**args):
    '''Validate proper input'''
    if args["cluster_location"] not in ['atlas', 'local']:
        raise Exception("You must specify either 'atlas' or 'local' for the 'cluster_location' parameter.")
        
    if args['cluster_location'] == "atlas":
        uri = f"mongodb+srv://{mongodb_args['user_name']}:{mongodb_args['password']}@"
        uri += f"{mongodb_args['cluster_name']}.{mongodb_args['cluster_subnet']}.mongodb.net/"
    else:
        uri = "mongodb://localhost:27017/"

    return uri


def get_spark_conf_args(spark_jars : list, **args):
    jars = ""
    for jar in spark_jars:
        jars += f"{jar}, "
    
    sparkConf_args = {
        "app_name" : "PySpark Northwind Data Lakehouse (Medallion Architecture)",
        "worker_threads" : f"local[{int(os.cpu_count()/2)}]",
        "shuffle_partitions" : int(os.cpu_count()),
        "mongo_uri" : get_mongo_uri(**args),
        "spark_jars" : jars[0:-2],
        "database_dir" : database_dir
    }
    
    return sparkConf_args
    

def get_spark_conf(**args):
    sparkConf = SparkConf().setAppName(args['app_name'])\
    .setMaster(args['worker_threads']) \
    .set("spark.executor.memory", "2g") \
    .set("spark.driver.memory", "4g") \
    .set("spark.sql.shuffle.partitions", args['shuffle_partitions']) \
    .set("spark.jars", args['spark_jars']) \
    .set("spark.jars.packages", "org.mongodb.spark:mongo-spark-connector_2.12:3.0.1") \
    .set("spark.mongodb.input.uri", args['mongo_uri']) \
    .set("spark.mongodb.output.uri", args['mongo_uri']) \
    .set("spark.streaming.stopGracefullyOnShutdown", "true") \
    .set("spark.sql.streaming.forceDeleteTempCheckpointLocation", "true") \
    .set("spark.sql.streaming.schemaInference", "true") \
    .set("spark.sql.warehouse.dir", args['database_dir'])
    
    return sparkConf

#### Create a New Spark Session

In [5]:
sparkConf_args = get_spark_conf_args(jars, **mongodb_args)

sparkConf = get_spark_conf(**sparkConf_args)
spark = SparkSession.builder.config(conf=sparkConf).getOrCreate()

### Fetch Data from MySQL

In [6]:
def get_mysql_dataframe(spark_session, sql_query, **args):
    '''Create a JDBC URL to the MySQL Database'''
    jdbc_url = f"jdbc:mysql://{args['host_name']}:{args['port']}/{args['db_name']}"
    
    '''Invoke the spark.read.jdbc() function to query the database, and fill a Pandas DataFrame.'''
    #dframe = spark_session.read.jdbc(url=conn_string, table=sql_query, properties=args['conn_props'])
    dframe = spark_session.read.format("jdbc") \
    .option("url", jdbc_url) \
    .option("driver", args['conn_props']['driver']) \
    .option("user", args['conn_props']['user']) \
    .option("password", args['conn_props']['password']) \
    .option("query", sql_query) \
    .load()
    
    return dframe

##### Fetch Data from a Local Instance of MySQL Server

In [7]:
mysql_args = {
    "host_name" : "localhost",
    "port" : "3306",
    "db_name" : "Northwind_DW2",
    "conn_props" : {
        "user" : "jtupitza",
        "password" : "Passw0rd123!",
        "driver" : "com.mysql.jdbc.Driver"
    }
}

print(f"Fetching data from: {mysql_args['host_name']}")

sql_dim_customers = f"SELECT * FROM {mysql_args['db_name']}.dim_customers"

df_dim_customers = get_mysql_dataframe(spark, sql_dim_customers, **mysql_args)
df_dim_customers.toPandas().head(5)

Fetching data from: localhost


Unnamed: 0,customer_key,customer_id,company,last_name,first_name,job_title,business_phone,fax_number,address,city,state_province,zip_postal_code,country_region
0,1,1,Company A,Bedecs,Anna,Owner,(123)555-0100,(123)555-0101,123 1st Street,Seattle,WA,99999,USA
1,2,2,Company B,Gratacos Solsona,Antonio,Owner,(123)555-0100,(123)555-0101,123 2nd Street,Boston,MA,99999,USA
2,3,3,Company C,Axen,Thomas,Purchasing Representative,(123)555-0100,(123)555-0101,123 3rd Street,Los Angelas,CA,99999,USA
3,4,4,Company D,Lee,Christina,Purchasing Manager,(123)555-0100,(123)555-0101,123 4th Street,New York,NY,99999,USA
4,5,5,Company E,O’Donnell,Martin,Owner,(123)555-0100,(123)555-0101,123 5th Street,Minneapolis,MN,99999,USA


##### Fetch Data from an Instance of Azure MySQL Server

In [8]:
mysql_args['host_name'] = "jtupitz-mysql.mysql.database.azure.com"
print(f"Fetching data from: {mysql_args['host_name']}")

sql_dim_customers = f"SELECT * FROM {mysql_args['db_name']}.dim_customers"

df_customers = get_mysql_dataframe(spark, sql_dim_customers, **mysql_args)
df_customers.toPandas().head(5)

Fetching data from: jtupitz-mysql.mysql.database.azure.com


Unnamed: 0,customer_key,customer_id,company,last_name,first_name,job_title,business_phone,fax_number,address,city,state_province,zip_postal_code,country_region
0,1,1,Company A,Bedecs,Anna,Owner,(123)555-0100,(123)555-0101,123 1st Street,Seattle,WA,99999,USA
1,2,2,Company B,Gratacos Solsona,Antonio,Owner,(123)555-0100,(123)555-0101,123 2nd Street,Boston,MA,99999,USA
2,3,3,Company C,Axen,Thomas,Purchasing Representative,(123)555-0100,(123)555-0101,123 3rd Street,Los Angelas,CA,99999,USA
3,4,4,Company D,Lee,Christina,Purchasing Manager,(123)555-0100,(123)555-0101,123 4th Street,New York,NY,99999,USA
4,5,5,Company E,O’Donnell,Martin,Owner,(123)555-0100,(123)555-0101,123 5th Street,Minneapolis,MN,99999,USA


### Fetch Data from SQL Server

In [9]:
def get_sql_dataframe(spark_session, sql_query, **args):
    '''Create a JDBC URL to the SQL Server Database'''
    jdbc_url = f"jdbc:sqlserver://{args['host_name']}:{args['port']};database={args['db_name']};TrustServerCertificate=True"
    
    '''Invoke the spark.read.jdbc() function to query the database, and fill a Pandas DataFrame.'''
    #dframe = spark_session.read.jdbc(url=jdbc_url, table=table_name, properties=args['conn_props'])
    dframe = spark_session.read.format("jdbc") \
    .option("url", jdbc_url) \
    .option("driver", args['conn_props']['driver']) \
    .option("user", args['conn_props']['user']) \
    .option("password", args['conn_props']['password']) \
    .option("query", sql_query) \
    .load()
    
    return dframe

##### Fetch Data from a Local Instance of SQL Server

In [10]:
sqlsvr_args = {
    "host_name" : "localhost",
    "port" : "1433",
    "db_name" : "AdventureWorksLT2022",
    "conn_props" : {
        "user" : "jtupitza",
        "password" : "Passw0rd123!",
        "driver" : "com.microsoft.sqlserver.jdbc.SQLServerDriver"
    }
}

sql_customer = "SELECT * FROM SalesLT.Customer"

df_customers = get_sql_dataframe(spark, sql_customer, **sqlsvr_args)
df_customers.toPandas().head(5)

Unnamed: 0,CustomerID,NameStyle,Title,FirstName,MiddleName,LastName,Suffix,CompanyName,SalesPerson,EmailAddress,Phone,PasswordHash,PasswordSalt,rowguid,ModifiedDate
0,1,False,Mr.,Orlando,N.,Gee,,A Bike Store,adventure-works\pamela0,orlando0@adventure-works.com,245-555-0173,L/Rlwxzp4w7RWmEgXX+/A7cXaePEPcp+KwQhl2fJL7w=,1KjXYs4=,3F5AE95E-B87D-4AED-95B4-C3797AFCB74F,2005-08-01
1,2,False,Mr.,Keith,,Harris,,Progressive Sports,adventure-works\david8,keith0@adventure-works.com,170-555-0127,YPdtRdvqeAhj6wyxEsFdshBDNXxkCXn+CRgbvJItknw=,fs1ZGhY=,E552F657-A9AF-4A7D-A645-C429D6E02491,2006-08-01
2,3,False,Ms.,Donna,F.,Carreras,,Advanced Bike Components,adventure-works\jillian0,donna0@adventure-works.com,279-555-0130,LNoK27abGQo48gGue3EBV/UrlYSToV0/s87dCRV7uJk=,YTNH5Rw=,130774B1-DB21-4EF3-98C8-C104BCD6ED6D,2005-09-01
3,4,False,Ms.,Janet,M.,Gates,,Modular Cycle Systems,adventure-works\jillian0,janet1@adventure-works.com,710-555-0173,ElzTpSNbUW1Ut+L5cWlfR7MF6nBZia8WpmGaQPjLOJA=,nm7D5e4=,FF862851-1DAA-4044-BE7C-3E85583C054D,2006-07-01
4,5,False,Mr.,Lucy,,Harrington,,Metropolitan Sports Supply,adventure-works\shu0,lucy0@adventure-works.com,828-555-0186,KJqV15wsX3PG8TS5GSddp6LFFVdd3CoRftZM/tP0+R4=,cNFKU4w=,83905BDC-6F5E-4F71-B162-C98DA069F38A,2006-09-01


##### Fetch Data from an Instance of Azure SQL Database
- Note: `Microsoft Entra Authentication Only` must be disabled

In [None]:
sqlsvr_args = {
    "host_name" : "jtupitz-sql.database.windows.net",
    "port" : "1433",
    "db_name" : "AdventureWorksLT",
    "conn_props" : {
        "user" : "jtupitza",
        "password" : "Passw0rd123!",
        "driver" : "com.microsoft.sqlserver.jdbc.SQLServerDriver"
    }
}

sql_customer = "SELECT * FROM SalesLT.Customer"

df_customers = get_sql_dataframe(spark, sql_customer, **sqlsvr_args)
df_customers.toPandas().head(5)

### Fetch Data from MongoDB

In [11]:
def get_mongo_client(**args):
    '''Create a MongoDB Client Connection'''
    mongo_uri = get_mongo_uri(**args)
    
    if args['cluster_location'] == "atlas":
        client = pymongo.MongoClient(mongo_uri, tlsCAFile=certifi.where())

    elif args['cluster_location'] == "local":
        client = pymongo.MongoClient(mongo_uri)
        
    else:
        raise Exception("A MongoDB Client could not be created.")

    return client

In [12]:
def set_mongo_collections(mongo_client, db_name, data_directory, json_files):
    db = mongo_client[db_name]
    
    for file in json_files:
        db.drop_collection(file)
        json_file = os.path.join(data_directory, json_files[file])
        with open(json_file, 'r') as openfile:
            json_object = json.load(openfile)
            file = db[file]
            result = file.insert_many(json_object)
        
    mongo_client.close()

In [13]:
def get_mongodb_dataframe(spark_session, **args):
    '''Query MongoDB, and create a DataFrame'''
    dframe = spark_session.read.format("com.mongodb.spark.sql.DefaultSource") \
        .option("database", args['db_name']) \
        .option("collection", args['collection']).load()
    
    '''Drop the '_id' index column to clean up the response.'''
    dframe = dframe.drop('_id')
    
    return dframe

#### Fetch Data from a Local Instance of MongoDB
##### Upload Local JSON Files to a Local Instance of MongoDb

In [14]:
mongodb_args["cluster_location"] = "local"

client = get_mongo_client(**mongodb_args)

data_dir = os.path.join(os.getcwd(), 'lab_data', 'mongodb-data')

json_files = {"suppliers" : 'northwind_suppliers.json',
              "invoices" : 'northwind_invoices.json',
              "purchase_orders" : 'northwind_purchase_orders.json',
              "inventory_transactions" : 'northwind_inventory_transactions.json'
             }

set_mongo_collections(client, mongodb_args["db_name"], data_dir, json_files) 

##### Fetch Data from a Local MongoDB Collection

In [15]:
mongodb_args["cluster_location"] = "local"
mongodb_args["collection"] = "suppliers"

df_suppliers = get_mongodb_dataframe(spark, **mongodb_args)
df_suppliers.toPandas().head(5)

Unnamed: 0,company,first_name,id,job_title,last_name
0,Supplier A,Elizabeth A.,1,Sales Manager,Andersen
1,Supplier B,Cornelia,2,Sales Manager,Weiler
2,Supplier C,Madeleine,3,Sales Representative,Kelley
3,Supplier D,Naoki,4,Marketing Manager,Sato
4,Supplier E,Amaya,5,Sales Manager,Hernandez-Echevarria


In [16]:
spark.stop()

#### Fetch Data from MongoDB Atlas
##### Upload Local JSON Files to a MongoDB Atlas Cluster

In [17]:
mongodb_args["cluster_location"] = "atlas"

client = get_mongo_client(**mongodb_args)

data_dir = os.path.join(os.getcwd(), 'lab_data', 'mongodb-data')

json_files = {"suppliers" : 'northwind_suppliers.json',
              "invoices" : 'northwind_invoices.json',
              "purchase_orders" : 'northwind_purchase_orders.json',
              "inventory_transactions" : 'northwind_inventory_transactions.json'
             }

set_mongo_collections(client, mongodb_args["db_name"], data_dir, json_files) 

##### Fetch Data from a MongoDB Atlas Collection

In [19]:
mongodb_args["cluster_location"] = "atlas"
mongodb_args["collection"] = "suppliers"

sparkConf_args = get_spark_conf_args(jars, **mongodb_args)

sparkConf = get_spark_conf(**sparkConf_args)
spark = SparkSession.builder.config(conf=sparkConf).getOrCreate()

df_suppliers = get_mongodb_dataframe(spark, **mongodb_args)
df_suppliers.toPandas().head(5)

Unnamed: 0,company,first_name,id,job_title,last_name
0,Supplier A,Elizabeth A.,1,Sales Manager,Andersen
1,Supplier B,Cornelia,2,Sales Manager,Weiler
2,Supplier C,Madeleine,3,Sales Representative,Kelley
3,Supplier D,Naoki,4,Marketing Manager,Sato
4,Supplier E,Amaya,5,Sales Manager,Hernandez-Echevarria


In [20]:
spark.stop()