In [43]:
#!usr/bin/env python3
import pandas as pd
from pathlib import Path
import praw
from tqdm.notebook import tqdm
import pandas as pd
import yfinance as yf
import re

# Options

In [44]:
root = Path.cwd()
compressed_data_root = root / 'posts'
reddit_csv_path = root.parent / '10y_reddit_data.csv'
stock_csv_path = root.parent / '10y_stock_data.csv'

years=10

subreddits = ["stocks", "StockMarket", "investing", "wallstreetbets", "options", "trading"]

stock_dict = {
    'nvidia': ['nvda', 'nvidia'],
    'tesla': ['tsla', 'tesla'],
    'apple': ['aapl', 'apple'],
    'amazon': ['amzn', 'amazon'],
    'microsoft': ['msft', 'microsoft'],
    'google': ['googl', 'google', 'alphabet']
}

tickers = ['nvda', 'tsla', 'aapl', 'amzn', 'msft', 'googl']

praw_api = praw.Reddit(
            client_id="5uFqCBUPadVnIxKHG0hnhw",
            client_secret="LDFyai0bjEAkEQqo5joU7PjSjtq2eQ",
            user_agent="TeslaScraper:v1.0 (by u/RecognitionSame5433)"
        )

# Reddit data fields
- id
- created_utc
- timestamp
- author
- title
- selftext
- score
- num_comments
- query (only for api data)
- stock
- subreddit
- source (api or archive)

# Stock data fields per ticker
- Open
- High
- Low
- Close
- Volume
- Open_pct
- High_pct
- Low_pct
- Close_pct
- Revenue
- Earnings
- Revenue
- Earnings
- Revenue_pct
- Earnings_pct

In [None]:
class RedditData:
    def __init__(self, years, subreddits, stock_dict, praw_api):
        self.subreddits = subreddits
        self.stock_dict = stock_dict
        self.api = praw_api
        self.timedelta = pd.Timedelta(days=365 * years)
        self.df = pd.DataFrame()

    def search_subreddit(self, subreddit_name:str, query:str):
        subreddit = self.api.subreddit(subreddit_name)
        gen = subreddit.search(query, sort='new', time_filter='year', limit=1000)
        rows = []
        for s in gen:
            rows.append({
                    "id": s.id,
                    "created_utc" : s.created_utc,
                    "timestamp": pd.to_datetime(s.created_utc, unit='s', utc=True),
                    "author": str(s.author) if s.author else None,
                    "title": s.title,
                    "selftext": s.selftext,
                    "score": s.score,
                    "num_comments": s.num_comments,
                })
        return pd.DataFrame(rows)

    # Get year of data for all subreddits and queries
    def load_api_data(self):
        df_list = []
        all_query_pairs = [(stock, query)
                           for stock, query_list in self.stock_dict.values()
                           for query in query_list
                           ]
        for stock, query in tqdm(all_query_pairs, desc='Queries', position=0):
            for sub in tqdm(self.subreddits,
                            desc='Subreddits',
                            leave=False,
                            position=1):
                df = self.search_subreddit(sub, query)
                df['query'] = query
                df['stock'] = stock
                df['subreddit'] = sub
                df_list.append(df)
        nonempty_dfs = [df for df in df_list if not df.empty]
        print(f'{len(nonempty_dfs)} out of {len(df_list)} queries had hits')
        new_df = pd.concat(df_list, ignore_index=True)
        self.add_data(new_df, 'api')
        print('API data loaded')

    # Parses all compressed data root
    def load_compressed_ndjson(self, root: Path):
        df_list = []
        for file_path in root.rglob('*.zst'):
            print(f'Extracting {file_path.name}')
            subreddit = file_path.stem.removesuffix('_submissions')
            chunk_iter = pd.read_json(
                file_path,
                lines=True,
                compression='zstd',
                chunksize=2**16
            )
            df_chunks = []
            # Process and append each chunk
            for df_chunk in tqdm(chunk_iter,
                                desc='File chunks processed',
                                leave=False):
                df_chunk = self.restrict_columns_for(df_chunk)
                df_chunk['timestamp'] = pd.to_datetime(df_chunk['created_utc'], unit='s', utc=True)
                df_chunk = self.restrict_time_for(df_chunk)
                df_chunk = self.keyword_filter_for(df_chunk)
                df_chunks.append(df_chunk)
            if not df_chunks:
                raise RuntimeError('File yielded no rows')
            df = pd.concat(df_chunks, ignore_index=True)
            df = self.drop_duplicates_for(df)
            df['subreddit'] = subreddit
            df_list.append(df)
        if not df_list:
            raise RuntimeError('Failed to load any data from files')
        new_df = pd.concat(df_list, ignore_index=True)
        self.add_data(new_df, 'archive')
        print('Compressed data loaded')
    
    # Remove duplicates and check for overlap when necessary
    def add_data(self, df, source):
        df = self.drop_duplicates_for(df)
        df = self.drop_invalid_posts_for(df)
        df['source'] = source
        if self.df.empty:
            self.df = df
        else:
            new_df = pd.concat([self.df, df], ignore_index=True)
            overlap = new_df.duplicated(subset=['id', 'stock']).sum()
            if overlap == 0:
                raise ValueError('Time ranges do not overlap')
            print(f'Dataframe overlap of {overlap} rows')
            self.df = self.drop_duplicates_for(new_df)

    def keyword_filter_for(self, df):
        keyword_cols = ['title', 'selftext']
        text_df = df[keyword_cols]
        df_list = []
        for stock, query_list in self.stock_dict.items():
            pattern = '|'.join(re.escape(q) for q in query_list)
            mask = (
                text_df
                .apply(lambda col: col.astype(str).str.contains(pattern, case=False, na=False))
                .any(axis=1)
            )
            df_part = df[mask].copy()
            df_part['stock'] = stock
            df_list.append(df_part)
        if df_list:
            new_df = pd.concat(df_list, ignore_index=True)
            new_df = self.drop_duplicates_for(new_df)
        else:
            new_df = df.iloc[0:0].copy()
        return new_df
    
    def drop_invalid_posts_for(self, df):
        keywords = ['[removed]', '[deleted]']
        pattern = '|'.join([re.escape(s) for s in keywords])
        mask = (
            df.selftext.str.contains(pattern) |
            df.title.str.contains(pattern)
        )
        new_df = df[~mask]
        return new_df
            
    def restrict_columns_for(self, df):
        reddit_columns = ['id', 'created_utc', 'author', 'title', 'selftext', 'score', 'num_comments']
        new_df = df[reddit_columns].copy()
        return new_df

    def restrict_time_for(self, df):
        start_timestamp = pd.Timestamp.utcnow() - self.timedelta
        new_df = df[df.timestamp >= start_timestamp].copy()
        return new_df

    # Remove duplicates for (id, stock) and enforce timeframe
    def drop_duplicates_for(self, df):
        new_df = df.drop_duplicates(subset=['id', 'stock'], inplace=False)
        return new_df


class StockData:
    def __init__(self, tickers, years, interval='1d'):
        self.tickers = tickers
        df = yf.download(tickers, period=f'{years}y', interval=interval, auto_adjust=True)
        if df is None:
            raise ValueError('Data failed to download')
        self.df = df

    # Forward fill OHLC and zero Volume
    def impute_off_days(self):
        full_idx = pd.date_range(self.df.index.min(), self.df.index.max(), freq='D')
        self.df = self.df.reindex(full_idx)
        self.df['Volume'] = self.df['Volume'].fillna(0)
        self.df = self.df.ffill()

    # Create percent change columns for prices
    def create_pct_columns(self):
        price_fields = ['Open', 'High', 'Low', 'Close']
        price_df = self.df[price_fields]
        pct_df = price_df.pct_change()
        l0_fields = pct_df.columns.get_level_values(0)
        l1_fields = pct_df.columns.get_level_values(1)
        l0_fields = [s+'_pct' for s in l0_fields]
        pct_df.columns = pd.MultiIndex.from_arrays(
            [l0_fields, l1_fields], names=pct_df.columns.names)
        self.df = pd.concat([self.df, pct_df], axis=1).sort_index(axis=1)


# Create Reddit data

In [46]:
reddit_data = RedditData(years, subreddits, stock_dict, praw_api)

In [47]:
reddit_data.load_compressed_ndjson(compressed_data_root)

Extracting options_submissions.zst


Processing file: 0it [00:00, ?it/s]

Extracting wallstreetbets__submissions.zst


Processing file: 0it [00:00, ?it/s]

Extracting stocks_submissions.zst


Processing file: 0it [00:00, ?it/s]

Extracting Trading_submissions.zst


Processing file: 0it [00:00, ?it/s]

Extracting investing_submissions.zst


Processing file: 0it [00:00, ?it/s]

Extracting StockMarket_submissions.zst


Processing file: 0it [00:00, ?it/s]

Extracting wallstreetbets_submissions.zst


Processing file: 0it [00:00, ?it/s]

Compressed data loaded


In [48]:
reddit_data.load_api_data()

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

78 out of 78 queries had hits
Dataframe overlap of 400 rows
API data loaded


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['source'] = source


In [52]:
len(reddit_data.df)

233738

In [60]:
print((reddit_data.df['source'] == 'api').sum())
print((reddit_data.df['source'] == 'archive').sum())
api_df = reddit_data.df[reddit_data.df['source'] == 'api']
arc_df = reddit_data.df[reddit_data.df['source'] == 'archive']

8993
224745


In [61]:
print(reddit_data.df['id'].nunique())
print(api_df['id'].nunique())
print(arc_df['id'].nunique())

191177
6219
184972


In [None]:
for i in range(0, 100_000, 1_000):
    selftext = arc_df.iloc[i].selftext.lower()
    title = arc_df.iloc[i].title.lower()
    stock = arc_df.iloc[i].stock
    if len(selftext) < 30:
        # print(title)
        print(selftext)

In [91]:
keywords = ['[removed]', '[deleted]']
pattern = '|'.join([re.escape(s) for s in keywords])
print(pattern)
mask = (
    arc_df.selftext.str.contains(pattern) |
    arc_df.title.str.contains(pattern)
)
mask.sum()

\[removed\]|\[deleted\]


np.int64(64946)

In [None]:
reddit_data.df.to_csv(reddit_csv_path, encoding='utf-8', index=True, header=True, sep=',')

# Create stock data

In [None]:
stock_data = StockData(tickers, years)

In [None]:
stock_data.impute_off_days()

In [None]:
stock_data.create_pct_columns()

In [None]:
stock_data.df.to_csv(stock_csv_path, encoding='utf-8', index=True, header=True, sep=',')