# EAI Taxonomy STEM w/ DCLM

This notebook demonstrates how to curate the EAI Taxonomy STEM dataset with DCLM filtering, a 912B token collection of high-quality STEM documents filtered from the Essential AI Common Crawl using semantic taxonomy labels and instruction density classification.
 
## Overview
The EAI Taxonomy STEM dataset represents a novel approach to STEM dataset curation. Rather than relying solely on domain-specific classifiers, we leverage semantic taxonomy labels to identify documents that:

 - Contain science, engineering, medical, and computer science content
 - Demonstrate reasoning capabilities 
 - Maintain high technical correctness
 - Come from high-quality document types per sub-topic

Key Statistics:
 - Base STEM Size: 1.74T tokens
 - STEM w/ DCLM Size: 912B tokens
 - Performance: 34.5% on MMLU-STEM (+6.8pp vs DCLM-baseline)
 - Hybrid curation: Combines taxonomy filters with DCLM instruction density classifier

## 1. Initialize Spark Session and Load Data


In [1]:
# Import necessary libraries
from pyspark.sql import functions as F
from pyspark.sql.window import Window
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import json

# Set up visualization style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

In [2]:
import logging
from pyspark import SparkConf
from pyspark.sql import SparkSession


class SessionFactory:
    @staticmethod
    def create(
        autoscale=False,
        max_partition_bytes=int(2**28),  # 256MB
        name=None,
        num_cores=None,
        num_instances=None,
        use_arrow=False,
        speculative_execution=True,
        additional_conf={},
    ) -> tuple[SparkSession, logging.Logger]:
        # Set Spark configurations
        conf = SparkConf()
        conf.set("spark.task.maxFailures", "15")
        conf.set("spark.sql.sources.parallelPartitionDiscovery.parallelism", "250")
        conf.set("spark.sql.files.maxPartitionBytes", str(max_partition_bytes))
        conf.set(
            "spark.sql.hive.filesourcePartitionFileCacheSize", 16 * 1024 * 1024 * 1024
        )

        if name:
            conf.set("spark.app.name", name)

        if not autoscale:
            assert num_instances is not None and num_cores is not None, (
                "num_instances and num_cores must be set if autoscale is False"
            )
            conf.set("spark.dynamicAllocation.enabled", "false")
            conf.set("spark.executor.instances", str(num_instances))
            conf.set("spark.executor.cores", str(num_cores))
            conf.set("spark.sql.shuffle.partitions", str(num_cores * num_instances))
            conf.set("spark.default.parallelism", str(num_cores * num_instances))

        if use_arrow:
            conf.set("spark.sql.execution.arrow.pyspark.enabled", True)
            conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 10000)

        for k, v in additional_conf.items():
            conf.set(k, v)

        if speculative_execution:
            # Add configuration for handling stragglers
            conf.set("spark.speculation", "true")
            conf.set(
                "spark.speculation.interval", "5000ms"
            )  # Check for stragglers every 5 seconds
            conf.set(
                "spark.speculation.multiplier", "3"
            )  # Task is a straggler if it's running 3x longer than median
            conf.set("spark.speculation.quantile", "0.75")

        # Start the SparkSession
        builder: SparkSession.Builder = SparkSession.builder
        spark = builder.config(conf=conf).getOrCreate()
        spark.sql("set spark.sql.files.ignoreCorruptFiles=true")
        spark.sql("set spark.sql.files.ignoreMissingFiles=true")

        logging.getLogger("py4j").setLevel(logging.WARNING)
        logging.getLogger("pyspark").setLevel(logging.WARNING)

        # Retrieve Spark's Log4j logger via the Py4J JVM bridge so callers can
        # emit messages that appear in the same place as the rest of the Spark
        # runtime logs. If an application name was supplied, we use that as the
        # logger name; otherwise we fall back to this module's __name__.
        log4j = spark._jvm.org.apache.log4j
        spark_logger = log4j.LogManager.getLogger(name if name else __name__)

        # Return both the SparkSession and the Java Log4j logger so downstream
        # code can keep a reference to the shared logger instance.
        return spark, spark_logger


In [None]:
# from eai_taxonomy.infra.spark_session_factory import SessionFactory

DATA_PATH = "<INPUT_PATH>"
NUM_INSTANCES = 50
NUM_CORES = 110

spark, logger = SessionFactory.create(name="eai-taxonomy-stem-w-dclm", num_instances=NUM_INSTANCES, num_cores=NUM_CORES)
df = spark.read.parquet(DATA_PATH)

## 2. Define the Taxonomy Top Math Filter
The filter combines multiple taxonomy dimensions to identify high-quality mathematical content:

In [6]:
from pyspark.sql.functions import col

# Define document type lists for each domain with code meanings

# Code domain document types
code_document_type_v1 = [
    "2",  # Academic/Research
    "3",  # Reference/Encyclopedic/Educational
    "4",  # Code/Software
    "5",  # Social/Forum
]
code_document_type_v2 = [
    "3",   # Academic Writing
    "5",   # Comment Section
    "8",   # Documentation
    "10",  # Knowledge Article
    "16",  # Personal Blog
    "18",  # Q&A Forum
    "23"   # Tutorial
]

# Medical domain document types
medical_document_type_v1 = [
    "2",   # Academic/Research
    "3",   # Reference/Encyclopedic/Educational
    "4",   # Code/Software
    "11"   # Legal/Regulatory
]
medical_document_type_v2 = [
    "3",   # Academic Writing
    "8",   # Documentation
    "9",   # FAQ
    "10",  # Knowledge Article
    "14",  # News Article
    "23"   # Tutorial
]

# Engineering domain document types
engineering_document_type_v1 = [
    "2",   # Academic/Research
    "3",   # Reference/Encyclopedic/Educational
    "9",   # Personal/Misc
    "11"   # Legal/Regulatory
]
engineering_document_type_v2 = [
    "3",   # Academic Writing
    "4",   # Audio Transcript
    "8",   # Documentation
    "9",   # FAQ
    "10",  # Knowledge Article
    "14",  # News Article
    "23"   # Tutorial
]

# Default domain document types (sciences and other tech)
default_document_type_v1 = [
    "2",   # Academic/Research
    "3"    # Reference/Encyclopedic/Educational
]
default_document_type_v2 = [
    "3",   # Academic Writing
    "10",  # Knowledge Article
    "14"   # News Article
]

# Bad reasoning depths to filter out
reasoning_depth_bad = ["Abstain", "Indeterminate"]

# Define valid DDS categories
science_categories = [
    "50",  # Science (general)
    "51",  # Mathematics
    "52",  # Astronomy & allied sciences
    "53",  # Physics
    "54",  # Chemistry & allied sciences
    "55",  # Earth sciences
    "56",  # Paleontology
    "57",  # Life sciences
    "58",  # Plants (Botany)
    "59"   # Animals (Zoology)
]
tech_categories = [
    "60",  # Technology (general)
    "61",  # Medicine & health
    "62",  # Engineering & allied operations
    "66",  # Chemical engineering
    "00"   # Computer science & general works
]
valid_dds_prefixes = science_categories + tech_categories

# Create the base DDS filter
dds_filter = (
    # Both primary and secondary DDS must start with valid categories
    col("eai_taxonomy.free_decimal_correspondence.primary.code").substr(0, 2).isin(valid_dds_prefixes) &
    col("eai_taxonomy.free_decimal_correspondence.secondary.code").substr(0, 2).isin(valid_dds_prefixes)
)

# Create document type filters for each domain
code_filter = (
    # DDS is programming-related (005.1: Computer programming, 005.4: Systems programming)
    (col("eai_taxonomy.free_decimal_correspondence.primary.code").substr(0, 5).isin(["005.1", "005.4"]) | 
     col("eai_taxonomy.free_decimal_correspondence.secondary.code").substr(0, 5).isin(["005.1", "005.4"])) &
    # Document types suitable for code content
    (col("eai_taxonomy.document_type_v1.primary.code").isin(code_document_type_v1)) &
    (col("eai_taxonomy.document_type_v2.primary.code").isin(code_document_type_v2))
)

medical_filter = (
    # DDS is medical (61: Medicine & health)
    ((col("eai_taxonomy.free_decimal_correspondence.primary.code").substr(0, 2) == "61") |
     (col("eai_taxonomy.free_decimal_correspondence.secondary.code").substr(0, 2) == "61")) &
    # Document types suitable for medical content
    (col("eai_taxonomy.document_type_v1.primary.code").isin(medical_document_type_v1)) &
    (col("eai_taxonomy.document_type_v2.primary.code").isin(medical_document_type_v2))
)

engineering_filter = (
    # DDS is engineering (62: Engineering & allied operations)
    ((col("eai_taxonomy.free_decimal_correspondence.primary.code").substr(0, 2) == "62") |
     (col("eai_taxonomy.free_decimal_correspondence.secondary.code").substr(0, 2) == "62")) &
    # Document types suitable for engineering content
    (col("eai_taxonomy.document_type_v1.primary.code").isin(engineering_document_type_v1)) &
    (col("eai_taxonomy.document_type_v2.primary.code").isin(engineering_document_type_v2))
)

# Default filter for sciences and other valid tech categories
default_filter = (
    # Not in code, medical, or engineering specific categories
    ~(col("eai_taxonomy.free_decimal_correspondence.primary.code").substr(0, 5).isin(["005.1", "005.4"]) | 
      col("eai_taxonomy.free_decimal_correspondence.secondary.code").substr(0, 5).isin(["005.1", "005.4"])) &
    ~(col("eai_taxonomy.free_decimal_correspondence.primary.code").substr(0, 2) == "61") &
    ~(col("eai_taxonomy.free_decimal_correspondence.secondary.code").substr(0, 2) == "62") &
    # But is in valid DDS categories (sciences 50-59 or tech 60, 66)
    (col("eai_taxonomy.free_decimal_correspondence.primary.code").substr(0, 2).isin(valid_dds_prefixes) |
     col("eai_taxonomy.free_decimal_correspondence.secondary.code").substr(0, 2).isin(valid_dds_prefixes)) &
    # Apply default document type filters for general academic/reference content
    (col("eai_taxonomy.document_type_v1.primary.code").isin(default_document_type_v1)) &
    (col("eai_taxonomy.document_type_v2.primary.code").isin(default_document_type_v2))
)

# Reasoning depth filter (common to all)
reasoning_filter = (
    # Exclude uncertain reasoning depth classifications
    ~(col("eai_taxonomy.reasoning_depth.primary.code").isin(reasoning_depth_bad)) &
    # Ensure reasoning depth is not null
    (col("eai_taxonomy.reasoning_depth.primary.code").isNotNull())
)

# Combine all filters
stem_filter = (
    # Must pass reasoning depth filter
    reasoning_filter &
    # Must be in valid DDS categories
    dds_filter &
    # Must match one of the domain-specific document type filters
    (code_filter | medical_filter | engineering_filter | default_filter)
)

## 3. Apply the Filter

In [None]:
stem_df = df.filter(stem_filter)

total_docs = df.count()
stem_docs = stem_df.count()

## 4. Save the Filtered Dataset

In [None]:
OUTPUT_PATH = "<OUTPUT_PATH>"
stem_df.write.parquet(OUTPUT_PATH)