# Create Spark Context

In [None]:
import sagemaker_pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit
from datetime import date, datetime, timedelta

classpath = ":".join(sagemaker_pyspark.classpath_jars())
spark = (
    SparkSession.builder \
        .config("spark.driver.extraClassPath", classpath) \
        .config("spark.sql.sources.partitionOverwriteMode", "dynamic") \
        .config("spark.sql.orc.enabled", "true") \
        .config("spark.dynamicAllocation.enabled", "true") \
        .config("spark.hadoop.mapreduce.outputcommitter.factory.scheme.s3a", "org.apache.hadoop.fs.s3a.commit.S3ACommitterFactory") \
        .config("spark.hadoop.fs.s3a.committer.name", "directory") \
        .config("spark.hadoop.fs.s3a.committer.staging.conflict-mode", "append") \
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
        .getOrCreate()
)

# Better Spark Display for .show()

In [None]:
from IPython.display import HTML


def better_show(df, num_rows=50):
    """
    Display a PySpark DataFrame as an HTML table in Jupyter notebook.

    Parameters:
    df (DataFrame): The PySpark DataFrame to display.
    num_rows (int): Number of rows to display. Default is 50.
    """
    # Collect the specified number of rows as a list of dictionaries
    rows = df.limit(num_rows).collect()

    # Create an HTML table string with column headers
    html = "<table border='1'><tr>" + "".join([f"<th>{col}</th>" for col in df.columns]) + "</tr>"

    # Add the rows to the table
    for row in rows:
        html += "<tr>" + "".join([f"<td>{value}</td>" for value in row]) + "</tr>"

    html += "</table>"

    # Display the HTML table
    return HTML(html)

# Read table from S3

In [None]:
df_path = 's3a://your-bucket-name/path/to/your/table/'
df = spark.read.format("orc").load(df_path)
df.createOrReplaceTempView("your_table")
better_show(df, 5)

# Write table in S3

In [None]:
# Uncomment when writing
final_df = final_df.repartition('partition_id') # write one file per partition

final_df.write \
    .format('orc') \ # change this depending on your format
    .mode('overwrite') \
    .partitionBy('partition_id') \
    .save(df_path)