Replacement Assignment for Daan Steur

Assignment
You will need to parse one of the PubMed XML files located in the /data/datasets/NCBI/PubMed/ directory. They contain all the information about articles published in a particular time period.

the script
I would like you to write a script that;

Parses a PubMed XML file into a PySpark dataframe.
The dataframe should contain the following information in the columns;

- PubMed ID
- First Author
- Last Author
- Year published
- Title
- Journal Title
- Length of Abstract (if Abstract text is present).

A column of references in a list variable, if references are present for the article.
When you have the table, please use PySpark functions to compute:

- Number of articles per First Author
- Number of articles per Year
- Minimum, maximum, Average length of an abstract
- Average Number of articles per Journal Title per Year

### Load needed packages and data into a spark dataframe

In [8]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, year, length, avg
from pyspark.sql.types import ArrayType, StringType, IntegerType, FloatType, StructType, StructField
import csv

In [9]:

# Function to create and configure a Spark session
def create_spark_session():
    """
    Creates and configures a Spark session with specific settings.

    Returns:
        SparkSession: A configured Spark session.
    """
    return SparkSession.builder.master("local[16]") \
        .config('spark.driver.memory', '128g') \
        .config('spark.executor.memory', '128g') \
        .config("spark.sql.debug.maxToStringFields", "100") \
        .appName("PubMedAnalysis").getOrCreate()

# User-Defined Functions (UDFs)

# UDF to extract references as a list
def extract_references(ref_list):
    """
    Extracts references from a list and returns them as a Python list.

    Args:
        ref_list (list): List of references in XML format.

    Returns:
        list: List of references as text, or an empty list if none are present.
    """
    return [ref[0].text for ref in ref_list] if ref_list is not None else []

# UDF to extract the year from PubDate
def extract_year(pub_date):
    """
    Extracts the year from a PubDate element.

    Args:
        pub_date (Element): PubDate element from XML.

    Returns:
        int or None: Year as an integer or None if not found.
    """
    return int(pub_date.Year[0].text) if pub_date is not None and pub_date.Year is not None else None

# Function to parse PubMed XML into a PySpark DataFrame
def parse_pubmed_xml(spark, xml_file_path):
    """
    Parses a PubMed XML file into a PySpark DataFrame with specified columns.

    Args:
        spark (SparkSession): Spark session.
        xml_file_path (str): Path to the XML file.

    Returns:
        DataFrame: PySpark DataFrame containing parsed data.
    """
    # Define the schema for the XML data
    xml_schema = StructType([
        StructField("PubMed_ID", StringType(), True),
        StructField("First_Author", StringType(), True),
        StructField("Last_Author", StringType(), True),
        StructField("Year_Published", StringType(), True),
        StructField("Title", StringType(), True),
        StructField("Journal_Title", StringType(), True),
        StructField("Abstract_Length", StringType(), True),
        StructField("References", StringType(), True)
    ])
    
    # Read the XML file using spark-xml and the specified schema
    df = spark.read \
        .format("com.databricks.spark.xml") \
        .option("rootTag", "PubmedArticle") \
        .option("rowTag", "PubmedArticle") \
        .schema(xml_schema) \
        .load(xml_file_path)
    
    # Perform data transformations if needed
    return df


# Path to the PubMed XML file
pubmed_path = "/data/datasets/NCBI/PubMed/pubmed21n0001.xml"

# Create a Spark session
spark = create_spark_session()

# Parse PubMed XML into a DataFrame
pubmed_df = parse_pubmed_xml(spark, pubmed_path)

# Display the first 5 rows of the DataFrame
pubmed_df.show(5, truncate=False)


Py4JJavaError: An error occurred while calling o71.load.
: org.apache.spark.SparkClassNotFoundException: [DATA_SOURCE_NOT_FOUND] Failed to find the data source: com.databricks.spark.xml. Please find packages at `https://spark.apache.org/third-party-projects.html`.
	at org.apache.spark.sql.errors.QueryExecutionErrors$.dataSourceNotFoundError(QueryExecutionErrors.scala:738)
	at org.apache.spark.sql.execution.datasources.DataSource$.lookupDataSource(DataSource.scala:647)
	at org.apache.spark.sql.execution.datasources.DataSource$.lookupDataSourceV2(DataSource.scala:697)
	at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:208)
	at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:186)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: java.lang.ClassNotFoundException: com.databricks.spark.xml.DefaultSource
	at java.base/java.net.URLClassLoader.findClass(URLClassLoader.java:476)
	at java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:589)
	at java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:522)
	at org.apache.spark.sql.execution.datasources.DataSource$.$anonfun$lookupDataSource$5(DataSource.scala:633)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.sql.execution.datasources.DataSource$.$anonfun$lookupDataSource$4(DataSource.scala:633)
	at scala.util.Failure.orElse(Try.scala:224)
	at org.apache.spark.sql.execution.datasources.DataSource$.lookupDataSource(DataSource.scala:633)
	... 15 more


## Anwser questions about the pubmed data

In [None]:
# Compute Number of articles per First Author
def compute_articles_per_first_author(pubmed_df):
    return pubmed_df.groupBy("First_Author").count()

aws1 = compute_articles_per_first_author(pubmed_df)


In [None]:
# Compute Number of articles per Year
def compute_articles_per_year(pubmed_df):
    return pubmed_df.groupBy("Year_Published").count()

aws2 = compute_articles_per_year(pubmed_df)

In [None]:
# Compute Minimum, Maximum, and Average length of an abstract
def compute_abstract_stats(pubmed_df):
    return pubmed_df.agg(
        {"Abstract_Length": "min", "Abstract_Length": "max", "Abstract_Length": "avg"}
    )
    
aws3 = compute_abstract_stats(pubmed_df)

In [None]:
# Compute Average Number of articles per Journal Title per Year
def compute_avg_articles_per_journal_year(pubmed_df):
    return pubmed_df.groupBy("Journal_Title", "Year_Published").count().groupBy("Journal_Title").agg(avg("count").alias("Avg_Articles_Per_Year"))

aws4 = compute_avg_articles_per_journal_year(pubmed_df)

In [None]:
# Main function
def main(xml_file_path):
    spark = create_spark_session()
    pubmed_df = parse_pubmed_xml(spark, xml_file_path)
    
    articles_per_first_author = compute_articles_per_first_author(pubmed_df)
    articles_per_year = compute_articles_per_year(pubmed_df)
    abstract_stats = compute_abstract_stats(pubmed_df)
    avg_articles_per_journal_year = compute_avg_articles_per_journal_year(pubmed_df)
    
    # Write answers to a CSV file
    with open("pubmed_analysis_answers.csv", "w", newline="") as csvfile:
        csv_writer = csv.writer(csvfile)
        
        # Write headers
        csv_writer.writerow(["Question", "Answer"])
        
        # Write answers
        csv_writer.writerow(["Number of articles per First Author", ""])
        for row in articles_per_first_author.collect():
            csv_writer.writerow([row["First_Author"], row["count"]])
        
        csv_writer.writerow(["Number of articles per Year", ""])
        for row in articles_per_year.collect():
            csv_writer.writerow([row["Year_Published"], row["count"]])
        
        csv_writer.writerow(["Minimum, Maximum, and Average length of an abstract", ""])
        abstract_stats_data = abstract_stats.collect()[0]
        csv_writer.writerow(["Minimum Abstract Length", abstract_stats_data["min(Abstract_Length)"]])
        csv_writer.writerow(["Maximum Abstract Length", abstract_stats_data["max(Abstract_Length)"]])
        csv_writer.writerow(["Average Abstract Length", abstract_stats_data["avg(Abstract_Length)"]])
        
        csv_writer.writerow(["Average Number of articles per Journal Title per Year", ""])
        for row in avg_articles_per_journal_year.collect():
            csv_writer.writerow([row["Journal_Title"], row["Avg_Articles_Per_Year"]])
    
    spark.stop()

xml_file_path = "pubmed21n0001.xml"    
main(xml_file_path)