<a href="https://colab.research.google.com/github/Tstrebe2/predicting-text-difficulty/blob/dave-updates/code/dave-text-cleaning-pipeline-test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sys

!{sys.executable} -m pip install pyspark==3.1.2 -q
!{sys.executable} -m pip install spark-nlp==4.2.0 -q
# !{sys.executable} -m pip install -U spacy==3.4.1 -q

# !{sys.executable} -m spacy download en_core_web_lg -q

[K     |████████████████████████████████| 212.4 MB 61 kB/s 
[K     |████████████████████████████████| 198 kB 20.0 MB/s 
[?25h  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 641 kB 5.1 MB/s 
[?25h

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!wget https://raw.githubusercontent.com/Tstrebe2/predicting-text-difficulty/main/assets/WikiLarge_Train.csv -q
!wget https://raw.githubusercontent.com/mahavivo/vocabulary/master/lemmas/AntBNC_lemmas_ver_001.txt -q
!wget https://raw.githubusercontent.com/Tstrebe2/predicting-text-difficulty/main/assets/dale_chall.txt -q
!wget https://raw.githubusercontent.com/Tstrebe2/predicting-text-difficulty/main/assets/AoA_51715_words.csv -q
!wget https://raw.githubusercontent.com/Tstrebe2/predicting-text-difficulty/main/assets/Concreteness_ratings_Brysbaert_et_al_BRM.txt -q

In [None]:
import pandas as pd

aoa = pd.read_csv('/content/AoA_51715_words.csv', 
                     encoding_errors='ignore', 
                     usecols=['Lemma_highest_PoS', 'AoA_Kup_lem'],
                     ).rename({'Lemma_highest_PoS':'lemma', 'AoA_Kup_lem':'aoa'}, axis=1)

aoa = aoa.groupby('lemma').first().to_dict()['aoa']

conc = (pd.read_csv('/content/Concreteness_ratings_Brysbaert_et_al_BRM.txt', 
                   sep='\t',
                   usecols=['Word', 'Bigram', 'Conc.M'])
        .rename({'Word':'word', 'Bigram':'bigram', 'Conc.M':'conc_mean'}, axis=1))

def split_word(x):
  if x['bigram'] == 0:
    word_or_phrase = x['word']
  else:
    word_or_phrase = tuple(x['word'].split(' '))

  return {'word':word_or_phrase, 'conc_mean':x['conc_mean'] }

conc = conc.apply(split_word, axis=1, result_type='expand').set_index('word').to_dict()['conc_mean']

d_chall = set(pd.read_csv('/content/dale_chall.txt', names=['word'])['word'].tolist())

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from pyspark.ml import Pipeline
import pyspark.sql.functions as F
from sparknlp.base import DocumentAssembler
from sparknlp.annotator import SentenceDetector, Tokenizer, Normalizer, Lemmatizer
import sparknlp
from pyspark.sql.types import StringType, ArrayType, FloatType, StructType
import numpy as np

spark = SparkSession.builder\
        .master("local")\
        .appName("Colab")\
        .config('spark.ui.port', '4050')\
        .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:4.2.0")\
        .getOrCreate()

df = spark.read.csv('/content/WikiLarge_Train.csv', header=True)
df.createOrReplaceTempView('wiki')

regex1, replace1 = r'-LRB-', '('
regex2, replace2 = r'-RRB-', ')'
regex3, replace3 = r"\'\' *", ''
regex4, replace4 = r", ; *", ''
regex5, replace5 = r"; , *", ''
regex6, replace6 = r", , *", ''
regex7, replace7 = r"; ; *", ''
regex8, replace8 = r"[(] [;,-]* [)] ", ''
regex9, replace9 = r"[(] +[)] *", ''
regex10, replace10 = r" km ", "kilometers"
regex11, replace11 = r"[0-9]+(km) ", "kilometers"
regex12, replace12 = r" mph "," miles per hour "
regex13, replace13 = r"° C ","degrees Celsius"
regex14, replace14 = r"° F ","degrees Farenheit"
regex15, replace15 = r"°","degrees"
regex16, replace16 = r" %"," percent"
regex17, replace17 = r" cm"," centimeters"
regex18, replace18 = r" kg "," kilograms "

iterable = ((regex1, replace1), (regex2, replace2), (regex3, replace3),
            (regex4, replace4), (regex5, replace5), (regex6, replace6),
            (regex7, replace7), (regex8, replace8), (regex9, replace9),
            (regex10,replace10), (regex11,replace11),(regex12, replace12),
            (regex13, replace13),(regex14,replace14),(regex15,replace15),
            (regex16,replace16), (regex17,replace17),(regex18,replace18))

for regex, replace in iterable:
  query = f"""
  SELECT
    regexp_replace(original_text, '{regex}', '{replace}') as original_text, 
    label 
  FROM wiki;"""
  df = spark.sql(query)
  df.createOrReplaceTempView('wiki')

query = r"SELECT * FROM wiki WHERE LENGTH(original_text) > 20;"
df = spark.sql(query)
df.createOrReplaceTempView('wiki')

documentAssembler = DocumentAssembler()\
    .setInputCol("original_text")\
    .setOutputCol("document")

tokenizer = Tokenizer() \
    .setInputCols(["document"]) \
    .setOutputCol("token")

lemmatizer = Lemmatizer() \
    .setInputCols(["token"]) \
    .setOutputCol("lemma") \
    .setDictionary("./AntBNC_lemmas_ver_001.txt", value_delimiter ="\t", key_delimiter = "->")

nlp_pipeline = Pipeline(stages=[documentAssembler,
                               tokenizer,
                               lemmatizer,
                               ])

nlp_pipeline = nlp_pipeline.fit(df)
df = nlp_pipeline.transform(df)
df.createOrReplaceTempView('wiki')
df.show(5, 0)

+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [None]:
def get_d_chall(x):
  num = len([w for w in x if w in d_chall])
  denom = len(x)
  return num/denom

def get_aoa(x):
  arr = [aoa[w.lower()] for w in x if w.lower() in aoa]
  if len(arr) > 0:
    return arr
  else:
    return [0.0] 

def get_conc_rating(x):
  ret_val = []

  bigrams = [(f.lower(), s.lower()) for f, s in zip(x[:-1], x[1:])]
  cont = False
  
  for bigram in bigrams:
    if cont:
      cont = False
      continue

    if bigram in conc:
      cont = True
      ret_val.append(conc[bigram])
    elif bigram[0] in conc:
      ret_val.append(conc[bigram[0]])

  return ret_val

spark.udf.register('get_d_chall', get_d_chall, FloatType())
spark.udf.register('get_aoa', get_aoa, ArrayType(FloatType()))
spark.udf.register('get_conc_rating', get_conc_rating, ArrayType(FloatType()))
spark.udf.register('get_joined_text', lambda x: ' '.join(x), StringType())
spark.udf.register('array_mean', lambda x: float(np.mean(x)), FloatType())

query = r"""
SELECT 
  original_text, 
  get_joined_text(lemma.result) as lemmatized_text, 
  get_d_chall(lemma.result) as d_chall_score,
  get_aoa(lemma.result) as aoa,
  get_conc_rating(lemma.result) as conc_rating,
  label
FROM wiki;
"""
df = spark.sql(query)
df.createOrReplaceTempView('wiki')
df.show(5, 0)

+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----+
|original_text                                                  

In [None]:
import numpy as np

spark.udf.register('array_mean', lambda x: float(np.mean(x)), FloatType())

query = r"""
SELECT 
original_text, 
lemmatized_text,
d_chall_score,
array_mean(aoa) as aoa_mean, 
array_min(aoa) as aoa_min, 
array_max(aoa) as aoa_max, 
array_mean(conc_rating) as conc_rating_mean, 
array_min(conc_rating) as conc_rating_min, 
array_max(conc_rating) as conc_rating_max,
label
FROM wiki; 
"""
df = spark.sql(query)
df.createOrReplaceTempView('wiki')
df.show(5, 0)

+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+---------+-------+-------+----------------+---------------+---------------+-----+
|original_text                                                                                                                                                                                                                                         |lemmatized_text                                                                                                                                                                        

In [None]:
df = df.toPandas()

In [None]:
df.head()

Unnamed: 0,original_text,lemmatized_text,d_chall_score,aoa_mean,aoa_min,aoa_max,conc_rating_mean,conc_rating_min,conc_rating_max,label
0,There is manuscript evidence that Austen conti...,There be manuscript evidence that Austen conti...,0.5,5.80931,3.57,12.12,2.495517,1.33,4.57,1
1,"In a remarkable comparative analysis , Mandaea...","In a remarkable comparative analysis , Mandaea...",0.208333,7.402308,2.89,11.94,2.334286,1.46,4.93,1
2,"Before Persephone was released to Hermes , who...","Before Persephone be release to Hermes , who h...",0.630435,5.231351,2.78,11.17,2.556486,1.43,4.86,1
3,Cogeneration plants are commonly found in dist...,Cogeneration plant be commonly find in distric...,0.538462,6.742,3.56,11.53,3.369655,1.52,4.93,1
4,Geneva is the second-most-populous city in Swi...,Geneva be the second-most-populous city in Swi...,0.481481,5.455,3.69,12.62,2.399333,1.43,4.79,1
