In [1]:
# Import libraries
import glob
import os
import time
from typing import Dict, Union 

import pyspark

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lower, regexp_replace, udf, monotonically_increasing_id
from pyspark.sql.types import ArrayType, StringType, DoubleType, IntegerType, StructType, StructField

from pyspark.ml import Pipeline, Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from pyspark.ml.feature import CountVectorizer, IDF, StopWordsRemover, Tokenizer, SQLTransformer, VectorAssembler
from pyspark.ml.clustering import LDA
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import ClusteringEvaluator
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from pyspark.ml.linalg import Vector, Vectors

from nltk.stem import WordNetLemmatizer
import nltk

try:
    nltk.download('wordnet')
except LookupError:
    nltk.download('wordnet')

import spacy
from scispacy.abbreviation import AbbreviationDetector


[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/birgitte/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
class Lemmatizer(Transformer, HasInputCol, HasOutputCol):
    """
    A custom transformer that lemmatizes a text column using WordNetLemmatizer from the nltk package.
    """
    def __init__(self, inputCol="filtered_tokens", outputCol="lemmatized"):
        super(Lemmatizer, self).__init__()
        self.inputCol = Param(self, "inputCol", "input column name")
        self.outputCol = Param(self, "outputCol", "output column name")
        self._setDefault(inputCol=inputCol, outputCol=outputCol)

    def getLemma(self, tokens):
        lemmatizer = WordNetLemmatizer()
        return [lemmatizer.lemmatize(token) for token in tokens]

    def _transform(self, df):
        lemma_udf = udf(lambda tokens: self.getLemma(tokens), ArrayType(StringType()))
        return df.withColumn(self.getOutputCol(), lemma_udf(df[self.getInputCol()]))

In [3]:
# Initialize Spark Session
spark = (SparkSession.builder.appName("Spark ETL and LDA model pipeline")
    .master("local[4]")                
    .config("spark.driver.memory", "8g") 
    .config("spark.driver.maxResultSize", "2G")
).getOrCreate()

23/06/02 21:46:40 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [4]:
def ingest_data(path: str) -> pyspark.sql.dataframe.DataFrame: 
    
    """
    Ingest the latest Pubmed Abstract CSV file into a PySpark DataFrame

    """
    filename = max(glob.iglob(f"{path}/*.csv"), key=os.path.getmtime)
    
    schema = StructType([
        StructField("Author", StringType(), True),
        StructField("Title", StringType(), True),
        StructField("Year", IntegerType(), True),
        StructField("Country", StringType(), True),
        StructField("Journal", StringType(), True),
        StructField("DOI", StringType(), True),
        StructField("Abstract", StringType(), True)
    ])
    
    df = spark.read.format("csv") \
    .option("header", "true") \
    .schema(schema) \
    .load(filename)
    
    return df

In [5]:
def transform(df: pyspark.sql.dataframe.DataFrame) -> pyspark.sql.dataframe.DataFrame:
    
    # Transformation stages
    tokenizer = Tokenizer(inputCol="Abstract", outputCol="tokens")
    stopword_remover = StopWordsRemover(inputCol=tokenizer.getOutputCol(), outputCol="filtered_tokens")
    lemmatizer = Lemmatizer(inputCol=stopword_remover.getOutputCol(), outputCol='lemmatized')
     
    # Feature stages: convert text into numerical features
    count_vectorizer = CountVectorizer(inputCol=lemmatizer.getOutputCol(), outputCol="tf")
    idf = IDF(inputCol=count_vectorizer.getOutputCol(), outputCol="features") # tf_idf
    
    # Pipeline
    pipeline = Pipeline(stages=[
        tokenizer,
        stopword_remover,
        lemmatizer,
        count_vectorizer,
        idf
    ])
    
    # Fit the pipeline on the DataFrame
    pipeline = pipeline.fit(df)
    
    # Apply the pipeline on the DataFrame
    df_transformed = pipeline.transform(df)
    
    # Remove duplicates
    df_transformed = df_transformed.dropDuplicates()
    
    # Handle missing values
    df_transformed = df_transformed.na.fill("Unknown", subset=["Author", "Title", "Country", "Journal", "DOI", "Abstract"])
    df_transformed = df_transformed.na.fill(0, subset=["Year"])
    
    # Standardize data formats
    df_transformed = df_transformed.withColumn("Journal", lower(col("Journal")))
    df_transformed = df_transformed.withColumn("DOI", regexp_replace(col("DOI"), "^http(s)?://doi.org/", ""))
    
    # Add incrementally increasing index
    df_transformed = df_transformed.withColumn("index", monotonically_increasing_id())
    
    return df_transformed, pipeline

In [6]:
def main(path):
    """
    Main function that puts everything together. 
    The function will expand as the PoC goes on.
    """
    df = ingest_data(path)
    transformed, pipeline = transform(df)
    return transformed, pipeline

In [7]:
data_dir ="../data/proc"

transformed, pipeline_model = main(data_dir)
transformed.show(5, truncate=True)



+------------+--------------------+----+-------+--------------------+-------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----+
|      Author|               Title|Year|Country|             Journal|    DOI|            Abstract|              tokens|     filtered_tokens|          lemmatized|                  tf|            features|index|
+------------+--------------------+----+-------+--------------------+-------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----+
|  M J Carson|Simultaneous occu...|1977|Unknown|      am j dis child|Unknown|we report the cas...|[we, report, the,...|[report, cases, t...|[report, case, tw...|(21520,[0,1,2,4,7...|(21520,[0,1,2,4,7...|    0|
| J Kavelaars|Hypernatremia in ...|2001|Unknown|          neth j med|Unknown|we describe a pat...|[we, describe, a,...|[describe, patien...|[describe, patien...

                                                                                

## The following code is work in progress

Topic modelling using LDA

In [8]:
from pyspark.ml.clustering import LDA

# Split the transformed data into training and test sets
train_data, test_data = transformed.randomSplit([0.8, 0.2], seed=42)

# LDA model
lda = LDA(k=30, maxIter=10, featuresCol="features")  # Specify the number of topics (k) and iterations
lda_model = lda.fit(train_data)

wordNumbers = 5

def topic_render(topic, vocabulary):
    terms = topic[1]
    result = []
    for i in range(wordNumbers):
        term = vocabulary[terms[i]]
        result.append(term)
    return result

count_vectorizer_model = pipeline_model.stages[3]
vocabulary = count_vectorizer_model.vocabulary
topics_final = lda_model.describeTopics(maxTermsPerTopic=wordNumbers).rdd.map(lambda x: topic_render(x, vocabulary)).collect()

for topic in range(len(topics_final)):
    print(f"Topic {topic}:")
    for term in topics_final[topic]:
        print(term)
    print('\n')
    


                                                                                

23/06/02 21:47:36 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
23/06/02 21:47:36 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS


[Stage 54:>                                                         (0 + 4) / 4]

Topic 0:
dm
corneal
bb
alanine
limb


Topic 1:
program
spinal
injury
cord
arrest


Topic 2:
alpha
pge
adrenoceptor
excretion



Topic 3:
diabetic

pregnancy
control
woman


Topic 4:
wfs
syndrome
wolfram
atrophy
optic


Topic 5:
bladder
cat
donor
hz
organ


Topic 6:
camp
tolbutamide
phosphodiesterase
hctz
ilk


Topic 7:
rat
brattleboro
di
na
nucleus


Topic 8:
mirnas
abscess
nystagmus
noteworthy
bind


Topic 9:
mri
lymphocytic
ma
hypophysitis
autoimmune


Topic 10:
neurosurgical
cdi
epidemiology
pmolliter
ttp


Topic 11:

diabetic
glucose
insulin
p


Topic 12:
proximal
microlitermin
par
cli
delivery


Topic 13:
periodontal
haematoma
periodontitis
dental
microm


Topic 14:
vep
bfv
oa
ophthalmic
latency


Topic 15:
siadh
mood
li
hyponatremia
urate


Topic 16:
retinopathy

heart
diabetic
ii


Topic 17:
pmolliter
parenchyma
net
range
glomerular


Topic 18:
crf
anorectal
ici
tumor
grading


Topic 19:
pmolliter
nelson
grl
pituitary
homeostatic


Topic 20:
platelet
type
ii
altitude
cofactor




                                                                                

In [9]:
# Evaluate perplexity on the test data
perplexity = lda_model.logPerplexity(test_data)
print(f"Perplexity: {perplexity}")

                                                                                

Perplexity: 12.691801656873688


In [10]:
# TODO: Create abbreviation/acronym replacer
# TODO: Hyperparameter tuning of basemodel
# TODO: Model evaluation
# TODO: Visualize topics (with Bokeh or Dash?)
# TODO: Visualize publications per country per year

In [None]:
# Non-spark test
import spacy
from scispacy.abbreviation import AbbreviationDetector
nlp = spacy.load("en_core_sci_md")
# Add the abbreviation pipe to the spacy pipeline.
nlp.add_pipe("abbreviation_detector")
doc = nlp(
    "Chronic lymphocytic leukemia (CLL), autoimmune hemolytic anemia, and oral ulcer. The patient was diagnosed with chronic lymphocytic leukemia and was noted to have autoimmune hemolytic anemia at the time of his CLL diagnosis."
)
fmt_str = "{:<6}| {:<30}| {:<6}| {:<6}"
print(fmt_str.format("Short", "Long", "Starts", "Ends"))
for abrv in doc._.abbreviations:
    print(fmt_str.format(abrv.text, str(abrv._.long_form), abrv.start, abrv.end))


In [None]:
# WIP: Abbreviation replacement using scispacy
import spacy
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType


def convert_abbreviations(text: str) -> str:
    """
    Convert abbreviations to text using SciSpacy.
    :param text: Input text data.
    """
    nlp = spacy.load("en_core_sci_md")
    abbreviation_pipe = nlp.create_pipe("abbreviation_detector")
    nlp.add_pipe("abbreviation_detector")
    doc = nlp(text)
    altered_tok = [tok.text for tok in doc]
    for abrv in doc._.abbreviations:
        altered_tok[abrv.start] = str(abrv._.long_form)
    return " ".join(altered_tok)


# Create a UDF from the convert_abbreviations function
convert_abbreviations_udf = udf(convert_abbreviations, StringType())

# Apply the UDF to the text_column
test = transformed.withColumn("converted_text", convert_abbreviations_udf(transformed["Abstract"]))

test.select(test["converted_text"]).show(5)