# Challenge 1: ETL Pipeline - Data Extraction from PostgreSQL

## Task Description
In this challenge, we need to:
1. Connect to PostgreSQL database
2. Extract data from multiple tables
3. Optimize data loading strategies
4. Handle schema discovery

## Setup

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

# Create a Spark session
spark = SparkSession.builder \
    .appName("ETL Pipeline - Data Extraction") \
    .master("local[*]") \
    .config("spark.sql.shuffle.partitions", "8") \
    .getOrCreate()

# Set log level
spark.sparkContext.setLogLevel("WARN")

## PostgreSQL Connection Setup

In [None]:
# TODO: Set up connection parameters
jdbc_url = "jdbc:postgresql://postgres:5432/datamart"
connection_properties = {
    "user": "spark",
    "password": "spark",
    "driver": "org.postgresql.Driver"
}

# TODO: Validate connection
# Hint: Try a simple query to verify connection works

## Basic Data Extraction

In [None]:
# TODO: Extract data from the customers table
customers_df = spark.read \
    .format("jdbc") \
    .option("url", jdbc_url) \
    .option("dbtable", "raw.customers") \
    .option("user", connection_properties["user"]) \
    .option("password", connection_properties["password"]) \
    .option("driver", connection_properties["driver"]) \
    .load()

# Display schema and sample data
customers_df.printSchema()
customers_df.show(5)

## Parallel Data Extraction

In [None]:
# TODO: Use partitioning to load data in parallel
# Hint: Use numPartitions, partitionColumn, lowerBound, upperBound options

orders_df = spark.read \
    .format("jdbc") \
    .option("url", jdbc_url) \
    .option("dbtable", "raw.orders") \
    .option("user", connection_properties["user"]) \
    .option("password", connection_properties["password"]) \
    .option("driver", connection_properties["driver"]) \
    .option("numPartitions", 4) \
    .option("partitionColumn", "order_id") \
    .option("lowerBound", 1) \
    .option("upperBound", 1000) \
    .load()

# Check number of partitions
print(f"Number of partitions: {orders_df.rdd.getNumPartitions()}")
orders_df.show(5)

## Custom Query Extraction

In [None]:
# TODO: Extract data using a custom SQL query
# Hint: Use the query option instead of dbtable

query = """
SELECT o.order_id, o.customer_id, o.order_date, 
       SUM(oi.quantity * oi.price) as total_amount,
       COUNT(oi.order_item_id) as item_count
FROM raw.orders o
JOIN raw.order_items oi ON o.order_id = oi.order_id
GROUP BY o.order_id, o.customer_id, o.order_date
ORDER BY o.order_date DESC
LIMIT 100
"""

order_summary_df = spark.read \
    .format("jdbc") \
    .option("url", jdbc_url) \
    .option("query", query) \
    .option("user", connection_properties["user"]) \
    .option("password", connection_properties["password"]) \
    .option("driver", connection_properties["driver"]) \
    .load()

order_summary_df.show()

## Schema Discovery

In [None]:
# TODO: Write a function to discover all tables in the database
# Hint: Query the information_schema.tables view

def get_schema_tables(schema_name):
    """Get all tables for a given schema"""
    query = f"""
    SELECT table_name 
    FROM information_schema.tables
    WHERE table_schema = '{schema_name}'
    """
    
    return spark.read \
        .format("jdbc") \
        .option("url", jdbc_url) \
        .option("query", query) \
        .option("user", connection_properties["user"]) \
        .option("password", connection_properties["password"]) \
        .option("driver", connection_properties["driver"]) \
        .load()

# Get all tables in the 'raw' schema
raw_tables = get_schema_tables("raw")
raw_tables.show()

## Incremental Data Loading

In [None]:
# TODO: Implement a function for incremental loading based on a timestamp column
# Hint: Store the last processed timestamp and only load newer data

def load_incremental_data(table_name, timestamp_col, last_processed_time):
    """Load only data newer than last_processed_time"""
    
    # Format the timestamp for SQL
    formatted_time = last_processed_time.strftime("%Y-%m-%d %H:%M:%S")
    
    # Create query with timestamp filter
    query = f"SELECT * FROM {table_name} WHERE {timestamp_col} > '{formatted_time}'"
    
    # Load data
    return spark.read \
        .format("jdbc") \
        .option("url", jdbc_url) \
        .option("query", query) \
        .option("user", connection_properties["user"]) \
        .option("password", connection_properties["password"]) \
        .option("driver", connection_properties["driver"]) \
        .load()

# Example usage (comment out if tables don't exist yet)
# from datetime import datetime
# last_processed = datetime(2023, 1, 1)
# new_orders = load_incremental_data("raw.orders", "order_date", last_processed)
# new_orders.show()

## Transaction Scope

In [None]:
# TODO: Implement extraction within a database transaction
# Note: This requires using the PySpark JDBC internals or the psycopg2 library