# Reddit Spark Streaming Consumer

In [None]:
import json
from textblob import TextBlob
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import StructType, StructField, StringType, DoubleType

In [None]:
HOST = 'host.docker.internal'
PORT = 9998
spark = SparkSession.builder.appName('RedditConsumer').getOrCreate()

In [None]:
schema = StructType([
    StructField('type', StringType()),
    StructField('subreddit', StringType()),
    StructField('id', StringType()),
    StructField('text', StringType()),
    StructField('created_utc', DoubleType()),
    StructField('author', StringType())
])

In [None]:
raw_lines = (spark.readStream.format('socket').option('host', HOST).option('port', PORT).load())
json_df = raw_lines.select(F.from_json(F.col('value'), schema).alias('data')).select('data.*')

In [None]:
query_memory = (json_df.writeStream.outputMode('append').format('memory').queryName('raw').start())
query_files = (json_df.writeStream.outputMode('append').format('parquet').option('path', 'data/raw').option('checkpointLocation', 'chk/raw').start())

In [None]:
user_refs = F.expr("regexp_extract_all(text, '/u/[^\\s]+')")
sub_refs = F.expr("regexp_extract_all(text, '/r/[^\\s]+')")
url_refs = F.expr("regexp_extract_all(text, 'https?://[^\\s]+')")
refs_df = json_df.select(F.col('created_utc').cast('timestamp').alias('created_ts'), F.size(user_refs).alias('user_ref_count'), F.size(sub_refs).alias('sub_ref_count'), F.size(url_refs).alias('url_ref_count'))
windowed_refs = (refs_df.withWatermark('created_ts', '1 minute').groupBy(F.window('created_ts', '60 seconds', '5 seconds')).sum('user_ref_count', 'sub_ref_count', 'url_ref_count'))
ref_query = (windowed_refs.writeStream.outputMode('update').format('console').option('truncate', False).start())

In [None]:
from pyspark.ml.feature import Tokenizer, StopWordsRemover, HashingTF, IDF
@F.udf('double')
def sentiment_udf(text):
    return float(TextBlob(text).sentiment.polarity) if text else 0.0

def compute_tfidf(df):
    tokenizer = Tokenizer(inputCol='text', outputCol='words')
    words = tokenizer.transform(df)
    remover = StopWordsRemover(inputCol='words', outputCol='filtered')
    filtered = remover.transform(words)
    hashingTF = HashingTF(inputCol='filtered', outputCol='rawFeatures', numFeatures=10000)
    featurized = hashingTF.transform(filtered)
    idf = IDF(inputCol='rawFeatures', outputCol='features')
    idf_model = idf.fit(featurized)
    tfidf = idf_model.transform(featurized)
    zipped = tfidf.select(F.explode(F.arrays_zip('filtered', 'features')).alias('z'))
    scores = zipped.select(F.col('z.filtered').alias('word'), F.col('z.features').alias('score'))
    top_words = scores.groupBy('word').agg(F.max('score').alias('score')).orderBy(F.desc('score')).limit(10)
    top_words.show(truncate=False)

def process_batch(batch_df, epoch_id):
    batch_df.persist()
    if spark.catalog.tableExists('raw'):
        full_df = spark.table('raw').unionByName(batch_df)
    else:
        full_df = batch_df
    full_df.createOrReplaceTempView('raw')
    batch_df.write.mode('append').parquet('data/raw')
    bounds = full_df.agg(F.min('created_utc').alias('min_ts'), F.max('created_utc').alias('max_ts')).collect()[0]
    print(f"Data time range: {bounds['min_ts']} - {bounds['max_ts']}")
    sentiments = batch_df.withColumn('sentiment', sentiment_udf('text'))
    avg_sent = sentiments.agg(F.avg('sentiment')).collect()[0][0]
    print(f'Average sentiment (batch): {avg_sent}')
    top_authors = batch_df.groupBy('author').count().orderBy(F.desc('count')).limit(5)
    top_authors.show(truncate=False)
    compute_tfidf(full_df)
    batch_df.unpersist()

In [None]:
process_query = json_df.writeStream.foreachBatch(process_batch).start()

In [None]:
for q in [query_memory, query_files, ref_query, process_query]:
    q.awaitTermination()