# Processing large datasets with Apache Spark and Amazon SageMaker

***This notebook run on `Python 3` kernel on a `ml.r5.xlarge` instance***.

Amazon SageMaker Processing Jobs are used  to analyze data and evaluate machine learning models on Amazon SageMaker. With Processing, you can use a simplified, managed experience on SageMaker to run your data processing workloads, such as feature engineering, data validation, model evaluation, and model interpretation. You can also use the Amazon SageMaker Processing APIs during the experimentation phase and after the code is deployed in production to evaluate performance.

        


![](https://docs.aws.amazon.com/images/sagemaker/latest/dg/images/Processing-1.png)

The preceding diagram shows how Amazon SageMaker spins up a Processing job. Amazon SageMaker takes your script, copies your data from Amazon Simple Storage Service (Amazon S3), and then pulls a processing container. The processing container image can either be an Amazon SageMaker built-in image or a custom image that you provide. The underlying infrastructure for a Processing job is fully managed by Amazon SageMaker. Cluster resources are provisioned for the duration of your job, and cleaned up when a job completes. The output of the Processing job is stored in the Amazon S3 bucket you specified.

## Our workflow for processing large amounts of data with SageMaker

We can divide our workflow into two steps:
    
1. Work with a small subset of the data with Spark running in local model in a SageMaker Studio Notebook.

1. Once we are able to work with the small subset of data we can provide the same code (as a Python script rather than a series of interactive steps) to SageMaker Processing which launched a Spark cluster, runs out code and terminates the cluster.

## In this notebook...

We will process the [Common Crawl](https://registry.opendata.aws/commoncrawl/) dataset made available by AWS on S3. The Common Crawl dataset is a corpus of web crawl data composed of over 50 billion web pages.

1. We will first read a small subset of the data locally. This would be the common crawl index file `s3://commoncrawl/cc-index/table/cc-main/warc/crawl=CC-MAIN-2023-06/subset=warc/part-00260-b5ddf469-bf28-43c4-9c36-5b5ccc3b2bf1.c000.gz.parquet` and do some analytics on it such as counting the number of rows, finding out number of web site hits per hour, top domains by count, top languages for web pages etc. This dataset is about 8 million rows. We will run this operation locally using spark in this notebook.

1. We will then repeat the same operation on all files in the `s3://commoncrawl/cc-index/table/cc-main/warc/crawl=CC-MAIN-2023-06/subset=warc/` prefix. This dataset is about 3 billion rows. This operation is too big to be run on a `ml.t3.large` instance (2 vCPU, 8GB RAM) so we will run this on 2 machines of `ml.m5.xlarge` instance type (4 vCPUs, 16GB RAM).

## Setup
We need an available Java installation to run pyspark. The easiest way to do this is to install JDK and set the proper paths using conda

In [1]:
# Setup - Run only once per Kernel App
%conda install https://anaconda.org/conda-forge/openjdk/11.0.1/download/linux-64/openjdk-11.0.1-hacce0ff_1021.tar.bz2

# install PySpark
%pip install pyspark==3.4.0

# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")

Retrieving notices: ...working... done

Downloading and Extracting Packages:

Preparing transaction: done
Verifying transaction: done
Executing transaction: done

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


## Utilize S3 Data within local PySpark
* By specifying the `hadoop-aws` jar in our Spark config we're able to access S3 datasets using the s3a file prefix. 
* Since we've already authenticated ourself to SageMaker Studio , we can use our assumed SageMaker ExecutionRole for any S3 reads/writes by setting the credential provider as `ContainerCredentialsProvider`

In [2]:
# Import pyspark and build Spark session
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder.appName("PySparkApp")
    .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.2.2")
    .config(
        "fs.s3a.aws.credentials.provider",
        "com.amazonaws.auth.ContainerCredentialsProvider",
    )
    .getOrCreate()
)

print(spark.version)



:: loading settings :: url = jar:file:/opt/conda/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/sagemaker-user/.ivy2/cache
The jars for the packages stored in: /home/sagemaker-user/.ivy2/jars
org.apache.hadoop#hadoop-aws added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-664ac2b9-8a1c-4f4e-a249-676522b243ac;1.0
	confs: [default]
	found org.apache.hadoop#hadoop-aws;3.2.2 in central
	found com.amazonaws#aws-java-sdk-bundle;1.11.563 in central
:: resolution report :: resolve 221ms :: artifacts dl 7ms
	:: modules in use:
	com.amazonaws#aws-java-sdk-bundle;1.11.563 from central in [default]
	org.apache.hadoop#hadoop-aws;3.2.2 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	---------------------------------------------------------------------
	|      default     |   2   |   0   |   0   |   0   ||   2   |   0   |
	----------------

3.4.0


### Reading data into a Spark Dataframe
Note that we will be using the "s3a" adapter (read more [here](https://aws.amazon.com/blogs/opensource/community-collaboration-the-s3a-story)). S3A enables Hadoop to directly read and write Amazon S3 objects.

In [3]:
%%time
cc_index_path = "s3a://commoncrawl/cc-index/table/cc-main/warc/crawl=CC-MAIN-2023-06/subset=warc/part-00260-b5ddf469-bf28-43c4-9c36-5b5ccc3b2bf1.c000.gz.parquet"
cc_index = spark.read.parquet(cc_index_path,
    header=True
)
cc_index.show()

24/10/18 05:19:40 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
24/10/18 05:19:45 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 1:>                                                          (0 + 1) / 1]

+--------------------+--------------------+--------------------+------------+----------------------+----------------------+----------------------+----------------------+------------------------+--------------------------+-----------------------+-----------------------+----------------------+------------+--------+--------------------+---------+-------------------+------------+--------------+--------------------+-----------------+---------------------+---------------+-----------------+-----------------+--------------------+------------------+------------------+----------------+
|         url_surtkey|                 url|       url_host_name|url_host_tld|url_host_2nd_last_part|url_host_3rd_last_part|url_host_4th_last_part|url_host_5th_last_part|url_host_registry_suffix|url_host_registered_domain|url_host_private_suffix|url_host_private_domain|url_host_name_reversed|url_protocol|url_port|            url_path|url_query|         fetch_time|fetch_status|fetch_redirect|      content_digest|con

                                                                                

In [4]:
cc_index.printSchema()

root
 |-- url_surtkey: string (nullable = true)
 |-- url: string (nullable = true)
 |-- url_host_name: string (nullable = true)
 |-- url_host_tld: string (nullable = true)
 |-- url_host_2nd_last_part: string (nullable = true)
 |-- url_host_3rd_last_part: string (nullable = true)
 |-- url_host_4th_last_part: string (nullable = true)
 |-- url_host_5th_last_part: string (nullable = true)
 |-- url_host_registry_suffix: string (nullable = true)
 |-- url_host_registered_domain: string (nullable = true)
 |-- url_host_private_suffix: string (nullable = true)
 |-- url_host_private_domain: string (nullable = true)
 |-- url_host_name_reversed: string (nullable = true)
 |-- url_protocol: string (nullable = true)
 |-- url_port: integer (nullable = true)
 |-- url_path: string (nullable = true)
 |-- url_query: string (nullable = true)
 |-- fetch_time: timestamp (nullable = true)
 |-- fetch_status: short (nullable = true)
 |-- fetch_redirect: string (nullable = true)
 |-- content_digest: string (nulla

### Analytics operations

Let us now do a few analytics operations locally in this notebook.

Dataframe shape

In [5]:
%%time
print(f"shape of the dataframe is {cc_index.count():,}x{len(cc_index.columns)}")



shape of the dataframe is 8,974,979x30
CPU times: user 3.43 ms, sys: 3.96 ms, total: 7.39 ms
Wall time: 1.96 s


                                                                                

Select a few columns of interest and filter a few rows of interest. In this example let us filter all rows where `url_path` contains the string `fortnite`.

In [6]:
%%time
pattern = 'fortnite'
cc_index_filtered = cc_index.select("url_protocol", "url_host_tld", "fetch_time", "fetch_status", "content_languages", "url_host_registered_domain", "url_path") \
                            .filter(cc_index.url_path.like(f'%{pattern}%'))


print(f"number of rows where url_path contains \"{pattern}\" is {cc_index_filtered.count():,}")
cc_index_filtered.head(5)

                                                                                

number of rows where url_path contains "fortnite" is 252


[Stage 8:>                                                          (0 + 1) / 1]

CPU times: user 14.4 ms, sys: 1.94 ms, total: 16.4 ms
Wall time: 7.23 s


                                                                                

[Row(url_protocol='https', url_host_tld='org', fetch_time=datetime.datetime(2023, 1, 28, 13, 15, 59), fetch_status=200, content_languages='ind', url_host_registered_domain='eu.org', url_path='/2022/03/22/epic-games-sumbang-hasil-penjualan-game-fortnite-ke-ukraina/'),
 Row(url_protocol='https', url_host_tld='org', fetch_time=datetime.datetime(2023, 2, 3, 10, 36, 32), fetch_status=200, content_languages='ind,eng', url_host_registered_domain='eu.org', url_path='/2021/03/sekarang-game-fortnite-bisa-dimainkan.html'),
 Row(url_protocol='https', url_host_tld='org', fetch_time=datetime.datetime(2023, 2, 1, 18, 48, 34), fetch_status=200, content_languages='eng', url_host_registered_domain='eu.org', url_path='/2020/12/fortnite-wallpaper-hd.html'),
 Row(url_protocol='https', url_host_tld='org', fetch_time=datetime.datetime(2023, 2, 1, 19, 33, 51), fetch_status=200, content_languages='eng', url_host_registered_domain='eu.org', url_path='/2020/12/fortnite-wallpaper-iphone.html'),
 Row(url_protocol=

Count pages by `content_languages` and sort counts in descending order.

In [7]:
%%time
from pyspark.sql.functions import *
cc_index.groupBy("content_languages").count().orderBy(col("count").desc()).show()



+-----------------+-------+
|content_languages|  count|
+-----------------+-------+
|              eng|5738808|
|             null| 375493|
|              fra| 374666|
|          fra,eng| 244200|
|              spa| 214516|
|          spa,eng| 118904|
|              ita|  93876|
|              deu|  89955|
|              rus|  75802|
|          ita,eng|  74727|
|          rus,eng|  67568|
|          deu,eng|  64660|
|          eng,fra|  64135|
|              pol|  60654|
|              zho|  49808|
|          eng,spa|  46708|
|          pol,eng|  39836|
|              por|  33688|
|              tur|  33323|
|          zho,eng|  33005|
+-----------------+-------+
only showing top 20 rows

CPU times: user 3.63 ms, sys: 4.27 ms, total: 7.9 ms
Wall time: 3.39 s


                                                                                

Count pages by `url_host_registered_domain` and `content_languages`.

In [8]:
cc_index.groupBy(["url_host_registered_domain", "content_languages"]).count().orderBy(col("count").desc()).show()



+--------------------------+-----------------+------+
|url_host_registered_domain|content_languages| count|
+--------------------------+-----------------+------+
|         fedoraproject.org|              eng|223965|
|                   eun.org|              eng|207939|
|               freebsd.org|              eng|159226|
|                   fao.org|              eng| 99724|
|                    eu.org|              eng| 81565|
|           freedesktop.org|              eng| 81419|
|                   ewg.org|              eng| 66536|
|           fieldmuseum.org|              eng| 59673|
|                ffmpeg.org|              eng| 55019|
|          familysearch.org|              eng| 52541|
|                  etsi.org|              eng| 47654|
|                    eu.org|              pol| 46528|
|      filezilla-project...|              eng| 41060|
|            freepascal.org|              eng| 40552|
|         firstinspires.org|              eng| 32360|
|                    eu.org|

                                                                                

What percentage of pages are `text/html`?

In [9]:
# calculate the total number of rows
rows = cc_index.count()

# convert the absolute counts to percentage
cc_index_mime_type_counts = cc_index.groupBy("content_mime_type").count().orderBy(col("count").desc())

cc_index_mime_type_counts \
        .withColumn("percentage", 100*col("count")/rows).show()



+--------------------+-------+--------------------+
|   content_mime_type|  count|          percentage|
+--------------------+-------+--------------------+
|           text/html|8663871|   96.53360748810665|
|     application/pdf| 122576|  1.3657524992537586|
|          text/plain|  41125| 0.45821834234932474|
|application/xhtml...|  24644|  0.2745856007016841|
|application/octet...|  23878|  0.2660507617900833|
|            text/csv|  17476|  0.1947191185628401|
|                 unk|  10333| 0.11513118860779507|
| application/rss+xml|  10031| 0.11176627822750337|
|          image/jpeg|   9205| 0.10256291407478502|
|       text/calendar|   9090|  0.1012815740293097|
|            text/xml|   4780| 0.05325917754236528|
|    application/json|   4747| 0.05289148865975062|
|application/atom+xml|   4008| 0.04465748610665273|
| application/rdf+xml|   2288|0.025493095861282795|
|         text/turtle|   2285| 0.02545966959922692|
|  charset=ISO-8859-1|   1989|0.022161611743046976|
|  applicati

                                                                                

## Process S3 data with SageMaker Processing Job `PySparkProcessor`

We are going to move the above processing code in a Python file and then submit that file to SageMaker Processing Job's [`PySparkProcessor`](https://sagemaker.readthedocs.io/en/stable/amazon_sagemaker_processing.html#pysparkprocessor).

In [10]:
!mkdir -p ./code

In [11]:
%%writefile ./code/process.py

import os
import argparse

# Import pyspark and build Spark session
from pyspark.sql.functions import *
from pyspark.sql.types import (
    DoubleType,
    IntegerType,
    StringType,
    StructField,
    StructType,
)
from pyspark.sql import SparkSession


def main():
    parser = argparse.ArgumentParser(description="app inputs and outputs")
    parser.add_argument("--s3_dataset_path", type=str, help="Path of dataset in S3")
    parser.add_argument("--s3_output_bucket", type=str, help="s3 output bucket")
    parser.add_argument("--s3_output_key_prefix", type=str, help="s3 output key prefix")
    args = parser.parse_args()

    spark = SparkSession.builder.appName("PySparkApp").getOrCreate()
    print(f"spark version = {spark.version}")
    
    # This is needed to save RDDs which is the only way to write nested Dataframes into CSV format
    sc = spark.sparkContext
    sc._jsc.hadoopConfiguration().set(
        "mapred.output.committer.class", "org.apache.hadoop.mapred.FileOutputCommitter"
    )

    # Defining the schema corresponding to the input data. The input data does not contain the headers
    schema = StructType(
        [
            StructField("url_surtkey", StringType(), True),
            StructField("url", StringType(), True),
            StructField("url_host_name", StringType(), True),
            StructField("url_host_tld", StringType(), True),
            StructField("url_host_2nd_last_part", StringType(), True),
            StructField("url_host_3rd_last_part", StringType(), True),
            StructField("url_host_4th_last_part", StringType(), True),
            StructField("url_host_5th_last_part", StringType(), True),
            StructField("url_host_registry_suffix", StringType(), True),
            StructField("url_host_registered_domain", StringType(), True),
            StructField("url_host_private_suffix", StringType(), True),
            StructField("url_host_private_domain", StringType(), True),
            StructField("url_host_name_reversed", StringType(), True),
            StructField("url_protocol", StringType(), True),
            StructField("url_port", IntegerType(), True),
            StructField("url_path", StringType(), True),
            StructField("url_query", StringType(), True),
            StructField("fetch_time", IntegerType(), True),
            StructField("fetch_status", IntegerType(), True),
            StructField("fetch_redirect", StringType(), True),
            StructField("content_digest", StringType(), True),
            StructField("content_mime_type", StringType(), True),
            StructField("content_mime_detected", StringType(), True),
            StructField("content_charset", StringType(), True),
            StructField("content_languages", StringType(), True),
            StructField("content_truncated", StringType(), True),
            StructField("warc_filename", StringType(), True),
            StructField("warc_record_offset", IntegerType(), True),
            StructField("warc_record_length", IntegerType(), True),
            StructField("warc_segment", StringType(), True)
        ]
    )
    
    # Downloading the data from S3 into a Dataframe
    print(f"going to read {args.s3_dataset_path}")
    cc_index = spark.read.parquet(args.s3_dataset_path, header=True, schema=schema)
    print(f"finished reading files...")
    
    # get count
    row_count = cc_index.count()
    # create a temp rdd and save to s3
    line = [f"count={row_count}"]
    print(line)
    l = [('count', row_count)]
    tmp_df = spark.createDataFrame(l)
    s3_path = "s3://" + os.path.join(args.s3_output_bucket, args.s3_output_key_prefix, "count")
    print(f"going to save count to {s3_path}")
    # we want to write to a single file so coalesce
    tmp_df.coalesce(1).write.format('csv').option('header', 'false').mode("overwrite").save(s3_path)
    
    # counts by registered domain
    counts_by_domain = cc_index.groupBy("url_host_registered_domain").count().orderBy(col("count").desc())
    s3_path = "s3://" + os.path.join(args.s3_output_bucket, args.s3_output_key_prefix, f"counts_by_registered_domain")
    print(f"going to save counts by domain to {s3_path}")
    # these could be a lot of domains only keep the top N
    counts_by_domain.coalesce(1).write.format('csv').option('header', 'true').mode("overwrite").save(s3_path)
    


    # convert the absolute counts to percentage
    content_languages_counts = cc_index.groupBy("content_languages").count().orderBy(col("count").desc())

    content_languages_counts_pct = content_languages_counts \
          .withColumn("percentage", 100*col("count")/row_count)
    s3_path = "s3://" + os.path.join(args.s3_output_bucket, args.s3_output_key_prefix, "counts_by_content_languages")
    print(f"going to save counts by content languages to {s3_path}")
    content_languages_counts_pct.coalesce(1).write.format('csv').option('header', 'true').mode("overwrite").save(s3_path)
    
    cc_hits_timeseries = cc_index.select("url_host_tld", "fetch_time") \
        .withColumn("fetch_hour", date_trunc('hour', cc_index.fetch_time)) \
        .groupBy(["fetch_hour"]) \
        .agg(count("url_host_tld").alias("hits")) \
        .orderBy(col("fetch_hour").desc()) 
    s3_path = "s3://" + os.path.join(args.s3_output_bucket, args.s3_output_key_prefix, "hits_per_hour_timeseries")
    print(f"going to save counts by hour to {s3_path}")
    cc_hits_timeseries.coalesce(1).write.format('csv').option('header', 'true').mode("overwrite").save(s3_path)
    

if __name__ == "__main__":
    main()

Overwriting ./code/process.py


Now submit this code to SageMaker Processing Job.

In [14]:
%%time
import sagemaker
from sagemaker.spark.processing import PySparkProcessor

# Setup the PySpark processor to run the job. Note the instance type and instance count parameters. SageMaker will create these many instances of this type for the spark job.
role = sagemaker.get_execution_role()
spark_processor = PySparkProcessor(
    base_job_name="sm-spark",
    framework_version="3.1",
    role=role,
    instance_count=4,
    instance_type="ml.m5.2xlarge",
    max_runtime_in_seconds=3600,
)

# s3 paths
crawl = "crawl=CC-MAIN-2023-*"
s3_dataset_path = f"s3://commoncrawl/cc-index/table/cc-main/warc/{crawl}/subset=warc/*.gz.parquet"
print(f"s3_dataset_path={s3_dataset_path}")
session = sagemaker.Session()
bucket = session.default_bucket()
output_prefix_data = f"{crawl}-data"
output_prefix_logs = f"{crawl}-spark_logs"


# run the job now, the arguments array is provided as command line to the Python script (Spark code in this case).
spark_processor.run(
    submit_app="./code/process.py",
    arguments=[
        "--s3_dataset_path",
        s3_dataset_path,
        "--s3_output_bucket",
        bucket,
        "--s3_output_key_prefix",
        output_prefix_data,
    ],
    spark_event_logs_s3_uri="s3://{}/{}/spark_event_logs".format(bucket, output_prefix_logs),
    logs=False,
)

s3_dataset_path=s3://commoncrawl/cc-index/table/cc-main/warc/crawl=CC-MAIN-2023-*/subset=warc/*.gz.parquet
.........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................!CPU times: user 2.94 s, sys: 274 ms, total: 3.21 s
Wall time: 45min 11s


![Processing job completed](img/sm-processing-job.jpg) 