# EAI Taxonomy Top Math

This notebook demonstrates how to curate the Taxonomy Top Math dataset, a 29B token collection of high-quality mathematical documents filtered from the Essential AI Common Crawl using semantic taxonomy labels.

## Overview
The Taxonomy Top Math dataset represents a novel approach to mathematical dataset curation. Rather than long iteration cycles with domain-specific classifiers, we leverage semantic taxonomy labels to identify documents that:

 - Contain mathematical content (FDC code 51)
 - Demonstrate reasoning capabilities
 - Maintain high technical correctness
 - Come from educational or reference sources

Key Statistics:
 - Size: 29B tokens
 - Documents: 19.8M documents (from 23.6B in Essential Common Crawl)
 - Performance: 21.3% on GSM8K, 11.0% on MATH
 - No domain-specific curation: Uses only taxonomy filters

## 1. Initialize Spark Session and Load Data


In [12]:
# Import necessary libraries
from pyspark.sql import functions as F
from pyspark.sql.window import Window
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import json

# Set up visualization style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

In [2]:
import logging
from pyspark import SparkConf
from pyspark.sql import SparkSession


class SessionFactory:
    @staticmethod
    def create(
        autoscale=False,
        max_partition_bytes=int(2**28),  # 256MB
        name=None,
        num_cores=None,
        num_instances=None,
        use_arrow=False,
        speculative_execution=True,
        additional_conf={},
    ) -> tuple[SparkSession, logging.Logger]:
        # Set Spark configurations
        conf = SparkConf()
        conf.set("spark.task.maxFailures", "15")
        conf.set("spark.sql.sources.parallelPartitionDiscovery.parallelism", "250")
        conf.set("spark.sql.files.maxPartitionBytes", str(max_partition_bytes))
        conf.set(
            "spark.sql.hive.filesourcePartitionFileCacheSize", 16 * 1024 * 1024 * 1024
        )

        if name:
            conf.set("spark.app.name", name)

        if not autoscale:
            assert num_instances is not None and num_cores is not None, (
                "num_instances and num_cores must be set if autoscale is False"
            )
            conf.set("spark.dynamicAllocation.enabled", "false")
            conf.set("spark.executor.instances", str(num_instances))
            conf.set("spark.executor.cores", str(num_cores))
            conf.set("spark.sql.shuffle.partitions", str(num_cores * num_instances))
            conf.set("spark.default.parallelism", str(num_cores * num_instances))

        if use_arrow:
            conf.set("spark.sql.execution.arrow.pyspark.enabled", True)
            conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 10000)

        for k, v in additional_conf.items():
            conf.set(k, v)

        if speculative_execution:
            # Add configuration for handling stragglers
            conf.set("spark.speculation", "true")
            conf.set(
                "spark.speculation.interval", "5000ms"
            )  # Check for stragglers every 5 seconds
            conf.set(
                "spark.speculation.multiplier", "3"
            )  # Task is a straggler if it's running 3x longer than median
            conf.set("spark.speculation.quantile", "0.75")

        # Start the SparkSession
        builder: SparkSession.Builder = SparkSession.builder
        spark = builder.config(conf=conf).getOrCreate()
        spark.sql("set spark.sql.files.ignoreCorruptFiles=true")
        spark.sql("set spark.sql.files.ignoreMissingFiles=true")

        logging.getLogger("py4j").setLevel(logging.WARNING)
        logging.getLogger("pyspark").setLevel(logging.WARNING)

        # Retrieve Spark's Log4j logger via the Py4J JVM bridge so callers can
        # emit messages that appear in the same place as the rest of the Spark
        # runtime logs. If an application name was supplied, we use that as the
        # logger name; otherwise we fall back to this module's __name__.
        log4j = spark._jvm.org.apache.log4j
        spark_logger = log4j.LogManager.getLogger(name if name else __name__)

        # Return both the SparkSession and the Java Log4j logger so downstream
        # code can keep a reference to the shared logger instance.
        return spark, spark_logger


In [None]:
# from eai_taxonomy.infra.spark_session_factory import SessionFactory

DATA_PATH = "gs://consus-dataproc/taxonomy/hf/raw/*/*/*.parquet"
NUM_INSTANCES = 50
NUM_CORES = 110

spark, logger = SessionFactory.create(name="eai-taxonomy-top-math", num_instances=NUM_INSTANCES, num_cores=NUM_CORES)
df = spark.read.parquet(DATA_PATH)

## 2. Define the Taxonomy Top Math Filter
The filter combines multiple taxonomy dimensions to identify high-quality mathematical content:

In [8]:
from pyspark.sql.functions import col

# Define filter components based on the paper
DOC_TYPE_V1 = [
    "Reference/Encyclopedic/Educational",
    "Code/Software",
    "Social/Forum",
    "Personal/Misc",
]

DOC_TYPE_V2 = [
    "Comment Section",
    "Documentation",
    "FAQ",
    "Knowledge Article",
    "Nonfiction Writing",
    "Personal Blog",
    "Q&A Forum",
    "Structured Data",
    "Tutorial",
]

REASONING_DEPTH = [
    "Basic Reasoning",
    "Intermediate Reasoning",
    "Advanced Reasoning",
    "Exceptional Reasoning",
]

TECH_CORRECTNESS = ["Highly Correct", "Exceptionally Correct"]

# Free Decimal Correspondence: 51 = Mathematics
FDC_KEEP = ["51"]

filter = (
    col("eai_taxonomy.free_decimal_correspondence.primary.code").startswith("51")
    & col("eai_taxonomy.document_type_v1.primary.label").isin(DOC_TYPE_V1)
    & col("eai_taxonomy.document_type_v2.primary.label").isin(DOC_TYPE_V2)
    & col("eai_taxonomy.reasoning_depth.primary.label").isin(REASONING_DEPTH)
    & col("eai_taxonomy.technical_correctness.primary.label").isin(TECH_CORRECTNESS)
)

## 3. Apply the Filter

In [None]:
math_df = df.filter(filter)
math_df.cache()

total_docs = df.count()
math_docs = math_df.count()

In [None]:
print(f"Selection rate: {math_docs/total_docs:.4%}")
print(f"Reduction: {(1 - math_docs/total_docs):.2%} filtered out")

## 4. Save the Filtered Dataset

In [None]:
OUTPUT_PATH = "<OUTPUT_PATH>"
math_df.write.parquet(OUTPUT_PATH)

## 5. Analyze Filter Components Distribution

Let's examine how each filter component contributes to the final dataset:


In [None]:
# FDC Mathematics subcategories (51X)
math_subcategories = {
    "510": "Mathematics (General)",
    "511": "General principles of mathematics",
    "512": "Algebra",
    "513": "Arithmetic", 
    "514": "Topology",
    "515": "Analysis",
    "516": "Geometry",
    "517": "[Unassigned]",
    "518": "Numerical analysis",
    "519": "Probabilities & applied mathematics"
}

# Get distribution of math subcategories in our filtered dataset
subcategory_dist = math_df.groupBy(
    F.substring(F.col("eai_taxonomy.free_decimal_correspondence.primary.code"), 1, 3)
).count().collect()

# Convert to dictionary
subcategory_counts = {row[0]: row[1] for row in subcategory_dist}

# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Pie chart
sizes = [subcategory_counts.get(cat, 0) for cat in math_subcategories.keys()]
labels = [f"{cat}: {math_subcategories[cat]}" for cat in math_subcategories.keys() if subcategory_counts.get(cat, 0) > 0]
sizes_filtered = [s for s in sizes if s > 0]

ax1.pie(sizes_filtered, labels=labels, autopct='%1.1f%%', startangle=90)
ax1.set_title('Distribution of Mathematics Subcategories in Dataset')

# Bar chart with counts
categories = [cat for cat, count in subcategory_counts.items() if count > 0]
counts = [count for count in subcategory_counts.values() if count > 0]

ax2.bar(range(len(categories)), counts)
ax2.set_xticks(range(len(categories)))
ax2.set_xticklabels(categories, rotation=45)
ax2.set_ylabel('Number of Documents')
ax2.set_title('Document Counts by Mathematics Subcategory')
ax2.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))

plt.tight_layout()
plt.show()

# Print exact counts
print("\nMathematics subcategory distribution:")
for code, name in math_subcategories.items():
    count = subcategory_counts.get(code, 0)
    if count > 0:
        print(f"{code} - {name}: {count:,} documents ({count/math_docs:.1%})")