In [None]:
import random
import string

from pyspark import SparkConf
from pyspark.sql import SparkSession

import pyspark.sql.functions as f
import pyspark.sql.types as t

In [None]:
def create_spark_session(app_name: str, config: dict) -> SparkSession:
    """ Create Spark Session Object """
    
    spark_conf = SparkConf().setAll([
        (k, v) for k, v in config.items()
    ])

    sess = SparkSession \
             .builder \
             .appName(app_name) \
             .config(conf=spark_conf) \
             .getOrCreate()
        
    return sess

def setup_database(spark: SparkSession, db_conf: dict, 
                    batchsize: int = 1000, timeout: int = 60):
    """ Populate database with dummy data """

    def name():
        return "".join(random.choice(string.ascii_letters) for _ in range(random.randint(4, 10)))
    
    def product():
        # return "".join(random.choice(string.ascii_letters) for _ in range(random.randint(4, 5)))        
        CHOICES = ("apple", "banana", "orange", "chicken", "steak", "shrimp", "pasta", "rice")
        return random.choice(CHOICES)
    
    def age():
        return random.randint(18, 80)
    
    def price():
        return round(random.random() * 100, 2)
    
    # USERS
    print("Creating Users Table...")
    users = spark.createDataFrame([
        (i, name(), age()) for i in range(1000)
    ], schema=["id", "name", "age"])
        
    users.repartition(1).write.format("jdbc").options(
        url=db_conf["url"],
        driver=db_conf["driver"],
        dbtable="users",
        user=db_conf["user"],
        password=db_conf["password"],
        batchsize=batchsize,
        queryTimeout=timeout,
    ).mode("overwrite").save()
        
    # PURCHASES
    print("Creating Purchases Table...")
    purchases = spark.createDataFrame([
        (random.randint(1, 1000), product(), price())
    for _ in range(20000)], schema=["user", "product", "price"])
    
    purchases.repartition(1).write.format("jdbc").options(
        url=db_conf["url"],
        driver=db_conf["driver"],
        dbtable="purchases",
        user=db_conf["user"],
        password=db_conf["password"],
        batchsize=batchsize,
        queryTimeout=timeout,
    ).mode("overwrite").save()

In [None]:
# Configuration
APP = "pyspark-postgres-demo"
DB_CONF = {
    "url": "jdbc:postgresql://warehouse:5432/dwh",
    "user": "dwh",
    "password": "password",
    "driver": "org.postgresql.Driver",
}  

In [None]:
# Session
spark = create_spark_session(app_name=APP, config={
    "spark.jars": "/home/jovyan/work/jars/postgresql-42.2.24.jre6.jar"
})

In [None]:
# Setup Database
setup_database(
    spark=spark,
    db_conf=DB_CONF
)

In [None]:
users = spark.read \
             .format("jdbc") \
             .option("url", DB_CONF["url"]) \
             .option("dbtable", "users") \
             .option("user", DB_CONF["user"]) \
             .option("password", DB_CONF["password"]) \
             .option("driver", DB_CONF["driver"]) \
             .load()

In [None]:
purchases = spark.read \
             .format("jdbc") \
             .option("url", DB_CONF["url"]) \
             .option("dbtable", "purchases") \
             .option("user", DB_CONF["user"]) \
             .option("password", DB_CONF["password"]) \
             .option("driver", DB_CONF["driver"]) \
             .load()

In [None]:
# Stage
users.limit(100).repartition(1).write.format("jdbc").options(
        url=DB_CONF["url"],
        driver=DB_CONF["driver"],
        dbtable="stg__users",
        user=DB_CONF["user"],
        password=DB_CONF["password"],
    ).mode("overwrite").save()