# arXiv Classification with Spark NLP

This notebook demonstrates text classification on arXiv paper abstracts using two approaches:
1. **TF-IDF + Logistic Regression** - Traditional ML approach
2. **DistilBERT + ClassifierDL** - Deep learning approach with Spark NLP

---

## 1. Setup & Installation

In [1]:
# Clean install of PySpark and Spark NLP
!pip uninstall pyspark spark-nlp -y
!rm -rf ~/.ivy2/jars/*
!rm -rf ~/.ivy2/cache/com.johnsnowlabs.nlp/
!pip cache purge

!pip install --no-cache-dir pyspark==3.5.0
!pip install --no-cache-dir spark-nlp==5.5.0

# Additional dependencies
!pip install --upgrade pip
!pip install pandas matplotlib seaborn numpy scikit-learn
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install transformers datasets accelerate>=0.26.0

Found existing installation: pyspark 3.5.0
Uninstalling pyspark-3.5.0:
  Successfully uninstalled pyspark-3.5.0
Found existing installation: spark-nlp 5.5.0
Uninstalling spark-nlp-5.5.0:
  Successfully uninstalled spark-nlp-5.5.0
Files removed: 6 (1.4 MB)
Defaulting to user installation because normal site-packages is not writeable
Collecting pyspark==3.5.0
  Downloading pyspark-3.5.0.tar.gz (316.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.9/316.9 MB[0m [31m73.1 MB/s[0m  [33m0:00:04[0m:00:01[0m00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (pyproject.toml) ... [?25ldone
[?25h  Created wheel for pyspark: filename=pyspark-3.5.0-py2.py3-none-any.whl size=317425400 sha256=c1d6c3d2b4c2290c97a24cae8631d611d70e6fc179536e943f13c706c7f13e8e
  Stored 

In [2]:
## 2. Imports & Configuration

In [3]:
# Standard libraries
import json
import sys
import warnings
from collections import Counter, defaultdict
from datetime import datetime

# Data science
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

# PySpark
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, split, explode, count, expr
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.feature import StringIndexer, Tokenizer, StopWordsRemover, HashingTF, IDF
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Spark NLP
import sparknlp
from sparknlp.base import DocumentAssembler
from sparknlp.annotator import (
    Tokenizer as SparkNLPTokenizer,
    DistilBertEmbeddings,
    ClassifierDLApproach,
    SentenceEmbeddings
)

# Configuration
warnings.filterwarnings('ignore')
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

# Version check
print(f"Python version: {sys.version}")
print(f"PySpark version: {pyspark.__version__}")
print(f"Spark NLP version: {sparknlp.version()}")

Python version: 3.10.12 (main, Jan  8 2026, 06:52:19) [GCC 11.4.0]
PySpark version: 3.5.0
Spark NLP version: 5.5.0


## 3. Initialize Spark Session

In [4]:
# Initialize Spark with Spark NLP
spark = sparknlp.start(gpu=False, memory="8G")
spark.sparkContext.setLogLevel("WARN")

print(f"Spark runtime version: {spark.version}")

:: loading settings :: url = jar:file:/usr/local/spark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/ubuntu/.ivy2/cache
The jars for the packages stored in: /home/ubuntu/.ivy2/jars
com.johnsnowlabs.nlp#spark-nlp_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-7e74a884-ecf1-436b-8d51-75466fca589f;1.0
	confs: [default]
	found com.johnsnowlabs.nlp#spark-nlp_2.12;5.5.0 in central
	found com.typesafe#config;1.4.2 in central
	found org.rocksdb#rocksdbjni;6.29.5 in central
	found com.amazonaws#aws-java-sdk-s3;1.12.500 in central
	found com.amazonaws#aws-java-sdk-kms;1.12.500 in central
	found com.amazonaws#aws-java-sdk-core;1.12.500 in central
	found commons-logging#commons-logging;1.1.3 in central
	found commons-codec#commons-codec;1.15 in central
	found org.apache.httpcomponents#httpclient;4.5.13 in central
	found org.apache.httpcomponents#httpcore;4.4.13 in central
	found software.amazon.ion#ion-java;1.0.2 in central
	found joda-time#joda-time;2.8.1 in central
	found com.amazonaws#jmespath-java;1.12.500 in centra

Spark runtime version: 3.5.0


## 4. Load & Explore Data

In [5]:
# Load arXiv metadata
df = spark.read.json("arxiv-metadata-oai-snapshot.json")

# Preview categories
df.select("categories").show(5, truncate=False)

# Count unique category combinations
unique_count = df.select("categories").distinct().count()
print(f"Unique category combinations: {unique_count}")

                                                                                

+---------------+
|categories     |
+---------------+
|hep-ph         |
|math.CO cs.CG  |
|physics.gen-ph |
|math.CO        |
|math.CA math.FA|
+---------------+
only showing top 5 rows





Unique category combinations: 21564


                                                                                

In [6]:
# Top 20 category combinations
top_categories = df.groupBy("categories") \
                   .agg(count("*").alias("count")) \
                   .orderBy(col("count").desc()) \
                   .limit(20)
top_categories.show(truncate=False)

# Explode categories to see individual labels
df_exploded = df.withColumn("category", explode(split(col("categories"), " ")))
unique_exploded = df_exploded.select("category").distinct().count()
print(f"Unique individual categories: {unique_exploded}")

                                                                                

+-----------------+-----+
|categories       |count|
+-----------------+-----+
|astro-ph         |16405|
|hep-ph           |15547|
|quant-ph         |13733|
|astro-ph.CO      |10718|
|hep-th           |10424|
|astro-ph.SR      |8477 |
|cond-mat.mes-hall|7550 |
|cond-mat.mtrl-sci|7041 |
|gr-qc            |5777 |
|cond-mat.str-el  |5552 |
|cs.IT math.IT    |5204 |
|math.PR          |5149 |
|math.AP          |4916 |
|astro-ph.HE      |4872 |
|math.CO          |4792 |
|hep-ex           |4778 |
|math.AG          |4211 |
|nucl-th          |4110 |
|cond-mat.supr-con|4026 |
|astro-ph.GA      |3981 |
+-----------------+-----+





Unique individual categories: 155


                                                                                

## 5. Data Preparation

In [7]:
# Use primary category (first in list) as label
df_ml = df.select("abstract", "categories") \
          .withColumn("label", split(col("categories"), " ").getItem(0)) \
          .select("abstract", "label") \
          .na.drop()

df_ml.show(5, truncate=False)

+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-

In [8]:
# Filter to top N categories and limit dataset size
TOP_N = 10
SAMPLE_SIZE = 20000

top_labels = (
    df_ml.groupBy("label")
         .agg(count("*").alias("count"))
         .orderBy(col("count").desc())
         .limit(TOP_N)
         .select("label")
)

df_ml = df_ml.join(top_labels, on="label", how="inner").limit(SAMPLE_SIZE)

# Show unique labels
print(f"Selected {TOP_N} categories:")
df_ml.select("label").distinct().show(truncate=False)

Selected 10 categories:




+-----------------+
|label            |
+-----------------+
|hep-ph           |
|cond-mat.mes-hall|
|gr-qc            |
|cond-mat.mtrl-sci|
|astro-ph         |
|hep-th           |
|cond-mat.str-el  |
|quant-ph         |
+-----------------+



                                                                                

---

## 6. Approach 1: TF-IDF + Logistic Regression

Traditional ML approach using term frequency-inverse document frequency features.

In [9]:
# Build TF-IDF + Logistic Regression Pipeline
label_indexer = StringIndexer(inputCol="label", outputCol="labelIndex")
tokenizer = Tokenizer(inputCol="abstract", outputCol="words")
remover = StopWordsRemover(inputCol="words", outputCol="filtered_words")
hashingTF = HashingTF(inputCol="filtered_words", outputCol="rawFeatures", numFeatures=20000)
idf = IDF(inputCol="rawFeatures", outputCol="features")
lr = LogisticRegression(featuresCol="features", labelCol="labelIndex", maxIter=20)

tfidf_pipeline = Pipeline(stages=[
    label_indexer,
    tokenizer,
    remover,
    hashingTF,
    idf,
    lr
])

In [10]:
# Train/test split and model training
train, test = df_ml.randomSplit([0.8, 0.2], seed=42)

tfidf_model = tfidf_pipeline.fit(train)
tfidf_model.write().overwrite().save("models/logreg_text_classifier")

predictions = tfidf_model.transform(test)
predictions.select("abstract", "label", "prediction").show(5, truncate=False)

26/01/15 10:53:47 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
26/01/15 10:54:03 WARN TaskSetManager: Stage 107 contains a task of very large size (1297 KiB). The maximum recommended task size is 1000 KiB.
26/01/15 10:54:06 WARN DAGScheduler: Broadcasting large task binary with size 1660.8 KiB


+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                

In [11]:
# Evaluate TF-IDF model
evaluator_acc = MulticlassClassificationEvaluator(
    labelCol="labelIndex", predictionCol="prediction", metricName="accuracy"
)
evaluator_f1 = MulticlassClassificationEvaluator(
    labelCol="labelIndex", predictionCol="prediction", metricName="f1"
)

accuracy = evaluator_acc.evaluate(predictions)
f1_score = evaluator_f1.evaluate(predictions)

print(f"TF-IDF + LogReg Results:")
print(f"  Accuracy: {accuracy:.4f}")
print(f"  F1 Score: {f1_score:.4f}")

26/01/15 10:54:10 WARN DAGScheduler: Broadcasting large task binary with size 1716.2 KiB
26/01/15 10:54:14 WARN DAGScheduler: Broadcasting large task binary with size 1716.2 KiB
[Stage 128:>                                                        (0 + 1) / 1]

TF-IDF + LogReg Results:
  Accuracy: 0.7898
  F1 Score: 0.7921


                                                                                

In [12]:
# Confusion matrix
confusion = predictions.groupBy("labelIndex", "prediction") \
    .count() \
    .orderBy("labelIndex", "prediction")
confusion.show(50)

# Label mapping
labels = tfidf_model.stages[0].labels
print("\nLabel Index Mapping:")
for i, label in enumerate(labels):
    print(f"  {i}: {label}")

26/01/15 10:54:18 WARN DAGScheduler: Broadcasting large task binary with size 1683.4 KiB


+----------+----------+-----+
|labelIndex|prediction|count|
+----------+----------+-----+
|       0.0|       0.0| 1241|
|       0.0|       1.0|   45|
|       0.0|       2.0|   24|
|       0.0|       3.0|   16|
|       0.0|       4.0|   60|
|       0.0|       5.0|    7|
|       0.0|       6.0|   11|
|       0.0|       7.0|    9|
|       1.0|       0.0|   23|
|       1.0|       1.0|  493|
|       1.0|       2.0|   49|
|       1.0|       3.0|    8|
|       1.0|       4.0|    7|
|       1.0|       6.0|    5|
|       1.0|       7.0|    2|
|       2.0|       0.0|    6|
|       2.0|       1.0|   41|
|       2.0|       2.0|  383|
|       2.0|       3.0|   18|
|       2.0|       4.0|   51|
|       2.0|       5.0|    4|
|       2.0|       6.0|    2|
|       2.0|       7.0|    2|
|       3.0|       0.0|    3|
|       3.0|       1.0|    7|
|       3.0|       2.0|   20|
|       3.0|       3.0|  355|
|       3.0|       4.0|    8|
|       3.0|       5.0|   12|
|       3.0|       6.0|   25|
|       3.

                                                                                

In [13]:
# Example prediction
example = predictions.select("abstract", "label", "prediction").limit(1).collect()[0]

true_label = example["label"]
pred_index = int(example["prediction"])
pred_label = labels[pred_index]

print("ABSTRACT:")
print(example["abstract"][:500], "...")
print(f"\nTRUE CATEGORY: {true_label}")
print(f"PREDICTED CATEGORY: {pred_label}")



ABSTRACT:
  $UBVRI$ photometry and medium resolution optical spectroscopy of peculiar
Type Ia supernova SN 2005hk are presented and analysed, covering the
pre-maximum phase to around 400 days after explosion. The supernova is found to
be underluminous compared to "normal" Type Ia supernovae. The photometric and
spectroscopic evolution of SN 2005hk is remarkably similar to the peculiar Type
Ia event SN 2002cx. The expansion velocity of the supernova ejecta is found to
be lower than normal Type Ia events. T ...

TRUE CATEGORY: astro-ph
PREDICTED CATEGORY: astro-ph


26/01/15 10:54:22 WARN DAGScheduler: Broadcasting large task binary with size 1659.9 KiB
                                                                                

---

## 7. Approach 2: DistilBERT + ClassifierDL

Deep learning approach using pre-trained DistilBERT embeddings with Spark NLP.

In [14]:
# Prepare data for DistilBERT (smaller sample due to computational cost)
BERT_SAMPLE_SIZE = 200
BERT_TOP_N = 20

df_bert = df.select("abstract", "categories") \
            .withColumn("label", split(col("categories"), " ").getItem(0)) \
            .select(col("abstract").alias("text"), "label") \
            .na.drop()

bert_top_labels = (
    df_bert.groupBy("label")
           .agg(count("*").alias("count"))
           .orderBy(col("count").desc())
           .limit(BERT_TOP_N)
           .select("label")
)

df_bert = df_bert.join(bert_top_labels, on="label", how="inner").limit(BERT_SAMPLE_SIZE)
train_bert, test_bert = df_bert.randomSplit([0.8, 0.2], seed=42)

In [15]:
# Build DistilBERT Pipeline
document = DocumentAssembler() \
    .setInputCol("text") \
    .setOutputCol("document")

tokenizer_bert = SparkNLPTokenizer() \
    .setInputCols(["document"]) \
    .setOutputCol("token")

distilbert = DistilBertEmbeddings.pretrained("distilbert_base_uncased", "en") \
    .setInputCols(["document", "token"]) \
    .setOutputCol("embeddings")

sentence_embeddings = SentenceEmbeddings() \
    .setInputCols(["document", "embeddings"]) \
    .setOutputCol("sentence_embeddings") \
    .setPoolingStrategy("AVERAGE")

classifier = ClassifierDLApproach() \
    .setInputCols(["sentence_embeddings"]) \
    .setOutputCol("prediction") \
    .setLabelColumn("label") \
    .setBatchSize(8) \
    .setMaxEpochs(3) \
    .setLr(1e-3) \
    .setEnableOutputLogs(True)

bert_pipeline = Pipeline(stages=[
    document,
    tokenizer_bert,
    distilbert,
    sentence_embeddings,
    classifier
])

distilbert_base_uncased download started this may take some time.
Approximate size to download 235.8 MB
[ | ]

26/01/15 10:54:32 WARN S3AbortableInputStream: Not all bytes were read from the S3ObjectInputStream, aborting HTTP connection. This is likely an error and may result in sub-optimal behavior. Request only the bytes you need via a ranged GET or drain the input stream after use.
26/01/15 10:54:32 WARN S3AbortableInputStream: Not all bytes were read from the S3ObjectInputStream, aborting HTTP connection. This is likely an error and may result in sub-optimal behavior. Request only the bytes you need via a ranged GET or drain the input stream after use.


distilbert_base_uncased download started this may take some time.
Approximate size to download 235.8 MB
Download done! Loading the resource.
[ \ ]Using CPUs




[OK!]


In [16]:
# Train DistilBERT model
bert_model = bert_pipeline.fit(train_bert)
bert_model.write().overwrite().save("models/distilbert_classifier")

bert_predictions = bert_model.transform(test_bert)
bert_predictions.select(
    col("text"),
    col("label"),
    col("prediction.result")[0].alias("predicted_label")
).show(5, truncate=80)

2026-01-15 10:55:47.969201: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:43] Reading SavedModel from: /tmp/ffa7b793c862_classifier_dl15392169626878701214
2026-01-15 10:55:48.117838: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:107] Reading meta graph with tags { serve }
2026-01-15 10:55:48.118034: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:148] Reading SavedModel debug info (if present) from: /tmp/ffa7b793c862_classifier_dl15392169626878701214
2026-01-15 10:55:48.119540: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-15 10:55:48.882626: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:228] Restoring SavedModel bundle.
2

Training started - epochs: 3 - learning_rate: 0.001 - batch_size: 8 - training_examples: 169 - classes: 17
Epoch 1/3 - 3.28s - loss: 59.748936 - acc: 0.23809524 - batches: 22
Epoch 2/3 - 0.11s - loss: 58.44898 - acc: 0.25 - batches: 22
Epoch 3/3 - 0.12s - loss: 58.44899 - acc: 0.25 - batches: 22


[Stage 176:>                                                        (0 + 1) / 1]

+--------------------------------------------------------------------------------+--------+---------------+
|                                                                            text|   label|predicted_label|
+--------------------------------------------------------------------------------+--------+---------------+
|  Aims and Methods: We present the results of VLBI observations of nineteen\n...|astro-ph|       astro-ph|
|  By combining high-resolution HST and wide-field ground based observations, ...|astro-ph|       astro-ph|
|  Common envelopes form in dynamical time scale mass exchange, when the\nenve...|astro-ph|       astro-ph|
|  Gamma-Ray Bursts (GRBs) have been detected at GeV energies by EGRET and\nmo...|astro-ph|       astro-ph|
|  Some of the means through which the possible presence of nearly deconfined\...|astro-ph|       astro-ph|
+--------------------------------------------------------------------------------+--------+---------------+
only showing top 5 rows



                                                                                

In [18]:
# Evaluate DistilBERT model
# Extract predicted label from Spark NLP annotation
bert_predictions_eval = bert_predictions.withColumn(
    "predicted_label",
    expr("prediction.result[0]")
)

# Convert string labels to numeric indices for evaluation
label_indexer_eval = StringIndexer(inputCol="label", outputCol="label_index")
pred_indexer_eval = StringIndexer(inputCol="predicted_label", outputCol="predicted_index")

# Fit on true labels and transform both columns
label_indexer_model = label_indexer_eval.fit(bert_predictions_eval)
bert_predictions_indexed = label_indexer_model.transform(bert_predictions_eval)

pred_indexer_model = pred_indexer_eval.fit(bert_predictions_indexed)
bert_predictions_indexed = pred_indexer_model.transform(bert_predictions_indexed)

# Now evaluate with numeric columns
bert_evaluator_acc = MulticlassClassificationEvaluator(
    labelCol="label_index", predictionCol="predicted_index", metricName="accuracy"
)
bert_evaluator_f1 = MulticlassClassificationEvaluator(
    labelCol="label_index", predictionCol="predicted_index", metricName="f1"
)

bert_accuracy = bert_evaluator_acc.evaluate(bert_predictions_indexed)
bert_f1 = bert_evaluator_f1.evaluate(bert_predictions_indexed)

print(f"DistilBERT + ClassifierDL Results:")
print(f"  Accuracy: {bert_accuracy:.4f}")
print(f"  F1 Score: {bert_f1:.4f}")

[Stage 201:>                                                        (0 + 1) / 1]

DistilBERT + ClassifierDL Results:
  Accuracy: 0.4516
  F1 Score: 0.2810


                                                                                