In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import year, month, to_date, date_format, lit
from pyspark import SparkConf
from botocore.exceptions import ClientError
import boto3
import botocore
import os
import awswrangler as wr

In [2]:
# Get env vars
database_name = os.environ.get("DB_NAME")
table_name = os.environ.get("TB_NAME")
bucket_name = os.environ.get("BUCKET_NAME")
work_group = os.environ.get("WORKGROUP")
region = os.environ.get("AWS_REGION")
role_arn = os.environ.get("AWS_ROLE")

In [4]:
# Assume the IAM role
# Create a "starter" session with your existing credentials
sts_client = boto3.client('sts')

# Assume the IAM role
assumed_role_response = sts_client.assume_role(
    RoleArn=role_arn,
    RoleSessionName='SampleSessionName'
)

# # 3. Get the temporary credentials from the response
# credentials = assumed_role_response['Credentials']

# # 4. Create a new session with the assumed role credentials
# session = boto3.Session(
#     aws_access_key_id=credentials['AccessKeyId'],
#     aws_secret_access_key=credentials['SecretAccessKey'],
#     aws_session_token=credentials['SessionToken']
# )

In [5]:
# Create SparkSession
spark = SparkSession.builder \
    .appName('sample_spark_app') \
    .config("spark.sql.catalog.AwsGlueCatalog", "org.apache.iceberg.spark.SparkCatalog") \
    .config("spark.sql.catalog.AwsGlueCatalog.catalog-impl", "org.apache.iceberg.aws.glue.GlueCatalog") \
    .config("spark.sql.catalog.AwsGlueCatalog.warehouse", "s3a://bd-datawarehouse/") \
    .config("spark.sql.catalog.AwsGlueCatalog.io-impl", "org.apache.iceberg.aws.s3.S3FileIO") \
    .getOrCreate()

# Set log level to WARN
spark.sparkContext.setLogLevel("WARN")


SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/home/glue_user/spark/jars/slf4j-reload4j-1.7.36.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/home/glue_user/spark/jars/log4j-slf4j-impl-2.17.2.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/home/glue_user/aws-glue-libs/jars/slf4j-reload4j-1.7.36.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/home/glue_user/aws-glue-libs/jars/log4j-slf4j-impl-2.17.2.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.impl.Reload4jLoggerFactory]
25/02/09 18:41:28 INFO SparkContext: Running Spark version 3.3.0-amzn-1
25/02/09 18:41:29 INFO ResourceUtils: No custom resources configured for spark.driver.
25/02/09 18:41:29 INFO SparkContext: Submitted application: sample_spark_app
25/02/09 18:41:29 INFO 

In [6]:
# Read a sample csv dataset
df = spark.read.csv('data/tips.csv', inferSchema=True, header=True)

In [19]:
# Its just a good practice to add a 'created_at' column storing data timestamp. It also will be used in this sample as Iceberg partion in the form: month('created_at')
# Create 'created_at' column and init
df = df.withColumn('created_at', F.current_timestamp())

# Just for testing purposes and in case you want to test iceberg partition behaviour inserting several months
# from dateutil.relativedelta import relativedelta
# from datetime import datetime
# df = df.withColumn('created_at', F.lit(datetime.now() + relativedelta(months=-1)).cast('timestamp'))

# Correct way to reorder (and avoid duplicates):
df = df.select("created_at", *[col for col in df.columns if col != "created_at"])


In [8]:
# Add new column regarding vat taxes of 21% 
df = df.withColumn('vat', F.col('total_bill') * 0.21)

In [9]:
def create_s3_bucket_if_not_exists(bucket_name, region_name=None):
    try:
        s3 = boto3.client('s3', region_name=region_name)
        # Check if the bucket exists
        try:
            s3.head_bucket(Bucket=bucket_name)
            print(f"Bucket '{bucket_name}' already exists.")
            return True
        except ClientError as e:
            error_code = e.response['Error']['Code'] 
            if error_code == '404':
                print(f"Bucket '{bucket_name}' does not exist. Creating...")
                if region_name:
                    try:
                        s3.create_bucket(
                            Bucket=bucket_name,
                            CreateBucketConfiguration={'LocationConstraint': region_name}
                        )
                    except Exception as e:
                        print(f"Error creating bucket in region {region_name}: {e}")
                        return False
                else:
                    try:
                        s3.create_bucket(Bucket=bucket_name)
                    except Exception as e:
                        print(f"Error creating bucket: {e}")
                        return False
                print(f"Bucket '{bucket_name}' created successfully.")
                return True

            else:
                print(f"An unexpected error occurred checking or creating bucket: {e}")
                return False
    except Exception as e:
        print(f"A general error occurred: {e}")
        return False


# Create bucket if not exists.
create_s3_bucket_if_not_exists(bucket_name, region)

Bucket 'bd-datawarehouse' already exists.


True

In [10]:
# Define S3 bucket and path for database and table
s3_path = f"{database_name}/{table_name}/"  # Path inside the bucket

# Initialize S3 client
s3 = boto3.client("s3")

def s3_path_exists(bucket, path):
    """Check if an S3 path exists by listing objects with that prefix."""
    response = s3.list_objects_v2(Bucket=bucket, Prefix=path)
    return "Contents" in response  # Returns True if objects exist

def create_s3_path(bucket, path):
    """Create an empty directory in S3 by uploading an empty file."""
    if not s3_path_exists(bucket, path):
        s3.put_object(Bucket=bucket, Key=f"{path}placeholder.txt", Body=b"")  # Upload an empty file
        print(f"Created path: s3://{bucket}/{path}")
    else:
        print(f"Path already exists: s3://{bucket}/{path}")

# Check and create the path if needed
create_s3_path(bucket_name, s3_path)

Path already exists: s3://bd-datawarehouse/database_name/table_name/


In [11]:
# Create database if not exists
glue_client = boto3.client('glue', region_name="eu-west-1")  # Change to your region

database_location = f"s3://{bucket_name}/{database_name}"

# Check if the database already exists
existing_databases = [db['Name'] for db in glue_client.get_databases()['DatabaseList']]
if database_name not in existing_databases:
    glue_client.create_database(DatabaseInput={'Name': database_name, 'LocationUri': database_location})
    print(f"Database {database_name} created successfully.")
else:
    print(f"Database {database_name} already exists.")

Database database_name created successfully.


In [12]:
# We have to options in order to partition by month
# Create a non hidden field in order to be used as partition field
# df = df.withColumn("yearmonth", date_format(F.col("created_at"), "yyyyMM").cast("int")) # or string
# or just use a hidden partition with transform method, e.g.: "PARTITIONED BY (month(created_at))", check https://medium.com/@life-is-short-so-enjoy-it/aws-athena-iceberg-experiment-dropping-partitions-month-b5074e56c911 in order to understand what partitions month value means.
# Iceberg table properties could be readed using: "describe formatted <database_name>.<table_name>;"
 #


# Create Iceberg table if not exist partionted by created_at
spark.sql(f"""    
    CREATE TABLE IF NOT EXISTS AwsGlueCatalog.{database_name}.{table_name} (
        created_at timestamp
    )
    PARTITIONED BY (month(created_at))
    LOCATION 's3://{bucket_name}/{database_name}/{table_name}'
    TBLPROPERTIES (
        'table_type' = 'ICEBERG',
        'format'='parquet',        
        'write_compression'='ZSTD',        
        'optimize_rewrite_data_file_threshold'='5',
        'optimize_rewrite_delete_file_threshold'='2',
        'vacuum_min_snapshots_to_keep'='5'
    )
""")


SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.


DataFrame[]

In [20]:
# Write to table using awswrangler.athena.to_iceberg -> https://aws-sdk-pandas.readthedocs.io/en/3.2.1/stubs/awswrangler.athena.to_iceberg.html
wr.athena.to_iceberg(
        df=df.toPandas(),
        database=database_name,
        table=table_name,   
        temp_path=f"s3://{bucket_name}/{database_name}/tmp_table_{table_name}/",
        workgroup=work_group,
        keep_files=False,
        schema_evolution=True,
        fill_missing_columns_in_df=True
    )

  series = series.astype(t, copy=False)


In [None]:
df = spark.sql(f"DESCRIBE TABLE EXTENDED AwsGlueCatalog.{database_name}.{table_name}")
df.show(100, False)

In [None]:
spark.stop()

In [None]:
# from pyspark.sql import SparkSession
# from pyspark.sql.functions import *  # Import Spark functions

# # Initialize a SparkSession
# spark = SparkSession.builder.appName("IcebergExample").getOrCreate()

# # Configure Iceberg (replace with your actual configuration)
# spark.conf.set("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
# spark.conf.set("spark.sql.catalog.iceberg", "org.apache.iceberg.spark.IcebergCatalog")
# spark.conf.set("spark.sql.catalog.iceberg.type", "hadoop") # or hive
# spark.conf.set("spark.sql.catalog.iceberg.warehouse", "s3a://your-iceberg-warehouse") # or hdfs://path

# # Create a sample DataFrame
# data = [("Alice", 25), ("Bob", 30), ("Charlie", 28)]
# df = spark.createDataFrame(data, ["name", "age"])

# # Write to Iceberg (create a new table or overwrite if it exists)
# df.write.format("iceberg").mode("overwrite").saveAsTable("iceberg.your_catalog.your_table") # iceberg.your_catalog is required. your_table is the table name.

# # Read from Iceberg
# iceberg_df = spark.read.format("iceberg").table("iceberg.your_catalog.your_table") # Read from iceberg.your_catalog.your_table

# iceberg_df.show()

# # Example Iceberg queries
# # You can use SQL queries to interact with Iceberg tables
# spark.sql("SELECT * FROM iceberg.your_catalog.your_table WHERE age > 25").show()

# # Example of updating data in an Iceberg table
# updatesDF = spark.createDataFrame([("Alice", 26)], ["name", "age"]) # Create a dataframe with updates
# updatesDF.write.format("iceberg").mode("merge").option("mergeSchema", "true").saveAsTable("iceberg.your_catalog.your_table") # Merge the updates

# # Example of deleting data from an Iceberg table
# df.filter("age > 27").write.format("iceberg").mode("delete").saveAsTable("iceberg.your_catalog.your_table")

# # Show the updated table
# iceberg_df = spark.read.format("iceberg").table("iceberg.your_catalog.your_table") # Read from iceberg.your_catalog.your_table
# iceberg_df.show()

# # Stop the SparkSession
# spark.stop()