In [1]:
import sys
import glob
from pathlib import Path

from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col,
    explode,
    split,
    regexp_replace,
    lit,
    concat,
    lead,
    lower
)
from pyspark.sql import Window

In [2]:
SONG_LYRICS_FILES = "/home/jovyan/work/data/song_lyrics/*.txt"
OUTPUT_FILE = "/home/jovyan/work/data/bigram_counts.csv"
SPLIT_PATTERN = r"[\s_]+|[^\w-'`*:$&.]+|\.\.\.*"
REPLACE_PATTERN = r"^\W+|\W+$"

In [3]:
files = glob.glob(SONG_LYRICS_FILES)
output_file_dir_path = Path(OUTPUT_FILE).parent

if not files:
    sys.exit(f"Path does not exist: {SONG_LYRICS_FILES}")

if not output_file_dir_path.exists():
    sys.exit(f"Path does not exist: {str(output_file_dir_path)}")

In [4]:
spark = (
    SparkSession.builder
    .master("spark://spark-master:7077")
    .appName("bigram-count")
    .getOrCreate()
)

In [5]:
song_lyrics_df = (
    spark.read
    .format("text")
    .load(SONG_LYRICS_FILES)
)

In [6]:
words_df = (
    song_lyrics_df
    .withColumn("word", explode(split(col("value"), SPLIT_PATTERN)))
    .withColumn("word", regexp_replace(col("word"), REPLACE_PATTERN, ""))
    .withColumn("word", lower(col("word")))
    .where(col("word") != "")
    .select("word")
)

In [7]:
window_spec_fake_order = Window.orderBy(lit(1))

bigrams_df = (
    words_df
    .withColumn(
        "word", 
        concat(
            col("word"), 
            lit(" "), 
            lead(col("word"), 1, None).over(window_spec_fake_order)
        )
    )
    .dropna()
)

In [8]:
counted_unique_bigrams_df = (
    bigrams_df
    .groupBy("word")
    .count()
    .orderBy(col("count").desc())
)


In [9]:
print("total_bigrams")
bigrams_df.count()

total_bigrams


1348940

In [10]:
print("bigram_counts")
counted_unique_bigrams_df.show(10, False)

bigram_counts
+--------+-----+
|word    |count|
+--------+-----+
|in the  |5562 |
|and i   |2809 |
|on the  |2651 |
|you know|2359 |
|i don't |2142 |
|i know  |2119 |
|to the  |2118 |
|i got   |1831 |
|if you  |1736 |
|like a  |1665 |
+--------+-----+
only showing top 10 rows



In [11]:
(
    counted_unique_bigrams_df
    .toPandas()
    .to_csv(OUTPUT_FILE, sep=",", index=False)
)

In [12]:
spark.stop()