# Reddit Spark Streaming Consumer
This notebook receives Reddit posts/comments from a socket, stores them to a Spark table, and computes metrics such as reference counts, TF-IDF top words, and sentiment analysis.

In [15]:
!pip install textblob

[0m

In [16]:
import json
import re
from textblob import TextBlob
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, DoubleType
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType
from pyspark.ml.feature import Tokenizer, StopWordsRemover, HashingTF, IDF
import os

### Set-up of Spark Stream Consumer and Data Schema structure.
##### See command to initialize spark server inside code cell. 

In [17]:
spark = SparkSession.builder.appName('RedditConsumer').getOrCreate()

# Command to run spark server on docker to plug into kernel for running notebook
# docker run -it -p 4040:4040 -p 8080:8080 -p 8081:8081 -p 8888:8888 -p 5432:5432 --cpus=2 --memory=2048m -h spark -w /mnt/host_home/ pyspark_container jupyter-lab --ip 0.0.0.0 --port 8888 --no-browser --allow-root

HOST = os.getenv("PRODUCER_HOST", "host.docker.internal")
PORT = int(os.getenv("PRODUCER_PORT", "9998"))

print("→ Connecting to producer at", HOST, PORT)

schema = StructType([
    StructField('type', StringType()),
    StructField('subreddit', StringType()),
    StructField('id', StringType()),
    StructField('text', StringType()),
    StructField('created_utc', DoubleType()),
    StructField('author', StringType())
])

raw_lines = (spark
    .readStream
    .format('socket')
    .option('host', HOST)
    .option('port', PORT)
    .load())

json_df = raw_lines.select(F.from_json(F.col('value'), schema).alias('data')).select('data.*')

→ Connecting to producer at host.docker.internal 9998


25/06/08 11:48:04 WARN TextSocketSourceProvider: The socket source should not be used for production applications! It does not support recovery.


### Write data to memory

In [None]:
query_memory = (json_df
    .writeStream
    .outputMode('append')
    .format('memory')
    .queryName('raw')
    .start())

25/06/08 11:48:08 WARN ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /tmp/temporary-8e757b5b-d967-4f14-82ce-b844b67fecae. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
25/06/08 11:48:08 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.


25/06/08 11:48:08 ERROR MicroBatchExecution: Query raw [id = 922c9fb4-7c13-4f56-ab55-2c0f77e916ff, runId = a7ae152b-508c-414d-b560-46681f5fe617] terminated with error
java.net.ConnectException: Connection refused
	at java.base/sun.nio.ch.Net.connect0(Native Method)
	at java.base/sun.nio.ch.Net.connect(Net.java:579)
	at java.base/sun.nio.ch.Net.connect(Net.java:568)
	at java.base/sun.nio.ch.NioSocketImpl.connect(NioSocketImpl.java:593)
	at java.base/java.net.SocksSocketImpl.connect(SocksSocketImpl.java:327)
	at java.base/java.net.Socket.connect(Socket.java:633)
	at java.base/java.net.Socket.connect(Socket.java:583)
	at java.base/java.net.Socket.<init>(Socket.java:507)
	at java.base/java.net.Socket.<init>(Socket.java:287)
	at org.apache.spark.sql.execution.streaming.sources.TextSocketMicroBatchStream.initialize(TextSocketMicroBatchStream.scala:71)
	at org.apache.spark.sql.execution.streaming.sources.TextSocketMicroBatchStream.planInputPartitions(TextSocketMicroBatchStream.scala:117)
	at 

### Get reference to users, subreddits and URLs

In [5]:
df2 = (
    json_df
    .withColumn(
        "user_refs",
        F.expr(r"regexp_extract_all(text, '/u/[^\s]+')")
    )
    .withColumn(
        "subreddit_refs",
        F.expr(r"regexp_extract_all(text, '/r/[^\s]+')")
    )
    .withColumn(
        "url_refs",
        F.expr(r"regexp_extract_all(text, 'https?://[^\s]+')")
    )
)

### Create dataframes of references on a sliding window basis.

In [6]:
# get the count of each type of reference and tag them with a created timestamp
# for time based filtering and aggregation
refs_df = df2.select(
    F.col("created_utc").cast("timestamp").alias("created_ts"),
    F.size(F.col("user_refs")).alias("user_ref_count"),
    F.size(F.col("subreddit_refs")).alias("subreddit_ref_count"),
    F.size(F.col("url_refs")).alias("url_ref_count")
)

In [7]:
# get the total references per time window (60 seconds with a 5 second slide)
windowed_refs = (refs_df
    .withWatermark('created_ts', '1 minute')
    .groupBy(F.window('created_ts', '60 seconds', '5 seconds'))
    .sum('user_ref_count', 'subreddit_ref_count', 'url_ref_count')
)

In [8]:
# Write the windowed reference counts to an in-memory table and to Parquet files
ref_query = (windowed_refs
    .writeStream
    .outputMode('update')
    .format('console')
    .option('truncate', False)
    .start())

25/06/08 11:47:31 WARN ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /tmp/temporary-dd3d928f-cf36-4fa5-9e56-e3dce6d43668. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
25/06/08 11:47:31 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.


[Stage 48:>                                                       (2 + 2) / 200]

### Extract reference counts

In [None]:
def extract_reference_counts(batch_df):
    refs = (batch_df.select(
                F.regexp_extract_all('text', r'/u/\w+').alias('users'),
                F.regexp_extract_all('text', r'/r/\w+').alias('subs'),
                F.regexp_extract_all('text', r'https?://[^\s]+').alias('urls'))
            .select(
                F.size('users').alias('user_refs'),
                F.size('subs').alias('sub_refs'),
                F.size('urls').alias('url_refs')))
    refs_summary = refs.groupBy().sum('user_refs', 'sub_refs', 'url_refs')
    return refs_summary



### Function to compute TF-IDF and find top 10 most important words in the window.

In [None]:
def compute_tf_idf(batch_df):
    # tokenize the text, remove stopworrds and apply featurization through hashing
    tokenizer = Tokenizer(inputCol='text', outputCol='words')
    words = tokenizer.transform(batch_df)
    remover = StopWordsRemover(inputCol='words', outputCol='filtered')
    filtered = remover.transform(words)
    hashingTF = HashingTF(inputCol='filtered', outputCol='rawFeatures', numFeatures=10000)
    featurized = hashingTF.transform(filtered)

    # compute the tf-idf scores
    idf = IDF(inputCol='rawFeatures', outputCol='features')
    idf_model = idf.fit(featurized)
    tfidf = idf_model.transform(featurized)

    # extract the top words with their scores
    zipped = tfidf.select(F.explode(F.arrays_zip('filtered','features')).alias('z'))
    word_scores = zipped.select(F.col('z.filtered').alias('word'), F.col('z.features').alias('score'))
    top_words = word_scores.groupBy('word').agg(F.max('score').alias('score')).orderBy(F.desc('score')).limit(10)
    return top_words



                                                                                

-------------------------------------------
Batch: 0
-------------------------------------------
+------+-------------------+------------------------+------------------+
|window|sum(user_ref_count)|sum(subreddit_ref_count)|sum(url_ref_count)|
+------+-------------------+------------------------+------------------+
+------+-------------------+------------------------+------------------+



### TextBlob function to achieve sentiment analysis of text.

In [11]:
@udf(returnType=DoubleType())
def sentiment_udf(text):
    return TextBlob(text).sentiment.polarity if text else 0.0

#### Batch Processing of Streaming Data.
- TODO:
    - Requires references in window created previously
    - Requires top 10 words in TF-IDF
    - Write data to processed memory

In [12]:
def process_batch(batch_df):
    batch_df.cache()
    batch_df.createOrReplaceTempView('current_batch')
    count = batch_df.count()
    print(f'Processing batch with {count} records')

    refs_summary = extract_reference_counts(batch_df)
    refs_summary.show(truncate=False)
    refs_summary.createOrReplaceTempView('batch_references')

    top_words = compute_tf_idf(batch_df)
    top_words.show(truncate=False)
    top_words.createOrReplaceTempView('batch_tfidf')

    sentiment_scores = batch_df.withColumn('sentiment', sentiment_udf(F.col('text')))
    sentiment_scores.createOrReplaceTempView('batch_sentiment')


In [13]:
process_query = (
    json_df
    .writeStream
    .format("console")
    .option("truncate", False)
    .foreachBatch(process_batch)
    .start()
)

25/06/08 11:47:50 WARN ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /tmp/temporary-f6060d47-cb7e-4eaf-ab43-bda3791ae1d0. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
25/06/08 11:47:50 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.


25/06/08 11:47:50 ERROR MicroBatchExecution: Query [id = f7b16dac-d690-4a74-a8a5-4f550a652cb5, runId = 47d7ff71-189b-4db7-b7b5-5669c111bf6c] terminated with error
java.net.ConnectException: Connection refused
	at java.base/sun.nio.ch.Net.connect0(Native Method)
	at java.base/sun.nio.ch.Net.connect(Net.java:579)
	at java.base/sun.nio.ch.Net.connect(Net.java:568)
	at java.base/sun.nio.ch.NioSocketImpl.connect(NioSocketImpl.java:593)
	at java.base/java.net.SocksSocketImpl.connect(SocksSocketImpl.java:327)
	at java.base/java.net.Socket.connect(Socket.java:633)
	at java.base/java.net.Socket.connect(Socket.java:583)
	at java.base/java.net.Socket.<init>(Socket.java:507)
	at java.base/java.net.Socket.<init>(Socket.java:287)
	at org.apache.spark.sql.execution.streaming.sources.TextSocketMicroBatchStream.initialize(TextSocketMicroBatchStream.scala:71)
	at org.apache.spark.sql.execution.streaming.sources.TextSocketMicroBatchStream.planInputPartitions(TextSocketMicroBatchStream.scala:117)
	at org.

In [14]:
spark.stop()

### Terminating all streams.

In [None]:
# Await termination of all streams
for q in [query_memory, ref_query, process_query]:
    q.awaitTermination()