# EAI Taxonomy STEM w/ DCLM

This notebook demonstrates how to curate the EAI Taxonomy STEM dataset with DCLM filtering, a 912B token collection of high-quality STEM documents filtered from the Essential AI Common Crawl using semantic taxonomy labels and instruction density classification.
 
## Overview
The EAI Taxonomy STEM dataset represents a novel approach to STEM dataset curation. Rather than relying solely on domain-specific classifiers, we leverage semantic taxonomy labels to identify documents that:

 - Contain science, engineering, medical, and computer science content
 - Demonstrate reasoning capabilities 
 - Maintain high technical correctness
 - Come from high-quality document types per sub-topic

Key Statistics:
 - Base STEM Size: 1.74T tokens
 - STEM w/ DCLM Size: 912B tokens
 - Performance: 34.5% on MMLU-STEM (+6.8pp vs DCLM-baseline)
 - Hybrid curation: Combines taxonomy filters with DCLM instruction density classifier

## 1. Initialize Spark Session and Load Data


In [1]:
# Import necessary libraries
from pyspark.sql import functions as F
from pyspark.sql.window import Window
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import json

# Set up visualization style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

In [2]:
import logging
from pyspark import SparkConf
from pyspark.sql import SparkSession


class SessionFactory:
    @staticmethod
    def create(
        autoscale=False,
        max_partition_bytes=int(2**28),  # 256MB
        name=None,
        num_cores=None,
        num_instances=None,
        use_arrow=False,
        speculative_execution=True,
        additional_conf={},
    ) -> tuple[SparkSession, logging.Logger]:
        # Set Spark configurations
        conf = SparkConf()
        conf.set("spark.task.maxFailures", "15")
        conf.set("spark.sql.sources.parallelPartitionDiscovery.parallelism", "250")
        conf.set("spark.sql.files.maxPartitionBytes", str(max_partition_bytes))
        conf.set(
            "spark.sql.hive.filesourcePartitionFileCacheSize", 16 * 1024 * 1024 * 1024
        )

        if name:
            conf.set("spark.app.name", name)

        if not autoscale:
            assert num_instances is not None and num_cores is not None, (
                "num_instances and num_cores must be set if autoscale is False"
            )
            conf.set("spark.dynamicAllocation.enabled", "false")
            conf.set("spark.executor.instances", str(num_instances))
            conf.set("spark.executor.cores", str(num_cores))
            conf.set("spark.sql.shuffle.partitions", str(num_cores * num_instances))
            conf.set("spark.default.parallelism", str(num_cores * num_instances))

        if use_arrow:
            conf.set("spark.sql.execution.arrow.pyspark.enabled", True)
            conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 10000)

        for k, v in additional_conf.items():
            conf.set(k, v)

        if speculative_execution:
            # Add configuration for handling stragglers
            conf.set("spark.speculation", "true")
            conf.set(
                "spark.speculation.interval", "5000ms"
            )  # Check for stragglers every 5 seconds
            conf.set(
                "spark.speculation.multiplier", "3"
            )  # Task is a straggler if it's running 3x longer than median
            conf.set("spark.speculation.quantile", "0.75")

        # Start the SparkSession
        builder: SparkSession.Builder = SparkSession.builder
        spark = builder.config(conf=conf).getOrCreate()
        spark.sql("set spark.sql.files.ignoreCorruptFiles=true")
        spark.sql("set spark.sql.files.ignoreMissingFiles=true")

        logging.getLogger("py4j").setLevel(logging.WARNING)
        logging.getLogger("pyspark").setLevel(logging.WARNING)

        # Retrieve Spark's Log4j logger via the Py4J JVM bridge so callers can
        # emit messages that appear in the same place as the rest of the Spark
        # runtime logs. If an application name was supplied, we use that as the
        # logger name; otherwise we fall back to this module's __name__.
        log4j = spark._jvm.org.apache.log4j
        spark_logger = log4j.LogManager.getLogger(name if name else __name__)

        # Return both the SparkSession and the Java Log4j logger so downstream
        # code can keep a reference to the shared logger instance.
        return spark, spark_logger


In [None]:
# from eai_taxonomy.infra.spark_session_factory import SessionFactory

DATA_PATH = "<INPUT_PATH>"
NUM_INSTANCES = 50
NUM_CORES = 110

spark, logger = SessionFactory.create(name="eai-taxonomy-stem-w-dclm", num_instances=NUM_INSTANCES, num_cores=NUM_CORES)
df = spark.read.parquet(DATA_PATH)

## 2. Define the Taxonomy Top Math Filter
The filter combines multiple taxonomy dimensions to identify high-quality mathematical content:

In [6]:
from pyspark.sql.functions import col

science_codes = ["50", "51", "54", "57", "58", "59", "61"]

dds_primary = col("eai_taxonomy.free_decimal_correspondence.primary.code").substr(0, 2)
dds_secondary = col("eai_taxonomy.free_decimal_correspondence.secondary.code").substr(0, 2)

dds_filter = (
    (
        (dds_primary == "61") & (dds_secondary.isin(science_codes))
    ) |
    (
        (dds_secondary == "61") & (dds_primary.isin(science_codes))
    )
)

# Add document type constraints:
doc_type_v1_codes = ["2", "3"]  # Academic/Research, Reference/Encyclopedic/Educational
doc_type_v2_codes = ["3", "8", "10", "18"]  # Academic Writing, Documentation, Knowledge Article, Q&A Forum

# Document type blacklists
DOC_TYPE_V1_BLACKLIST = ["1", "4", "5", "6", "8", "9", "10", "15", "16"]
DOC_TYPE_V2_BLACKLIST = ["1", "2", "4", "5", "6", "7", "11", "12", "13", "14", "16", "17", "19", "20", "22", "23", "24"]

doc_type_v1_col = col("eai_taxonomy.document_type_v1.primary.code")
doc_type_v2_col = col("eai_taxonomy.document_type_v2.primary.code")

doc_type_filter = (
    doc_type_v1_col.isin(doc_type_v1_codes) |
    doc_type_v2_col.isin(doc_type_v2_codes)
) & (
    ~doc_type_v1_col.isin(DOC_TYPE_V1_BLACKLIST) &
    ~doc_type_v2_col.isin(DOC_TYPE_V2_BLACKLIST)
)
        
# Combine DDS and document type filters
final_filter = dds_filter & doc_type_filter

## 3. Apply the Filter

In [None]:
stem_df = df.filter(stem_filter)

total_docs = df.count()
stem_docs = stem_df.count()

## 4. Save the Filtered Dataset

In [None]:
OUTPUT_PATH = "<OUTPUT_PATH>"
stem_df.write.parquet(OUTPUT_PATH)