# Reddit Spark Streaming Consumer
This notebook receives Reddit posts/comments from a socket, stores them to a Spark table, and computes metrics such as reference counts, TF-IDF top words, and sentiment analysis.

In [None]:
import json
import re
from textblob import TextBlob
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, DoubleType

spark = SparkSession.builder.appName('RedditConsumer').getOrCreate()

HOST = 'host.docker.internal'  # change if running locally
PORT = 9998

schema = StructType([
    StructField('type', StringType()),
    StructField('subreddit', StringType()),
    StructField('id', StringType()),
    StructField('text', StringType()),
    StructField('created_utc', DoubleType()),
    StructField('author', StringType())
])

raw_lines = (spark
    .readStream
    .format('socket')
    .option('host', HOST)
    .option('port', PORT)
    .load())

json_df = raw_lines.select(F.from_json(F.col('value'), schema).alias('data')).select('data.*')


In [None]:
# Write incoming data to an in-memory table and to Parquet files
query_memory = (json_df
    .writeStream
    .outputMode('append')
    .format('memory')
    .queryName('raw')
    .start())

query_files = (json_df
    .writeStream
    .outputMode('append')
    .format('parquet')
    .option('path', 'data/raw')
    .option('checkpointLocation', 'chk/raw')
    .start())


In [None]:
# Helper expressions for counting references
user_refs = F.expr("regexp_extract_all(text, '/u/[^\s]+')")
subreddit_refs = F.expr("regexp_extract_all(text, '/r/[^\s]+')")
url_refs = F.expr("regexp_extract_all(text, 'https?://[^\s]+')")

refs_df = json_df.select(
    F.col('created_utc').cast('timestamp').alias('created_ts'),
    F.size(user_refs).alias('user_ref_count'),
    F.size(subreddit_refs).alias('subreddit_ref_count'),
    F.size(url_refs).alias('url_ref_count')
)

windowed_refs = (refs_df
    .withWatermark('created_ts', '1 minute')
    .groupBy(F.window('created_ts', '60 seconds', '5 seconds'))
    .sum('user_ref_count', 'subreddit_ref_count', 'url_ref_count')
)

ref_query = (windowed_refs
    .writeStream
    .outputMode('update')
    .format('console')
    .option('truncate', False)
    .start())


In [None]:
# TF-IDF top words based on history stored in table 'raw'
from pyspark.ml.feature import Tokenizer, StopWordsRemover, HashingTF, IDF

def compute_tfidf():
    raw_df = spark.sql('select * from raw')
    tokenizer = Tokenizer(inputCol='text', outputCol='words')
    words_data = tokenizer.transform(raw_df)
    remover = StopWordsRemover(inputCol='words', outputCol='filtered')
    filtered = remover.transform(words_data)
    hashingTF = HashingTF(inputCol='filtered', outputCol='rawFeatures', numFeatures=10000)
    featurized = hashingTF.transform(filtered)
    idf = IDF(inputCol='rawFeatures', outputCol='features')
    idf_model = idf.fit(featurized)
    tfidf = idf_model.transform(featurized)
    zipped = tfidf.select(F.explode(F.arrays_zip('filtered', 'features')).alias('z'))
    word_scores = zipped.select(F.col('z.filtered').alias('word'), F.col('z.features').alias('score'))
    top_words = word_scores.groupBy('word').agg(F.max('score').alias('score')).orderBy(F.desc('score')).limit(10)
    top_words.show(truncate=False)

compute_tfidf()


In [None]:
# Time range and sentiment analysis for each batch
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType

@udf(returnType=DoubleType())
def sentiment_udf(text):
    return TextBlob(text).sentiment.polarity if text else 0.0

def process_batch(batch_df, epoch_id):
    batch_df.cache()
    batch_df.createOrReplaceTempView('raw')
    time_bounds = spark.sql('SELECT MIN(created_utc) as min_ts, MAX(created_utc) as max_ts FROM raw')
    time_bounds.show()
    sentiments = batch_df.withColumn('sentiment', sentiment_udf('text'))
    avg_sent = sentiments.agg(F.avg('sentiment')).collect()[0][0]
    print(f'Average sentiment: {avg_sent}')

process_query = (json_df
    .writeStream
    .foreachBatch(process_batch)
    .start())


In [None]:
# Await termination of all streams
for q in [query_memory, query_files, ref_query, process_query]:
    q.awaitTermination()
