In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col,split,count, size, explode, length, when

spark = SparkSession.builder \
    .appName("arxivCategoryEDA") \
    .getOrCreate()

df = spark.read.json("../../../data/arxiv-metadata-oai-snapshot.json")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/02/04 01:29:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
26/02/04 01:29:16 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
                                                                                

In [9]:
from pyspark.sql.functions import (
    col, split, count, length, min, max, avg, 
    array, lit, to_date, size, array_intersect
)

# 1. Configuration: We require BOTH of these
REQUIRED_CATEGORIES = ["cs.AI", "cs.LG"]
START_DATE = "2025-01-01"

# 2. Create the array column literal
req_cats_col = array([lit(c) for c in REQUIRED_CATEGORIES])

# 3. Apply Filtering Logic
#    We use size() to convert the array result into a number we can compare.
filtered_df = df.withColumn("cat_array", split(col("categories"), " ")) \
    .filter(
        # LOGIC: The intersection must contain exactly as many items as our required list.
        # If intersection size == 2, it means BOTH cs.AI and cs.LG were found.
        (size(array_intersect(col("cat_array"), req_cats_col)) == size(req_cats_col)) & 
        (to_date(col("update_date")) >= lit(START_DATE))
    )

# 4. Calculate Statistics
stats_df = filtered_df.select(length(col("abstract")).alias("abs_len")) \
    .agg(
        count("*").alias("total_abstracts"),
        min("abs_len").alias("min_length"),
        max("abs_len").alias("max_length"),
        avg("abs_len").alias("avg_length")
    )

stats_df.show()



+---------------+----------+----------+------------------+
|total_abstracts|min_length|max_length|        avg_length|
+---------------+----------+----------+------------------+
|          19947|        87|      2231|1324.6611520529402|
+---------------+----------+----------+------------------+



                                                                                

Top 10 rows of categories

In [None]:
df.select(["categories"]).\
show()

+-----------------+
|       categories|
+-----------------+
|           hep-ph|
|    math.CO cs.CG|
|   physics.gen-ph|
|          math.CO|
|  math.CA math.FA|
|cond-mat.mes-hall|
|            gr-qc|
|cond-mat.mtrl-sci|
|         astro-ph|
|          math.CO|
|  math.NT math.AG|
|          math.NT|
|          math.NT|
|  math.CA math.AT|
|           hep-th|
|           hep-ph|
|         astro-ph|
|           hep-th|
|  math.PR math.AG|
|           hep-ex|
+-----------------+
only showing top 20 rows



Explode Categories

In [None]:
cat_df = df.withColumn("category", explode(split(col("categories"), " ")))


Category Frequency

In [None]:
cat_freq = cat_df.groupBy("category").count().orderBy(col("count").desc())
cat_freq.show(20)



+------------------+------+
|          category| count|
+------------------+------+
|             cs.LG|242508|
|            hep-ph|191810|
|            hep-th|177961|
|             cs.CV|173824|
|          quant-ph|170398|
|             cs.AI|151883|
|             gr-qc|117690|
|          astro-ph|105380|
| cond-mat.mtrl-sci|104154|
| cond-mat.mes-hall| 98338|
|             cs.CL| 97377|
|           math.MP| 86880|
|           math-ph| 86880|
|   cond-mat.str-el| 80258|
|cond-mat.stat-mech| 79055|
|           math.CO| 74707|
|       astro-ph.CO| 74381|
|           stat.ML| 74330|
|       astro-ph.GA| 73745|
|           math.AP| 70446|
+------------------+------+
only showing top 20 rows



                                                                                

Long Tail Detection

In [None]:
cat_freq.filter(col("count") < 100).count()


                                                                                

4

Multi-label Count Per Paper

In [None]:
df.withColumn("num_labels", size(split(col("categories"), " "))) \
  .groupBy("num_labels").count().orderBy("num_labels").show()




+----------+-------+
|num_labels|  count|
+----------+-------+
|         1|1514466|
|         2| 868892|
|         3| 352262|
|         4| 113661|
|         5|  32690|
|         6|   7126|
|         7|   1013|
|         8|    171|
|         9|     34|
|        10|     14|
|        11|      2|
|        13|      1|
+----------+-------+



                                                                                

In [3]:
df = df.dropDuplicates()
df.count()

                                                                                

2890321

Extract first category (single-label baseline)

In [2]:
df = df.withColumn(
    "primary_category",
    split(col("categories"), " ").getItem(0)
)


Extract prefix before dot

In [3]:
df = df.withColumn(
    "category_prefix",
    split(col("primary_category"), "\.").getItem(0)
)
df.show(5)

+--------------------+--------------------+--------------------+---------------+--------------------+--------------------+---------+--------------------+--------------------+----------------+------------------+--------------------+-----------+--------------------+----------------+---------------+
|            abstract|             authors|      authors_parsed|     categories|            comments|                 doi|       id|         journal-ref|             license|       report-no|         submitter|               title|update_date|            versions|primary_category|category_prefix|
+--------------------+--------------------+--------------------+---------------+--------------------+--------------------+---------+--------------------+--------------------+----------------+------------------+--------------------+-----------+--------------------+----------------+---------------+
|  A fully differe...|C. Bal\'azs, E. L...|[[Balázs, C., ], ...|         hep-ph|37 pages, 15 figu...|10.11

In [None]:
#We check if there are chemistry papers
df.filter(col("category_prefix") == "chem").count()

                                                                                

0

In [4]:
from pyspark.sql.functions import when, col

df = df.withColumn(
    "final_category",
    when(col("category_prefix") == "cs", "cs")
    .when(col("category_prefix") == "math", "math")
    .when(col("category_prefix") == "physics", "physics")
    .when(col("category_prefix") == "stat", "stat")
    .when(col("category_prefix") == "q-bio", "bio")
    .when(col("category_prefix") == "econ", "econ")
    .when(col("category_prefix") == "eess", "engineering")
    .when(col("category_prefix") == "chem", "chemistry")
    .otherwise("other")
)

In [5]:
df.select(['id','title','final_category']).show(5)

+---------+--------------------+--------------+
|       id|               title|final_category|
+---------+--------------------+--------------+
|0704.0001|Calculation of pr...|         other|
|0704.0002|Sparsity-certifyi...|          math|
|0704.0003|The evolution of ...|       physics|
|0704.0004|A determinant of ...|          math|
|0704.0005|From dyadic $\Lam...|          math|
+---------+--------------------+--------------+
only showing top 5 rows



In [None]:
df.groupBy("final_category") \
  .count() \
  .orderBy(col("count").desc()) \
  .show(truncate=False)




+--------------+-------+
|final_category|count  |
+--------------+-------+
|other         |1273379|
|cs            |691231 |
|math          |556370 |
|physics       |201469 |
|engineering   |66850  |
|stat          |58160  |
|bio           |32717  |
|econ          |10145  |
+--------------+-------+



                                                                                

PERCENTAGE DISTRIBUTION OF FINAL CATEGORIES

In [8]:
total = df.count()

dist_df = df.groupBy("final_category").count() \
    .withColumn("percentage", col("count") / total * 100) \
    .orderBy(col("count").desc())

dist_df.show(truncate=False)




+--------------+-------+-------------------+
|final_category|count  |percentage         |
+--------------+-------+-------------------+
|other         |1273379|44.05666360241648  |
|cs            |691231 |23.91537133764727  |
|math          |556370 |19.24941900916888  |
|physics       |201469 |6.970471445905144  |
|engineering   |66850  |2.312891889862752  |
|stat          |58160  |2.0122332432971977 |
|bio           |32717  |1.131950395821087  |
|econ          |10145  |0.35099907588119106|
+--------------+-------+-------------------+



                                                                                

See percentage distribution inside other categories

In [None]:
from pyspark.sql.functions import sum as spark_sum

other_df = df.filter(col("final_category") == "other")

total_other = other_df.count()

other_dist = (
    other_df.groupBy("category_prefix")
    .count()
    .withColumn("percent", col("count") * 100 / total_other)
    .orderBy(col("percent").desc())
)

other_dist.show(50, truncate=False)




+---------------+------+---------------------+
|category_prefix|count |percent              |
+---------------+------+---------------------+
|cond-mat       |338294|26.56663884043949    |
|astro-ph       |326072|25.6068303309541     |
|hep-ph         |139714|10.971910169713809   |
|quant-ph       |124246|9.75718933640338     |
|hep-th         |110761|8.698195902398265    |
|gr-qc          |69205 |5.434752732689954    |
|nucl-th        |34987 |2.7475716185047814   |
|math-ph        |33378 |2.6212148936019832   |
|hep-ex         |24463 |1.9211091120554054   |
|nlin           |20283 |1.5928486334390626   |
|hep-lat        |18750 |1.4724602808747436   |
|q-fin          |12827 |1.0073198945482846   |
|nucl-ex        |12196 |0.9577666978959132   |
|chao-dyn       |1770  |0.1390002505145758   |
|alg-geom       |1209  |0.09494423891080346  |
|q-alg          |1177  |0.0924312400314439   |
|cmp-lg         |894   |0.07020690619210777  |
|solv-int       |844   |0.06628034544310846  |
|dg-ga       

                                                                                

DROP UNNEEDED CATEGORIES AND CATEGORY-PREFIX COLUMNS

In [5]:
df = df.drop("categories", "category_prefix")


In [6]:
df.printSchema()

root
 |-- abstract: string (nullable = true)
 |-- authors: string (nullable = true)
 |-- authors_parsed: array (nullable = true)
 |    |-- element: array (containsNull = true)
 |    |    |-- element: string (containsNull = true)
 |-- comments: string (nullable = true)
 |-- doi: string (nullable = true)
 |-- id: string (nullable = true)
 |-- journal-ref: string (nullable = true)
 |-- license: string (nullable = true)
 |-- report-no: string (nullable = true)
 |-- submitter: string (nullable = true)
 |-- title: string (nullable = true)
 |-- update_date: string (nullable = true)
 |-- versions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- created: string (nullable = true)
 |    |    |-- version: string (nullable = true)
 |-- primary_category: string (nullable = true)
 |-- final_category: string (nullable = false)



DATASET FOR TABLEAU

In [10]:
from pyspark.sql.functions import col,explode_outer,year,to_timestamp,concat_ws

df_small = (
    df
    .select("id", "final_category", "authors_parsed", "versions", "submitter","update_date")
    .filter(col("submitter").isNotNull())
    .limit(60000)   
)

In [11]:
df_exploded = (
    df_small
    .withColumn("author", explode_outer("authors_parsed"))\
    .withColumn("version", explode_outer("versions"))\
    .withColumn("last_name", col("author")[0]) 
    .withColumn("first_name", col("author")[1])
    .withColumn("version_no", col("version.version"))
    .withColumn("version_date", col("version.created"))
)

df_exploded = df_exploded.withColumn(
    "author_name",
    concat_ws(" ", col("first_name"), col("last_name"))
)


In [15]:
df_exploded.show()



+----------+--------------+--------------------+--------------------+-----------------+-----------+--------------------+--------------------+---------+------------+----------+--------------------+--------------------+
|        id|final_category|      authors_parsed|            versions|        submitter|update_date|              author|             version|last_name|  first_name|version_no|        version_date|         author_name|
+----------+--------------+--------------------+--------------------+-----------------+-----------+--------------------+--------------------+---------+------------+----------+--------------------+--------------------+
|2508.00233|            cs|[[Markant, Dougla...|[{Fri, 01 Aug 202...|       Subham Sah| 2025-08-04|[Markant, Douglas, ]|{Fri, 01 Aug 2025...|  Markant|     Douglas|        v1|Fri, 01 Aug 2025 ...|     Douglas Markant|
|2508.00233|            cs|[[Markant, Dougla...|[{Fri, 01 Aug 202...|       Subham Sah| 2025-08-04|     [Sah, Subham, ]|{Fri, 01

                                                                                

In [12]:
df_exploded.printSchema()

root
 |-- id: string (nullable = true)
 |-- final_category: string (nullable = false)
 |-- authors_parsed: array (nullable = true)
 |    |-- element: array (containsNull = true)
 |    |    |-- element: string (containsNull = true)
 |-- versions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- created: string (nullable = true)
 |    |    |-- version: string (nullable = true)
 |-- submitter: string (nullable = true)
 |-- update_date: string (nullable = true)
 |-- author: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- version: struct (nullable = true)
 |    |-- created: string (nullable = true)
 |    |-- version: string (nullable = true)
 |-- last_name: string (nullable = true)
 |-- first_name: string (nullable = true)
 |-- version_no: string (nullable = true)
 |-- version_date: string (nullable = true)
 |-- author_name: string (nullable = false)



In [23]:
spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")


df_exploded = df_exploded.withColumn(
    "version_ts",
    to_timestamp("version_ts", "EEE, dd MMM yyyy HH:mm:ss z")
)

In [14]:
df_exploded.columns

['id',
 'final_category',
 'authors_parsed',
 'versions',
 'submitter',
 'update_date',
 'author',
 'version',
 'last_name',
 'first_name',
 'version_no',
 'version_date',
 'author_name',
 'version_ts']

In [19]:
df_exploded=df_exploded.drop("authors_parsed","versions","author","version","version_date")

In [24]:
df_exploded.coalesce(1).write \
    .mode("overwrite") \
    .option("header", True) \
    .csv("../../../data/tableauProcessedCSV/arxiv_exploded.csv")

                                                                                

SAMPLING

In [7]:
from pyspark.sql.functions import col

TARGETS = {
    "cs": 12000,
    "math": 10000,
    "physics": 8000,
    "engineering": 7000,
    "stat": 6000,
    "bio": 6000,
    "econ": 5000,
    "other": 6000
}


In [8]:
def sample_class(df, label, target, seed=42):
    class_df = df.filter(col("final_category") == label)
    total = class_df.count()

    if total <= target:
        # Not enough samples → keep all
        print(f"[INFO] Keeping all {total} samples for class '{label}'")
        return class_df
    else:
        fraction = target / total
        print(f"[INFO] Sampling class '{label}': {target}/{total} (fraction={fraction:.4f})")
        return class_df.sample(
            withReplacement=False,
            fraction=fraction,
            seed=seed
        )


In [9]:
sampled_dfs = []

for label, target in TARGETS.items():
    sampled = sample_class(df, label, target)
    sampled_dfs.append(sampled)


                                                                                

[INFO] Sampling class 'cs': 12000/691231 (fraction=0.0174)


                                                                                

[INFO] Sampling class 'math': 10000/556370 (fraction=0.0180)


                                                                                

[INFO] Sampling class 'physics': 8000/201469 (fraction=0.0397)


                                                                                

[INFO] Sampling class 'engineering': 7000/66850 (fraction=0.1047)


                                                                                

[INFO] Sampling class 'stat': 6000/58160 (fraction=0.1032)


                                                                                

[INFO] Sampling class 'bio': 6000/32717 (fraction=0.1834)


                                                                                

[INFO] Sampling class 'econ': 5000/10145 (fraction=0.4929)




[INFO] Sampling class 'other': 6000/1273390 (fraction=0.0047)


                                                                                

In [10]:
from functools import reduce

final_df = reduce(lambda d1, d2: d1.unionByName(d2), sampled_dfs)


In [11]:
from pyspark.sql.functions import count

final_dist = (
    final_df.groupBy("final_category")
    .agg(count("*").alias("count"))
    .orderBy(col("count").desc())
)

final_dist.show(truncate=False)




+--------------+-----+
|final_category|count|
+--------------+-----+
|cs            |12132|
|math          |10082|
|physics       |7987 |
|engineering   |6991 |
|other         |6114 |
|bio           |5998 |
|stat          |5903 |
|econ          |4968 |
+--------------+-----+



                                                                                

In [12]:
final_df = final_df.coalesce(8)

final_df.write \
    .mode("overwrite") \
    .parquet("../../../data/processed/arxiv_8class_60k.parquet")


                                                                                

In [14]:
final_df.limit(5).show()

+--------------------+--------------------+--------------------+--------------------+--------------------+---------+--------------------+--------------------+---------+-------------------+--------------------+-----------+--------------------+----------------+--------------+
|            abstract|             authors|      authors_parsed|            comments|                 doi|       id|         journal-ref|             license|report-no|          submitter|               title|update_date|            versions|primary_category|final_category|
+--------------------+--------------------+--------------------+--------------------+--------------------+---------+--------------------+--------------------+---------+-------------------+--------------------+-----------+--------------------+----------------+--------------+
|  We illustrate t...|Shenghui Su, and ...|[[Su, Shenghui, ]...|14 pages, and 2 t...|                NULL|0704.0492|                NULL|http://arxiv.org/...|     NULL|       