## (Optional) Step 2: Prepare your Dataset

By default, the torchtitan library uses the allenai/c4 dataset in its training configuration. This is streamed directly during training. 

However, you may want to pre-train your models on your own dataset residing in S3. In this notebook, we will download the allenai/c4 dataset to s3 and in the training notebook , we will show you how you can configure the torchtitan library to use the dataset residing in S3

We first create a processing script to download the dataset from HuggingFace in parquet format

In [None]:

%%writefile download_c4_dataset.py

import sys
import subprocess
import os


subprocess.check_call([sys.executable, "-m", "pip", "install", "datasets", "tqdm"])


import argparse
from datasets import load_dataset
import multiprocessing
from concurrent.futures import ProcessPoolExecutor, as_completed
import pyarrow as pa
import pyarrow.parquet as pq
import io
from tqdm import tqdm


def process_and_upload_chunk(args):
    output_path, chunk_id, chunk = args
    
    try:
        print(f"Processing chunk {chunk_id}")
        # Convert chunk to Arrow table
        arrow_table = pa.Table.from_pylist(chunk)
        
        # Write the table to an in-memory buffer as Parquet
        buf = io.BytesIO()
        pq.write_table(arrow_table, buf)
        
        # Get the bytes from the buffer
        parquet_bytes = buf.getvalue()
        
        # Save to local file
        file_path = os.path.join(output_path, f"chunk_{chunk_id:06d}.parquet")
        with open(file_path, 'wb') as f:
            f.write(parquet_bytes)
        
        return chunk_id
    except Exception as e:
        print(f"Error processing chunk {chunk_id}: {str(e)}")
        return None

def main():
    parser = argparse.ArgumentParser(description='Process C4 dataset')
    parser.add_argument('--input-data', type=str, help='S3 path to input data (optional)')
    parser.add_argument('--output-dir', type=str, default='/opt/ml/processing/output', help='Output directory')
    parser.add_argument('--chunk-size', type=int, default=10000, help='Number of examples per chunk')
    args = parser.parse_args()

    output_path = args.output_dir
    os.makedirs(output_path, exist_ok=True)

    print("Loading C4 dataset...")
    dataset_args = {"streaming": True}
    if args.input_data:
        dataset_args["data_files"] = args.input_data
    dataset = load_dataset("c4", "en", split="train", **dataset_args)

    print(f"Saving dataset to {output_path}")

    num_cpus = multiprocessing.cpu_count()
    print(f"Number of CPUs available: {num_cpus}")

    chunk_size = args.chunk_size
    chunk = []
    chunk_id = 0

    with ProcessPoolExecutor(max_workers=num_cpus) as executor:
        futures = []
        
        with tqdm(total=None) as pbar:
            for example in dataset:
                chunk.append(example)
                if len(chunk) >= chunk_size:
                    future_args = (output_path, chunk_id, chunk)
                    futures.append(executor.submit(process_and_upload_chunk, future_args))
                    chunk = []
                    chunk_id += 1
                    pbar.update(chunk_size)

            if chunk:
                future_args = (output_path, chunk_id, chunk)
                futures.append(executor.submit(process_and_upload_chunk, future_args))
                pbar.update(len(chunk))

        for future in as_completed(futures):
            result = future.result()
            if result is not None:
                print(f"Completed processing chunk {result}")

    print("Dataset processing complete.")

if __name__ == "__main__":
    main()

We then launch the above script using a [SageMaker Processsing Job](https://docs.aws.amazon.com/sagemaker/latest/dg/processing-job.html) to download the dataset to S3

In [None]:
import os
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor

# Setup
role = get_execution_role()
print(f"SageMaker Execution Role: {role}")

client = boto3.client("sts")
account = client.get_caller_identity()["Account"]
print(f"AWS account: {account}")

session = boto3.session.Session()
region = session.region_name
print(f"AWS region: {region}")

sm_boto_client = boto3.client("sagemaker")
sagemaker_session = sagemaker.session.Session(boto_session=session)

# get default bucket
default_bucket = sagemaker_session.default_bucket()
print("Default bucket for this session: ", default_bucket)

# Create a custom Processor
script_processor = ScriptProcessor(
    role=role,
    image_uri=sagemaker.image_uris.retrieve(
        framework="sklearn",
        region=region,
        version="0.23-1",
    ),
    instance_count=10,
    instance_type="ml.c5.9xlarge",
    base_job_name='c4-dataset-processing',
    command=["python3"],
    sagemaker_session=sagemaker_session,
)

# Set up the processing job
script_processor.run(
    code="download_c4_dataset.py",
    outputs=[
        ProcessingOutput(
            output_name="c4_dataset",
            source="/opt/ml/processing/output",
            destination=f"s3://{default_bucket}/c4-dataset",
        )
    ],
    arguments=[
        "--output-dir", "/opt/ml/processing/output",
        "--chunk-size", "100000"
    ],
)

print(f"Processing job '{script_processor.latest_job.job_name}' started.")

Please note down the S3 path where the dataset is downloaded,  as we will use in the torchtitan training Notebook as the input channel for the training estimator function.