# Data Preparation for LLM Pre-training
## Notebook 2: Data Preprocessing

In this notebook we will setup a data preprocessing pipeline to prepare the data for LLM pre-training. We will use the [datatrove](https://github.com/huggingface/datatrove) library from Hugging Face to create a data processing pipeline. The pipeline will be implemented using [SageMaker Pipelines](https://docs.aws.amazon.com/sagemaker/latest/dg/define-pipeline.html). We'll utilize the [step decorator](https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines-step-decorator-create-pipeline.html) to simplify the creation of the pipeline steps.

<div style="border: 1px solid black; padding: 10px; background-color: #ffffcc; color: black;">
<strong>Note:</strong> Make sure to fully run the first notebook to obtain the data needed for this notebook.
</div>

In [None]:
%pip install -Uqq datatrove[all]
%pip install -Uqq nltk
%pip install -Uqq sagemaker

In [None]:
import sagemaker
from sagemaker.workflow.function_step import step
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.parameters import ParameterString
from sagemaker import image_uris
import json
import os
from pathlib import Path


role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = (
    sagemaker.session.Session()
)  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
bucket = sess.default_bucket()  # default bucket name
account_id = sess.account_id()

# get the image uri that will be used to run the pipeline
# use pytorch image as it is the most fully featured image
image_uri = image_uris.retrieve(
    framework="pytorch",
    image_scope="training",
    region=region,
    version="2.3",
    py_version="py311",
    instance_type="ml.g5.xlarge",
)

s3_data_path = json.loads(Path("s3_path.json").open("r").read())[
    "s3_path"
]  # read the s3 path from the file generated by the previous notebook

### Configure pipeline using @step Decorator
By using the `@step` decorator, we can configure our pipeline as python functions that we can test locally. Then by decorating each function with `@step`, we can define the pipeline DAG that will run each step as a SageMaker job.

In [None]:
from datatrove.executor.local import LocalPipelineExecutor
from datatrove.pipeline.extractors import Trafilatura
from datatrove.pipeline.readers import WarcReader
from datatrove.pipeline.writers.jsonl import JsonlWriter
from datatrove.pipeline.filters import (
    GopherQualityFilter,
    GopherRepetitionFilter,
    LanguageFilter,
    URLFilter,
)

import nltk

The first step of the pipeline is adopted from the example here [here](https://github.com/huggingface/datatrove/blob/main/examples/process_common_crawl_dump.py) to process crawled data. The pipeline will read the warc files, filter out unwanted URLs (there are none in our example), extract the text from the HTML using the [trafilatura library](https://trafilatura.readthedocs.io/en/latest/), apply repetition filters to remove documents with repeated content, apply additional quality filters (minimum number of words, average length of word, ration of alphabetic words, etc.), and finally save the processed data to S3.

In [4]:
@step(
    name="initial_filter",
    keep_alive_period_in_seconds=300,
    image_uri=image_uri,
    instance_type="ml.m5.2xlarge",        # instance type for the pipeline
    instance_count=1,
    dependencies="requirements.txt",     # dependencies for the pipeline
)
def initial_filter(dataset_name: str, s3_data_path: str, s3_output_path: str):
    

    nltk.download("punkt_tab")

    executor = LocalPipelineExecutor(       # LocalPipelineExecutor is used to run the pipeline on a single machine. There is a SLURM executor available for running the pipeline on a cluster 
        pipeline=[
            WarcReader(
                s3_data_path,
                glob_pattern="*warc.gz",  # we want the warc files
                default_metadata={"name": dataset_name},
            ),
            URLFilter(
                exclusion_writer=JsonlWriter(
                    f"{s3_output_path}/removed/url/{dataset_name}"
                )
            ),
            Trafilatura(favour_precision=True),
            LanguageFilter(
                exclusion_writer=JsonlWriter(
                    f"{s3_output_path}/non_english/",
                    output_filename="${language}/" + dataset_name + "/${rank}.jsonl.gz",
                )
            ),
            GopherRepetitionFilter(
                exclusion_writer=JsonlWriter(
                    f"{s3_output_path}/removed/repetitive/{dataset_name}"
                )
            ),
            GopherQualityFilter(
                exclusion_writer=JsonlWriter(
                    f"{s3_output_path}/removed/quality/{dataset_name}"
                )
            ),
            JsonlWriter(f"{s3_output_path}/output/{dataset_name}"),
        ],
        tasks=4,
        logging_dir=f"{s3_output_path}/logs/base_processing/{dataset_name}",
    )

    executor.run()

    final_output_path = f"{s3_output_path}/output/{dataset_name}"
    
    return final_output_path

The next step is to deduplicate the filtered data. Minhashing is used to create a signature for each document, and then Locality Sensitive Hashing (LSH) is used to cluster similar documents and finally filter out duplicates by taking one document from each cluster.
The step is adopted from the example [here](https://github.com/huggingface/datatrove/blob/main/examples/minhash_deduplication.py)

In [5]:
from datatrove.pipeline.dedup import MinhashDedupSignature
from datatrove.pipeline.dedup.minhash import (
    MinhashConfig,
    MinhashDedupBuckets,
    MinhashDedupCluster,
    MinhashDedupFilter,
)
from datatrove.pipeline.readers import JsonlReader
from datatrove.pipeline.tokens import TokensCounter
from datatrove.pipeline.writers.jsonl import JsonlWriter
from datatrove.utils.hashing import HashConfig
from datatrove.utils.typeshelper import Languages

In [6]:
@step(
    name="deduplication",
    keep_alive_period_in_seconds=300,
    image_uri=image_uri,
    instance_type="ml.m5.2xlarge",
    instance_count=1,
    dependencies="requirements.txt",
)

def deduplicate(s3_input_path: str):
    
    nltk.download('punkt_tab')

    minhash_config = MinhashConfig(
        hash_config=HashConfig(precision=64),
        num_buckets=14,
        hashes_per_bucket=8,
    )  

    input_reader = JsonlReader(s3_input_path)
    s3_output_path = f"{os.path.dirname(s3_input_path)}/deduplication"

    stage1 = LocalPipelineExecutor(
        pipeline=[
            input_reader,
            MinhashDedupSignature(
                output_folder=f"{s3_output_path}/signatures",
                config=minhash_config,
                language=Languages.english,
            ),
        ],
        tasks=4,
    )

    stage2 = LocalPipelineExecutor(
        pipeline=[
            MinhashDedupBuckets(
                input_folder=f"{s3_output_path}/signatures",
                output_folder=f"{s3_output_path}/buckets",
                config=minhash_config,
            ),
        ],
        tasks=minhash_config.num_buckets,
        depends=stage1,
    )

    stage3 = LocalPipelineExecutor(
        pipeline=[
            MinhashDedupCluster(
                input_folder=f"{s3_output_path}/buckets",
                output_folder=f"{s3_output_path}/remove_ids",
                config=minhash_config,
            ),
        ],
        tasks=1,
        depends=stage2,
    )

    stage4 = LocalPipelineExecutor(
        pipeline=[
            input_reader,
            TokensCounter(),
            MinhashDedupFilter(
                input_folder=f"{s3_output_path}/remove_ids",
                exclusion_writer=JsonlWriter(f"{s3_output_path}/removed"),
            ),
            JsonlWriter(output_folder=f"{s3_output_path}/deduplicated_output"),
        ],
        tasks=4,
        depends=stage3,
    )
    
    stage4.run()
    
    final_output = f"{s3_output_path}/deduplicated_output"
    
    return final_output

Finally we can combine the two steps into a single pipeline. The pipeline will take 3 parameters:
- The location of the crawled data (warc files)
- The name of the dataset to be created
- The output location for the processed data

In [7]:
# Pipeline parameters
s3_source_param = ParameterString(name="s3_source_path")          # s3 path to the source data
dataset_name = ParameterString(name="dataset_name")               # name of the dataset to create a folder in the s3 bucket
s3_filtered_param = ParameterString(name="s3_filtered_path")      # s3 path to the output data

In [None]:
# create the pipeline definition
pipeline_name = "PreTrainDatePrep"
pipeline = Pipeline(
    name=pipeline_name,
    steps=[deduplicate(initial_filter(dataset_name, s3_source_param, s3_filtered_param))],
    parameters=[dataset_name, s3_source_param, s3_filtered_param],
)

# update or create the pipeline
pipeline.upsert(role_arn=role)

In [9]:
s3_output_path = f"s3://{bucket}/pre-training-data/"

# start the pipeline
execution = pipeline.start(
    parameters={
        "dataset_name": "aws-blogs",
        "s3_source_path": s3_data_path,
        "s3_filtered_path": s3_output_path,
    }
)

In [None]:
execution.describe()

In [11]:
# wait for the pipeline to finish
execution.wait()

After the pipeline is done running, we can get the returned value of the `deduplication` step which will contain the S3 location of the processed data.

In [None]:
final_data_location = execution.result(step_name="deduplication")
print(f"Final data location: {final_data_location}")

In [None]:
# list the files in the final data location
!aws s3 ls $final_data_location/

Let's open one of the output files to see the processed data.

In [None]:
import pandas as pd
df = pd.read_json(f"{final_data_location}/00000.jsonl.gz", lines=True)
df.head()

### Conclusion
In this notebook we have created a data preprocessing pipeline using SageMaker Pipelines and the datatrove library. The pipeline processed the crawled data, filtered out unwanted URLs, extracted text from HTML, applied repetition filters, and deduplicated the data. The processed data is saved to S3 and can be used for LLM pre-training.