In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import expr, concat
import findspark
import logging
import time

findspark.init()

# Setup basic configuration for logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def log_time_taken(start, operation):
    end = time.time()
    logger.info(f"{operation} completed in {end - start:.2f} seconds")

# Start timing and log the initialization of the Spark session
logger.info("Initializing Spark session with optimized memory settings")
start_time = time.time()
spark = SparkSession.builder \
    .appName("Reddit Comment Context Builder") \
    .master("local[*]")  \
    .config("spark.executor.memory", "64g")  \
    .config("spark.driver.memory", "32g")  \
    .config("spark.executor.memoryOverhead", "4096") \
    .config("spark.driver.memoryOverhead", "2048")  \
    .config("spark.driver.maxResultSize", "8g") \
    .config("spark.driver.extraClassPath", "/Volumes/LaCie/wsb_archive/postgresql-42.7.3.jar") \
    .config("spark.driver.extraJavaOptions", "-XX:+UseG1GC") \
    .config("spark.executor.extraJavaOptions", "-XX:+UseG1GC") \
    .getOrCreate()
log_time_taken(start_time, "SparkSession initialization")

2024-04-03 06:57:41,081 - INFO - Initializing Spark session with optimized memory settings


24/04/03 06:57:42 WARN Utils: Your hostname, Binmings-iMac-5.local resolves to a loopback address: 127.0.0.1; using 192.168.1.69 instead (on interface en1)
24/04/03 06:57:42 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


24/04/03 06:57:42 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


2024-04-03 06:57:43,666 - INFO - SparkSession initialization completed in 2.58 seconds


```
# CREATE TABLE wsb_comments (
#     datetime_utc TIMESTAMP WITH TIME ZONE NOT NULL,
#     comment_id VARCHAR(20) PRIMARY KEY,
#     submission_id VARCHAR(20) NOT NULL,
#     parent_id VARCHAR(20) NULL, -- Can be NULL for top-level comments
#     distinguished TEXT NULL, -- Can be NULL if the comment is not distinguished
#     archived BOOLEAN NOT NULL,
#     edited BOOLEAN NOT NULL,
#     ups INT NOT NULL,
#     downs INT NOT NULL,
#     controversiality INT NOT NULL CHECK (controversiality IN (0, 1)),
#     comment_score INT NOT NULL,
#     comment_body TEXT NULL -- Can be NULL for deleted or removed comments
# );

# -- Indexes to improve query performance
# CREATE INDEX idx_submission_id ON wsb_comments (submission_id);
# CREATE INDEX idx_parent_id ON wsb_comments (parent_id);
```

------------------------------------------------------------------------------------------------------------------------------------------

```
# CREATE TABLE wsb_submissions (
#     datetime_utc TIMESTAMP WITH TIME ZONE,
#     submission_id VARCHAR(10) PRIMARY KEY,
#     url TEXT,
#     title TEXT,
#     self_text TEXT,
#     is_self BOOLEAN,
#     num_comments INT,
#     likes INT,
#     downs INT,
#     ups INT,
#     post_score INT,
#     distinguished TEXT,
#     edited BOOLEAN,
#     author TEXT,
#     over_18 BOOLEAN
# );
```

In [2]:
wsb_comments = spark.read.parquet("./wsb_comments")
wsb_comments = wsb_comments.orderBy("datetime_utc")

In [3]:
# shape
print((wsb_comments.count(), len(wsb_comments.columns)))



(77587930, 12)


                                                                                

In [4]:
wsb_comments.show()



+-------------------+----------+-------------+----------+-------------+--------+------+---+-----+----------------+-------------+--------------------+
|       datetime_utc|comment_id|submission_id| parent_id|distinguished|archived|edited|ups|downs|controversiality|comment_score|        comment_body|
+-------------------+----------+-------------+----------+-------------+--------+------+---+-----+----------------+-------------+--------------------+
|2012-04-11 09:46:43|   c4b0pvu|     t3_s4jw1|  t3_s4jw1|         null|    true|  true|  2|    0|               0|            2|This is a fantast...|
|2012-04-11 10:12:16|   c4b127p|     t3_s4jw1|  t3_s4jw1|         null|    true| false|  1|    0|               0|            1|           [deleted]|
|2012-04-11 10:39:08|   c4b1fpf|     t3_s4jw1|  t3_s4jw1|         null|    true| false|  2|    0|               0|            2|     INTC is on 4/17|
|2012-04-11 11:02:31|   c4b1rmm|     t3_s4jw1|  t3_s4jw1|         null|    true| false|  1|    0|   

                                                                                

In [5]:
wsb_submissions = spark.read.parquet("./wsb_submissions")
wsb_submissions = wsb_submissions.orderBy("datetime_utc")

In [6]:
# shape
print((wsb_submissions.count(), len(wsb_submissions.columns)))

(2349120, 15)


In [7]:
wsb_submissions.show()

                                                                                

+-------------------+-------------+--------------------+--------------------+--------------------+-------+------------+-----+-----+---+----------+-------------+------+----------------+-------+
|       datetime_utc|submission_id|                 url|               title|           self_text|is_self|num_comments|likes|downs|ups|post_score|distinguished|edited|          author|over_18|
+-------------------+-------------+--------------------+--------------------+--------------------+-------+------------+-----+-----+---+----------+-------------+------+----------------+-------+
|2012-04-11 09:40:40|        s4jw1|http://www.reddit...|Earnings season i...|I know that /r/in...|   true|          22| null|    3| 16|        13|         null| false|       [deleted]|  false|
|2012-04-12 13:37:31|        s6r57|http://www.bloomb...|GOOG - beat estim...|                    |  false|           0| null|    3|  5|         2|         null| false|       [deleted]|  false|
|2012-04-16 15:29:37|        sd5ai|

In [26]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, when, concat_ws, lit, expr, greatest

def build_context_chain(comments: DataFrame, submissions: DataFrame, max_depth: int = None) -> DataFrame:
    # Alias submissions and comments with unique column names
    submissions_kv = submissions.select(
        expr("submission_id as s_key"),
        concat_ws(" ", col("title"), col("self_text")).alias("s_value")
    )

    # Including submission_id in the comments_kv DataFrame and introducing the reached_top flag
    comments_kv = comments.select(
        expr("comment_id as c_key"),
        col("parent_id").alias("c_parent_id"),
        col("comment_body").alias("c_value"),
        lit(False).alias("reached_top"),  # Initial reached_top flag set to False
    ).withColumn("curr_parent_id", col("c_parent_id"))  # Initialize curr_parent_id

    # Initialize context with the comment itself
    context_df = comments_kv.withColumn("context", col("c_value"))

    i = 1
    while True and (max_depth is None or i <= max_depth):
        comments_iter = comments_kv.alias(f"c{i}")
        
        context_df = context_df.join(
            comments_iter,
            context_df["curr_parent_id"] == expr(f"concat('t1_', c{i}.c_key)"),
            "left_outer"
        ).join(
            submissions_kv,
            context_df["curr_parent_id"] == expr(f"concat('t3_', s_key)"),
            "left_outer"
        ).select(
            context_df["c_key"],
            when(
                context_df["curr_parent_id"].startswith("t3_"), 
                concat_ws(" |->| ", context_df["context"], submissions_kv["s_value"])
            ).when(
                context_df["curr_parent_id"].startswith("t1_"), 
                concat_ws(" |->| ", context_df["context"], col(f"c{i}.c_value"))
            ).otherwise(context_df["context"]).alias("context"),
            # Update curr_parent_id based on the join result
            when(context_df["curr_parent_id"].startswith("t1_"), col(f"c{i}.c_parent_id")).otherwise(context_df["curr_parent_id"]).alias("curr_parent_id"),
            # Update reached_top flag
            when(context_df["curr_parent_id"].startswith("t3_"), lit(True)).otherwise(context_df["reached_top"]).alias("reached_top")
        )

        # Check if all rows have reached the top; if so, break the loop
        if context_df.filter(col("reached_top") == False).count() == 0:
            break

        i += 1

    # Final join with original comments DataFrame to include additional details
    final_df = comments.join(
        context_df,
        comments["comment_id"] == context_df["c_key"]
    ).select(
        comments["datetime_utc"], comments["comment_id"], comments["submission_id"], 
        comments["parent_id"], comments["comment_score"], comments["comment_body"], 
        context_df["curr_parent_id"],
        context_df["context"].alias("comment_context"), context_df["reached_top"]
    )

    return final_df

# Assuming wsb_comments and wsb_submissions DataFrames are already defined
wsb_comments_with_context = build_context_chain(wsb_comments, wsb_submissions, 5)
# Show the result
wsb_comments_with_context.show(truncate=False)

[Stage 42510:>                                                      (0 + 1) / 1]]

+-------------------+----------+-------------+----------+-------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                

In [11]:
print((wsb_comments_with_context.count(), len(wsb_comments_with_context.columns)))



(77587930, 7)


                                                                                

In [27]:
wsb_comments_with_context = wsb_comments_with_context.orderBy("datetime_utc")

In [28]:
wsb_comments_with_context.show()



+-------------------+----------+-------------+----------+-------------+--------------------+--------------+--------------------+-----------+
|       datetime_utc|comment_id|submission_id| parent_id|comment_score|        comment_body|curr_parent_id|     comment_context|reached_top|
+-------------------+----------+-------------+----------+-------------+--------------------+--------------+--------------------+-----------+
|2012-04-11 09:46:43|   c4b0pvu|     t3_s4jw1|  t3_s4jw1|            2|This is a fantast...|      t3_s4jw1|This is a fantast...|       true|
|2012-04-11 10:12:16|   c4b127p|     t3_s4jw1|  t3_s4jw1|            1|           [deleted]|      t3_s4jw1|[deleted] |->| Ea...|       true|
|2012-04-11 10:39:08|   c4b1fpf|     t3_s4jw1|  t3_s4jw1|            2|     INTC is on 4/17|      t3_s4jw1|INTC is on 4/17 |...|       true|
|2012-04-11 11:02:31|   c4b1rmm|     t3_s4jw1|  t3_s4jw1|            1|straddle, call, s...|      t3_s4jw1|straddle, call, s...|       true|
|2012-04-11 1

                                                                                

In [30]:
wsb_comments_with_context.write.parquet("wsb_comments_with_context")

                                                                                