# DistilBERT Embedding Generation (Spark NLP)

This notebook generates document-level embeddings for arXiv papers
using **DistilBERT** and Spark NLP.

Pipeline overview:
1. Load cleaned abstracts from HDFS
2. Convert text to Spark NLP documents
3. Tokenize text
4. Generate DistilBERT embeddings (token-level)
5. Apply mean pooling to obtain one vector per paper
6. Save embeddings to HDFS for later reuse (similarity search, recommendation)

This step is **expensive but done only once**.


## 1. Start Spark Session with Spark NLP


In [1]:
from pyspark.sql import SparkSession
import sparknlp

spark = (
    SparkSession.builder
    .appName("Arxiv-Embeddings")
    .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.5.0")
    .config("spark.driver.memory", "16g")
    .config("spark.executor.memory", "8g")
    .config("spark.sql.shuffle.partitions", "32")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .getOrCreate()
)

spark.sparkContext.setLogLevel("WARN")

print("Spark version:", spark.version)
print("Spark NLP version:", sparknlp.version())


:: loading settings :: url = jar:file:/usr/local/spark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/ubuntu/.ivy2/cache
The jars for the packages stored in: /home/ubuntu/.ivy2/jars
com.johnsnowlabs.nlp#spark-nlp_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-40be11d7-7067-4b2a-8614-70025edb38f6;1.0
	confs: [default]
	found com.johnsnowlabs.nlp#spark-nlp_2.12;5.5.0 in central
	found com.typesafe#config;1.4.2 in central
	found org.rocksdb#rocksdbjni;6.29.5 in central
	found com.amazonaws#aws-java-sdk-s3;1.12.500 in central
	found com.amazonaws#aws-java-sdk-kms;1.12.500 in central
	found com.amazonaws#aws-java-sdk-core;1.12.500 in central
	found commons-logging#commons-logging;1.1.3 in central
	found commons-codec#commons-codec;1.15 in central
	found org.apache.httpcomponents#httpclient;4.5.13 in central
	found org.apache.httpcomponents#httpcore;4.4.13 in central
	found software.amazon.ion#ion-java;1.0.2 in central
	found joda-time#joda-time;2.8.1 in central
	found com.amazonaws#jmespath-java;1.12.500 in centra

Spark version: 3.5.0
Spark NLP version: 5.5.0


## 2. Load cleaned abstracts from HDFS

We load only the columns needed for embeddings:
- `id`: paper identifier
- `abstract`: cleaned abstract text


In [2]:
from pyspark.sql.functions import col

df = (
    spark.read.parquet("hdfs:///arxiv/clean")
    .select(
        col("id"),
        col("abstract").alias("text")
    )
    .na.drop()
)

df.printSchema()
df.show(3, truncate=False)


                                                                                

root
 |-- id: string (nullable = true)
 |-- text: string (nullable = true)



                                                                                

+---------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

## 3. Spark NLP Pipeline (SciBERT)

Pipeline stages:
- **DocumentAssembler**: converts raw text → NLP document
- **Tokenizer**: splits document into tokens
- **DistilBertEmbeddings**: generates contextual embeddings for each token

Output:
- One embedding vector **per token**


In [3]:
from sparknlp.base import DocumentAssembler
from sparknlp.annotator import Tokenizer, DistilBertEmbeddings
from pyspark.ml import Pipeline

document = (
    DocumentAssembler()
    .setInputCol("text")
    .setOutputCol("document")
)

tokenizer = (
    Tokenizer()
    .setInputCols(["document"])
    .setOutputCol("token")
)

distilbert = (
    DistilBertEmbeddings.pretrained(
        "distilbert_base_uncased",
        "en"
    )
    .setInputCols(["document", "token"])
    .setOutputCol("embeddings")
)

pipeline = Pipeline(stages=[document, tokenizer, distilbert])


distilbert_base_uncased download started this may take some time.
Approximate size to download 235.8 MB
[ | ]

26/01/14 17:30:54 WARN S3AbortableInputStream: Not all bytes were read from the S3ObjectInputStream, aborting HTTP connection. This is likely an error and may result in sub-optimal behavior. Request only the bytes you need via a ranged GET or drain the input stream after use.
26/01/14 17:30:55 WARN S3AbortableInputStream: Not all bytes were read from the S3ObjectInputStream, aborting HTTP connection. This is likely an error and may result in sub-optimal behavior. Request only the bytes you need via a ranged GET or drain the input stream after use.


distilbert_base_uncased download started this may take some time.
Approximate size to download 235.8 MB
Download done! Loading the resource.
[ / ]

                                                                                

[ — ]Using CPUs
[ \ ]



[OK!]


## 4. Generate token-level embeddings with DistilBERT

In this step, we run the Spark NLP pipeline that includes:
- `DocumentAssembler`
- `Tokenizer`
- `DistilBertEmbeddings`

We start with a limited sample to avoid JVM crashes.


In [4]:
# Use a safe subset to avoid JVM crashes (adjust later)
df_sample = df.limit(2000)

# Run the embedding pipeline
embedded = pipeline.fit(df_sample).transform(df_sample)

# Verify execution
embedded.select("id").show(5)


[Stage 6:>                                                          (0 + 1) / 1]

+---------+
|       id|
+---------+
|0704.0001|
|0704.0002|
|0704.0003|
|0704.0004|
|0704.0005|
+---------+
only showing top 5 rows



                                                                                

## 5. Mean Pooling (Document-Level Embeddings)

DistilBERT outputs one embedding per token.

However, for similarity search, nearest-neighbor queries, clustering we need one fixed-length vector per document.

Solution: Mean pooling

We average all token vectors into one document vector.

Instead of a Python UDF (unstable), we use Spark NLP’s built-in pooling.


### What this cell does

Converts token embeddings → one vector per document

Runs fully inside Spark JVM

Safe, fast, and scalable


In [5]:
from sparknlp.annotator import SentenceEmbeddings

sentence_pool = (
    SentenceEmbeddings()
    .setInputCols(["document", "embeddings"])
    .setOutputCol("sentence_embedding")
    .setPoolingStrategy("AVERAGE")
)

# Extend the pipeline with pooling
from pyspark.ml import Pipeline

pooling_pipeline = Pipeline(
    stages=pipeline.getStages() + [sentence_pool]
)

embedded_pooled = pooling_pipeline.fit(df_sample).transform(df_sample)

embedded_pooled.select("id").show(3)


[Stage 9:>                                                          (0 + 1) / 1]

+---------+
|       id|
+---------+
|0704.0001|
|0704.0002|
|0704.0003|
+---------+
only showing top 3 rows



                                                                                

## 6. Create final document-level embeddings

Each row becomes:
- `id`: paper identifier
- `embedding`: vector representation of the abstract


In [6]:
from pyspark.sql.functions import col

embeddings_df = embedded_pooled.select(
    col("id"),
    col("sentence_embedding.embeddings")[0].alias("embedding")
)

embeddings_df.printSchema()



root
 |-- id: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = false)



## 7. Save embeddings to HDFS (DistilBERT)

We store the embeddings in **HDFS** to:
- Avoid recomputation
- Enable reuse across notebooks
- Support similarity search and recommendation

Format: **Parquet**
- Columnar
- Compressed
- Fast loading


In [7]:
embeddings_df.coalesce(1).write.mode("overwrite").parquet(
    "/arxiv/embeddings/distilbert"
)


                                                                                