In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import expr, concat
import findspark
import logging
import time

findspark.init()

# Setup basic configuration for logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def log_time_taken(start, operation):
    end = time.time()
    logger.info(f"{operation} completed in {end - start:.2f} seconds")

# Start timing and log the initialization of the Spark session
logger.info("Initializing Spark session with optimized memory settings")
start_time = time.time()
spark = SparkSession.builder \
    .appName("Reddit Comment Context Builder") \
    .master("local[*]")  \
    .config("spark.executor.memory", "64g")  \
    .config("spark.driver.memory", "32g")  \
    .config("spark.executor.memoryOverhead", "4096") \
    .config("spark.driver.memoryOverhead", "2048")  \
    .config("spark.driver.maxResultSize", "8g") \
    .config("spark.driver.extraClassPath", "/Volumes/LaCie/wsb_archive/postgresql-42.7.3.jar") \
    .config("spark.driver.extraJavaOptions", "-XX:+UseG1GC") \
    .config("spark.executor.extraJavaOptions", "-XX:+UseG1GC") \
    .getOrCreate()
log_time_taken(start_time, "SparkSession initialization")

2024-04-03 18:26:55,167 - INFO - Initializing Spark session with optimized memory settings
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


24/04/03 18:26:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/03 18:26:57 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


2024-04-03 18:26:57,576 - INFO - SparkSession initialization completed in 2.41 seconds


### The dataset looks like:

In [2]:
wsb_comments_with_context = spark.read.parquet("./wsb_comments_with_context")

                                                                                

In [3]:
import yfinance as yf
import re

class CompanyNameSimplifier:
    def __init__(self):
        self.suffixes = [
            'Inc.', 'Inc', 'Corporation', 'Corp.', 'Corp', 'Company', 'Co.', 'Co', 
            'Limited', 'Ltd.', 'Ltd', ' PLC', ' NV', ' SA', ' AG', ' LLC', ' L.P.', ' LP'
        ]
        # Adjusted to remove web domains in any part of the name before comma, period, or space
        self.web_domains_regex = r'\.com|\.org|\.net|\.io|\.co|\.ai'

    def simplify_company_name(self, name):
        # Remove web domain suffixes using regular expression first
        name = re.sub(self.web_domains_regex, '', name, flags=re.IGNORECASE)

        # Remove any company suffix from the list
        for suffix in self.suffixes:
            if name.endswith(suffix):
                name = name.replace(suffix, '')
                break
        
        # Additional cleanup: remove anything after a comma or dash
        name = re.split(',| -', name)[0]

        # Strip leading and trailing whitespace
        name = name.strip()

        return name

    def get_simplified_company_name(self, ticker):
        # Fetch the company info using yfinance
        company = yf.Ticker(ticker)
        company_info = company.info
        
        # Extract the long name
        full_name = company_info.get('longName', '')
        
        # Simplify the name
        simple_name = self.simplify_company_name(full_name)
        
        return simple_name

2024-04-03 18:27:00,491 - INFO - Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-04-03 18:27:00,492 - INFO - NumExpr defaulting to 8 threads.


In [4]:
def spark_df_shape(df):
    """
    Calculates the shape of a DataFrame more efficiently by writing to and reading from a temporary file,
    and then deletes the temporary file.
    
    Args:
    - df: The Spark DataFrame whose shape is to be calculated.
    
    Returns:
    Tuple representing the shape (number of rows, number of columns) of the DataFrame.
    """
    # Define path for the temporary file
    temp_path = "./temp"
    
    # Write the DataFrame to a temporary location. Parquet is a good choice for efficiency.
    df.write.mode("overwrite").parquet(temp_path)
    
    # Read back the DataFrame, potentially just one column to speed up the count operation if needed
    # If the column count is dynamic, consider removing .select()
    temp_df = spark.read.parquet(temp_path)
    
    # Use the count function on the simplified DataFrame
    num_rows = temp_df.count()
    num_columns = len(df.columns)
    
    # Cleanup: Delete the temporary directory and its contents
    try:
        sc = spark.sparkContext
        path = sc._jvm.org.apache.hadoop.fs.Path(temp_path)
        fs = path.getFileSystem(sc._jvm.org.apache.hadoop.conf.Configuration())
        if fs.exists(path):
            fs.delete(path, True)  # True for recursive delete
    except Exception as e:
        print(f"Failed to delete temporary files at {temp_path}: {e}")

    # Return the shape
    return (num_rows, num_columns)

In [5]:
# Example usage
simplifier = CompanyNameSimplifier()
ticker_symbols = ['AAPL', 'NVDA', 'TSLA']
for ticker in ticker_symbols:
    print(f"{ticker}: {simplifier.get_simplified_company_name(ticker)}")

AAPL: Apple
NVDA: NVIDIA
TSLA: Tesla


In [6]:
from pyspark.sql.functions import col, lower

def filter_comments_by_ticker(df, ticker):
    simplifier = CompanyNameSimplifier()
    # Obtain the simplified company name for the given ticker
    company_name = simplifier.get_simplified_company_name(ticker)
    
    # Convert the ticker and company name to lowercase for a case-insensitive search
    ticker_lower = ticker.lower()
    company_name_lower = company_name.lower()

    # Filter the DataFrame for rows where the `comment_context` contains the ticker or the company name
    # Uses `lower` function to ensure that the search is case-insensitive
    filtered_df = df.filter(
        lower(col("comment_context")).contains(ticker_lower) | 
        lower(col("comment_context")).contains(company_name_lower)
    ).select("datetime_utc", "comment_score", "comment_body")

    return filtered_df

In [7]:
AAPL_comments = filter_comments_by_ticker(wsb_comments_with_context, 'AAPL').persist()

In [8]:
AAPL_comments.show()

[Stage 1:>                                                          (0 + 1) / 1]

+-------------------+-------------+--------------------+
|       datetime_utc|comment_score|        comment_body|
+-------------------+-------------+--------------------+
|2012-04-11 09:46:43|            2|This is a fantast...|
|2012-04-11 10:12:16|            1|           [deleted]|
|2012-04-11 10:39:08|            2|     INTC is on 4/17|
|2012-04-11 11:02:31|            1|straddle, call, s...|
|2012-04-11 11:47:11|            6|GMCR falls, GOOG ...|
|2012-04-11 12:44:33|            1|CROX 4/26\n\nBZH ...|
|2012-04-11 13:02:56|            1|Shorting GOOG and...|
|2012-04-11 13:16:44|            2|I'm looking at CJ...|
|2012-04-11 13:48:27|            2|GRPN earnings rep...|
|2012-04-11 13:54:48|            1|           [deleted]|
|2012-04-11 13:59:34|            2|I'm long CJES. Se...|
|2012-04-11 14:01:27|            1|Shorting FB for t...|
|2012-04-11 14:02:43|            1|BAC 4/19 betting ...|
|2012-04-11 14:19:43|            1|Shorting AAPL ahe...|
|2012-04-11 14:55:43|          

                                                                                

In [9]:
# spark_df_shape(AAPL_comments)

In [10]:
# AAPL_comments.write.parquet('./stock_comments/AAPL_comments')

In [11]:
spark.stop()