Imports and initialization.

In [1]:
# imports
import json
import findspark
import pyspark
from pyspark.sql import SparkSession
from transformers import BertTokenizer

# initialization
findspark.init()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', local_files_only=True)
spark = (
    SparkSession
    .builder
    .appName("Preprocessing")
    .getOrCreate()
)
sc = spark.sparkContext

22/07/06 12:04:52 WARN Utils: Your hostname, ThinkPad-X1-Gen6 resolves to a loopback address: 127.0.1.1; using 192.168.0.16 instead (on interface wlp2s0)
22/07/06 12:04:52 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/07/06 12:04:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Read the lines of the text file.

In [2]:
# Load a text file and convert each line to a Row.
lines = sc.textFile("wikitext-2/wiki.train.tokens", 40)
print(f'Number of lines: {lines.count()}')



Number of lines: 36718


                                                                                

Preprocessing functions: Lower-casing, stopword removal, transforming to input IDs.

In [3]:
def lowercase(text):
    return text.lower()

def remove_stopwords(text):
    from nltk.corpus import stopwords
    return ' '.join([t for t in text.split() if t not in stopwords.words('english')])

def encode(tokenizer, text):
    return tokenizer(text)

We apply the preprocessing functions. Spark will take care of distributing the processing.

In [4]:
output = (
    lines
    .filter(lambda l: l.strip())                      # remove empty lines
    .filter(lambda l: not l.strip().startswith('='))  # remove comment lines starting with '='
    .map(lowercase)                                   # lowercase
    .map(str.strip)                                   # strip whitespace
    .map(remove_stopwords)                            # remove stopwords
    .map(lambda l: encode(tokenizer, l))              # encode words to IDs read by the BERT model
)

Computation only gets started once the output is required (*lazy evaluation*).

In [5]:
output_ids = output.collect()
print(output_ids[0]['input_ids'][:10])

Token indices sequence length is longer than the specified maximum sequence length for this model (796 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (591 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (517 > 512). Running this sequence through the model will result in indexing errors
                                                                                

[101, 12411, 5558, 11748, 4801, 4360, 1017, 1024, 1026, 4895]


Now we can write the output IDs to a JSON file.

In [6]:
output_ids = [dict(oids) for oids in output_ids]
with open('output_ids.json', 'w') as f:
    json.dump(output_ids, f)

We can now stop the Spark session.

In [7]:
spark.stop()