In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

In [None]:
spark = (
    SparkSession
        .builder
        .master("local")
        .appName("Data Sources")
        .config("spark.jars", "jars/postgresql-42.7.2.jar")
        .config("spark.driver.memory", "16g")
        .getOrCreate()
)

## Драйвер

In [None]:
driver = "org.postgresql.Driver"
url = "jdbc:postgresql://localhost:5432/spark"
user = "postgres"
password = "postgres"

## Чтение таблицы целиком

### Пример 1

In [None]:
employees_df = spark.read \
    .format("jdbc") \
    .option("driver", driver) \
    .option("url", url) \
    .option("user", user) \
    .option("password", password) \
    .option("dbtable", "public.employees") \
    .load()

employees_df.count()

In [None]:
employees_df.show(10)

### Пример 2

In [None]:
DBPARAMS = {
    "user": user,
    "password": password,
    "driver": driver
}

In [None]:
df = spark.read.jdbc(url=url, table="public.employees", properties=DBPARAMS)
df.count()

In [None]:
df.show(10)

In [None]:
df.rdd.getNumPartitions()

In [None]:
df.agg(min(col("emp_no")), max(col("emp_no"))).show()

### Как распараллелить чтение?

In [None]:
df101 = spark.read.jdbc(url=url, table="public.employees", properties=DBPARAMS, numPartitions=10)
df101.count()

In [None]:
df101.rdd.getNumPartitions()

In [None]:
df102 = spark.read.jdbc(url=url, table="public.employees", properties=DBPARAMS,
                        column="emp_no", lowerBound = 10010, upperBound = 499990, numPartitions=10)
df102.count()

In [None]:
df102.rdd.getNumPartitions()

In [None]:
df103 = spark.read \
    .format("jdbc") \
    .option("driver", driver) \
    .option("url", url) \
    .option("user", user) \
    .option("password", password) \
    .option("dbtable", "public.employees") \
    .option("partitionColumn", "emp_no") \
    .option("lowerBound", "10010") \
    .option("upperBound", "499990") \
    .option("numPartitions", "10") \
    .load()

df103.count()

In [None]:
df103.rdd.getNumPartitions()

In [None]:
employees_pruned = """(select e.first_name, e.last_name, e.hire_date from public.employees e where e.gender = 'F') as new_emp"""
df_pruned = spark.read.jdbc(url=url, table=employees_pruned, properties=DBPARAMS)
df_pruned.count()

In [None]:
df_pruned.show(10)

## Предикаты

### Пример 1

In [None]:
pred = ["gender = 'M'", "gender = 'F'"]

df_pred = spark.read.jdbc(url=url, table="public.employees", properties=DBPARAMS, predicates=pred)
df_pred.count()

In [None]:
df_pred.rdd.getNumPartitions()

In [None]:
pred1 = ["gender = 'F'"]

df_pred1 = spark.read.jdbc(url=url, table="public.employees", properties=DBPARAMS, predicates=pred1)
df_pred1.count()

In [None]:
df_pred1.rdd.getNumPartitions()

In [None]:
pred3 = ["gender = 'F'", "gender = 'M'", "gender = 'M'"]

df_pred3 = spark.read.jdbc(url=url, table="public.employees", properties=DBPARAMS, predicates=pred3)
df_pred3.count()

In [None]:
df_pred3.rdd.getNumPartitions()

In [None]:
df.groupBy(col("gender")).agg(count(col("emp_no"))).show()

In [None]:
df_pred3.groupBy(col("gender")).agg(count(col("emp_no"))).show()

### Пример 2

In [None]:
pred2 = ["emp_no > 20000 and emp_no <= 50000", "emp_no >= 50000 and emp_no <= 100000"]

df_pred2 = spark.read.jdbc(url=url, table="public.employees", properties=DBPARAMS, predicates=pred2)

df_pred2.count()

In [None]:
df_pred2.show(10)

## Фильтрация

In [None]:
q = """select * from public.employees where emp_no > 20000 and emp_no <= 50000"""

dfq = spark.read \
    .format("jdbc") \
    .option("driver", driver) \
    .option("url", url) \
    .option("user", user) \
    .option("password", password) \
    .option("query", q) \
    .load()

dfq.count()

In [None]:
dfq.show(10)

## Соединения в базе

In [None]:
qj = """select e.emp_no, birth_date, first_name, last_name, gender, hire_date, salary, from_date, to_date
from employees e join salaries s on e.emp_no = s.emp_no"""

dfj = spark.read \
    .format("jdbc") \
    .option("driver", driver) \
    .option("url", url) \
    .option("user", user) \
    .option("password", password) \
    .option("query", qj) \
    .load()

dfj.count()

In [None]:
dfj.show()

## Запись в таблицу

In [None]:
employees_df.show(10)

In [None]:
salaries_df = spark.read \
    .format("jdbc") \
    .option("driver", driver) \
    .option("url", url) \
    .option("user", user) \
    .option("password", password) \
    .option("dbtable", "public.salaries") \
    .load()

salaries_df.count()

In [None]:
salaries_df.show(10)

In [None]:
employees_salaries_df = employees_df.join(salaries_df.groupBy(col("emp_no")).agg(max(col("salary")).alias("max_salary")), "emp_no")

employees_salaries_df.show()

In [None]:
employees_salaries_df.write \
    .mode("overwrite") \
    .format("jdbc") \
    .option("driver", driver) \
    .option("url", url) \
    .option("user", user) \
    .option("password", password) \
    .option("dbtable", "public.employees_salaries") \
    .option("truncate", "true") \
    .save()