# Analyzing Trends in AI Research Publication
# *Full Ingestion from Kaggle*
# From Bronze To Silver
---


# Prepare Environment

## Import Packages

In [None]:
from delta.tables import DeltaTable
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, ArrayType

from datetime import datetime

import pytz


### Arxiv Database in Hive Metastore

In [None]:
# Ensure the 'arxiv' database exists or create it
if not spark.catalog.databaseExists("arxiv"):
    spark.sql("CREATE DATABASE arxiv")

# Switch to the 'arxiv' database
spark.sql("USE arxiv")

## Constants

In [None]:
BASE_PATH = "/mnt/arxiv/"
INGESTION_PATH = "/mnt/arxiv/bronze/kaggle/"


## Define Schema

In [None]:
kaggle_schema = StructType([
    StructField("id", StringType(), True),
    StructField("submitter", StringType(), True),
    StructField("authors", StringType(), True),
    StructField("title", StringType(), True),
    StructField("comments", StringType(), True),
    StructField("journal-ref", StringType(), True),
    StructField("doi", StringType(), True),
    StructField("report-no", StringType(), True),
    StructField("categories", StringType(), True),
    StructField("license", StringType(), True),
    StructField("abstract", StringType(), True),
    StructField("versions", ArrayType(StructType([
        StructField("created", StringType(), True),
        StructField("version", StringType(), True)
    ])), True),
    StructField("update_date", StringType(), True),
    StructField("authors_parsed", ArrayType(ArrayType(StringType())), True)
])


### Define Functions

In [None]:
def load_latest_json():
    """
    Load the most recent JSON file into a DataFrame
    """
    files = dbutils.fs.ls(INGESTION_PATH) 
    json_files = [f.name for f in files if f.name.endswith('.json')]
    sorted_files = sorted(json_files, reverse=True)
    latest_file = sorted_files[0]
    return spark.read.schema(kaggle_schema).json(INGESTION_PATH + latest_file)
    

In [None]:
def delta_table_exists(layer, table_name):
    """
    Check if Delta table exists
    """
    table_path = f"{BASE_PATH}{layer}/delta/{table_name}/_delta_log/"
    try:
        dbutils.fs.ls(table_path)
        return True
    except:
        return False


In [None]:
def create_or_replace_kaggle_delta(layer, table_name, chunk_size=12, recreate=False, join_on=["id"]):
    """
    Function to manage the creation or replacement of a Delta table from the latest JSON file.
    This function supports only full ingestion.
    
    Args:
        layer (str): The layer (bronze, silver, gold) where the Delta table resides or will reside.
        table_name (str): The name of the Delta table.
        chunk_size (int, optional): Number of partitions for chunking the JSON file. Default is 12.
        recreate (bool, optional): Whether to recreate the table if it already exists. Default is False.
        join_on (list, optional): This parameter is kept for compatibility but is not used in the function.
    """
    
    delta_path = f"{BASE_PATH}{layer}/delta/{table_name}/"
    
    # Inform the user that the JSON file is being read
    print("Reading the latest JSON file in chunks...")
    json_chunks = load_latest_json().repartition(chunk_size)
    
    if delta_table_exists(layer, table_name):
        print(f"The Delta table '{table_name}' already exists.")
        
        if recreate:
            print(f"Recreating the Delta table '{table_name}'...")
            
            # Drop the existing table and remove associated files
            spark.sql(f"DROP TABLE IF EXISTS {table_name}")
            dbutils.fs.rm(delta_path, recurse=True)
            
            # Write the new data from the JSON file
            json_chunks.write.format("delta").mode("overwrite").save(delta_path)
            
            # Register the new table in the Hive metastore
            spark.sql(f"""
            CREATE TABLE {table_name}
            USING DELTA 
            LOCATION '{delta_path}'
            """)
            
            print(f"The Delta table '{table_name}' has been recreated.")
    else:
        print(f"The Delta table '{table_name}' does not exist. Creating a new table...")
        
        # Create a new Delta table
        json_chunks.write.format("delta").mode("overwrite").save(delta_path)
        
        # Register the new table in the Hive metastore
        spark.sql(f"""
        CREATE TABLE {table_name}
        USING DELTA 
        LOCATION '{delta_path}'
        """)
        
        print(f"The Delta table '{table_name}' has been created.")
        
    # Display the first five rows of the Delta table
    print("Displaying the first five rows of the Delta table...")
    display(spark.read.format("delta").load(delta_path))

In [None]:
def create_or_update_delta(layer, table_name, data_source, 
                           join_on=["id", "last_update"], recreate=False):
    """
    Create, append, or recreate a Delta table in the specified layer, 
    register the table in the Hive metastore, and display the 
    first five rows of the Delta table.

    Args:
        layer (str): The layer (silver or gold) in which to create/append/recreate the Delta table.
        table_name (str): The name of the Delta table.
        data_source (DataFrame, optional): The Spark DataFrame to be loaded. 
                                           If None, the latest Parquet file from the ingestion path is used.
        join_on (list, optional): List of columns to join on when deduplicating data. Default is ["id", "last_update"].
        recreate (bool, optional): If True, drop and recreate the existing Delta table. Default is False.

    """
    delta_path = f"{BASE_PATH}{layer}/delta/{table_name}/"

    if delta_table_exists(layer, table_name):
        print(f"The Delta table '{table_name}' already exists.")
        
        if recreate:
            print(f"Recreating the Delta table '{table_name}'...")
            
            # Drop the existing Delta table
            spark.sql(f"DROP TABLE IF EXISTS {table_name}")
            
            # Remove the associated files of the Delta table
            dbutils.fs.rm(delta_path, recurse=True)
            
            # Create a new Delta table using the provided DataFrame's schema
            data_source.write.format("delta").mode("overwrite").save(delta_path)
            
            # Register the new Delta table in the Hive metastore
            spark.sql(f"""
            CREATE TABLE {table_name}
            USING DELTA 
            LOCATION '{delta_path}'
            """)
            
            print(f"The Delta table '{table_name}' has been recreated.")
        else:
            print(f"Appending new data to the existing Delta table '{table_name}'...")
            
            # Load new data
            new_data = data_source
            
            # Load existing data
            existing_data = spark.read.format("delta").load(delta_path)

            # Deduplicate new data with existing data based on provided columns
            new_data = new_data.join(existing_data, join_on, "left_anti")

            # Append new data to Delta table
            new_data.write.format("delta").mode("append").save(delta_path)
            
            print(f"New data has been appended to the Delta table '{table_name}'.")
    else:
        print(f"The Delta table '{table_name}' does not exist. Creating a new table...")
        
        # Create the Delta table
        data_source.write.format("delta").mode("overwrite").save(delta_path)
        
        # Register the Delta table in the Hive metastore under 'arxiv' database
        spark.sql(f"""
        CREATE TABLE {table_name}
        USING DELTA 
        LOCATION '{delta_path}'
        """)
        
        print(f"The Delta table '{table_name}' has been created.")
    
    # Display the first five rows of the Delta table
    print("Displaying the first five rows of the Delta table...")
    display(spark.read.format("delta").load(delta_path))


# Silver Layer

## Raw Table

In [None]:
create_or_replace_kaggle_delta("silver", "raw_kaggle")

### Create DataFrame from Raw table

In [None]:
raw_df = spark.table("raw_kaggle")
display(raw_df)

In [None]:
raw_df.count()

## Data Preprocessing

### Filter by AI Research Categories

In [None]:
# Define the list of categories
categories_list = ['cs.MA', 'cs.RO', 'cs.CV', 'cs.LG', 'cs.AI', 'cs.CL', 'cs.NE']

# Create a regular expression pattern to match any of the categories
categories_pattern = '|'.join(categories_list)

# Filter the DataFrame
filtered_df = raw_df.filter(F.col("categories").rlike(categories_pattern))

# Display the filtered DataFrame
display(filtered_df)

In [None]:
filtered_df.count()

### Schema Alignment

In [None]:
# Define UDFs
def convert_date(date_str):
    if date_str:
        try:
            dt = datetime.strptime(date_str, '%a, %d %b %Y %H:%M:%S %Z')
            dt = pytz.utc.localize(dt)
            return dt.strftime('%Y-%m-%dT%H:%M:%SZ')
        except Exception as e:
            return None
    return None

def format_authors(authors_parsed):
    authors = []
    for author in authors_parsed:
        if author[0] and author[1]:
            full_name = f"{author[1]} {author[0]}"
            authors.append(full_name)
    return authors

convert_date_udf = F.udf(convert_date, StringType())
format_authors_udf = F.udf(format_authors, ArrayType(StringType()))

# Assuming raw_kaggle is your original DataFrame
aligned_df = (filtered_df
    .withColumnRenamed("abstract", "summary")
    .withColumn("categories", F.split(F.col("categories"), " "))
    .withColumn("last_update", convert_date_udf(F.element_at(F.col("versions.created"), -1)))
    .withColumn("published_on", convert_date_udf(F.element_at(F.col("versions.created"), 1)))
    .withColumn("authors", format_authors_udf(F.col("authors_parsed")))
    .select("id", "title", "summary", "authors", "categories", "published_on", "last_update")
)

# Show the transformed DataFrame
display(aligned_df)

### Data Cleaning

#### Drop duplicate rows

In [None]:
# Count the number of rows in the original DataFrame
original_count = aligned_df["id", "title"].count()

# Print the original count
print(f"Original row count in the raw DataFrame: {original_count}")

# Drop duplicate rows and count the number of rows
deduped_df = aligned_df.dropDuplicates(["id", "title"])
new_count = deduped_df.count()

# Check if there were any duplicates
if original_count > new_count:
    print(f"There were {original_count - new_count} duplicate rows in the raw DataFrame.")
else:
    print("No duplicates found in the raw DataFrame.")

#### Drop rows with Null values

In [None]:
# Aggregate the data to count null values for each column
null_counts = deduped_df.agg(*[F.sum(F.when(F.isnull(c), 1).otherwise(0)).alias(c) for c in deduped_df.columns])

# Collect the data to the driver (since the result will be small)
null_counts_collected = null_counts.collect()[0].asDict()

# Flag to check if any column has null values
has_nulls = False

# Display columns with null values
for column, null_count in null_counts_collected.items():
    if null_count > 0:
        has_nulls = True
        print(f"Column {column} has {null_count} null values.")

# Print message if no null value is found
if not has_nulls:
    clean_df = deduped_df.dropna()
    print("No null values found in the DataFrame.")
else:
    # Print rows with null values
    print("\nRows with null values:")
    conditions = [F.isnull(c) for c in deduped_df.columns]
    null_rows = deduped_df.filter(conditions[0])
    for condition in conditions[1:]:
        null_rows = null_rows.union(deduped_df.filter(condition))
    null_rows.show()

    # Remove rows with null values
    clean_df = deduped_df.dropna()
    print("\nRows with null values have been removed.")

#### Column conversion

In [None]:
# Convert the 'published_on' column to date type
clean_df = clean_df.withColumn("published_on", clean_df["published_on"].cast("timestamp"))

# Convert the 'last_update' column to date type
clean_df = clean_df.withColumn("last_update", clean_df["last_update"].cast("timestamp"))

In [None]:
display(clean_df)

### Data Enrichment

#### Create date and time columns from timestamp columns

In [None]:
# Split 'published_on' into date and time
enriched_df = clean_df.withColumn("published_date", F.to_date(clean_df["published_on"]))
enriched_df = enriched_df.withColumn("published_time", F.date_format(enriched_df["published_on"], "HH:mm:ss"))

# Split 'last_update' into date and time
enriched_df = enriched_df.withColumn("last_update_date", F.to_date(enriched_df["last_update"]))
enriched_df = enriched_df.withColumn("last_update_time", F.date_format(enriched_df["last_update"], "HH:mm:ss"))

# If desired, drop the original timestamp columns
enriched_df = enriched_df.drop("published_on", "last_update")

# Display the enriched dataframe
display(enriched_df)

## Create or update Preprocessed table

In [None]:
create_or_update_delta("silver", "preprocessed", data_source=enriched_df)

## Show Preprocessed Table

In [None]:
spark.table("preprocessed").printSchema()

In [None]:
display(spark.table("preprocessed"))

In [None]:
spark.table("preprocessed").count()