In [None]:
#!usr/bin/env python3
import pandas as pd
from pathlib import Path
import praw
from tqdm import tqdm
import pandas as pd
import yfinance as yf

In [None]:
root = Path.cwd()
compressed_data_root = root / 'posts'
reddit_csv_path = root.parent / '10y_reddit_data.csv'
stock_csv_path = root.parent / '10y_stock_data.csv'

subreddits = ["stocks", "StockMarket", "investing", "wallstreetbets", "options", "trading"]

stock_dict = {
    'nvidia': ['nvda', 'nvidia'],
    'tesla': ['tsla', 'tesla'],
    'apple': ['aapl', 'apple'],
    'amazon': ['amzn', 'amazon'],
    'microsoft': ['msft', 'microsoft'],
    'google': ['googl', 'google', 'alphabet']
}

tickers = ['nvda', 'tsla', 'aapl', 'amzn', 'msft', 'googl']

praw_api = praw.Reddit(
            client_id="5uFqCBUPadVnIxKHG0hnhw",
            client_secret="LDFyai0bjEAkEQqo5joU7PjSjtq2eQ",
            user_agent="TeslaScraper:v1.0 (by u/RecognitionSame5433)"
        )

In [None]:
class RedditData:
    def __init__(self, years, subreddits=subreddits, stock_dict=stock_dict, praw_api=praw_api):
        self.subreddits = subreddits
        self.stock_dict = stock_dict
        self.api = praw_api
        self.years = years
        self.df = None

    def search_subreddit(self, subreddit_name:str, query:str):
        subreddit = self.api.subreddit(subreddit_name)
        gen = subreddit.search(query, sort='new', time_filter='year', limit=1000)
        rows = []
        for s in tqdm(gen, desc="pulling submissions"):
            rows.append({
                    "id": s.id,
                    "created_utc": pd.to_datetime(s.created_utc, unit="s"),
                    "ticker" : query,
                    "subreddit": str(s.subreddit),
                    "author": str(s.author) if s.author else None,
                    "title": s.title,
                    "selftext": s.selftext,
                    "score": s.score,
                    "upvote_ratio": s.upvote_ratio,
                    "num_comments": s.num_comments,
                })
        return pd.DataFrame(rows)

    # Get year of data for all subreddits and queries
    def load_api_data(self):
        df_list = []
        for stock, query_list in self.stock_dict.items():
            for query in query_list:
                for sub in self.subreddits:
                    df = self.search_subreddit(sub, query)
                    df['query'] = query
                    df['stock'] = stock
                    df['subreddit'] = sub
                    df_list.append(df)
        new_df = pd.concat(df_list)
        self.add_data(new_df)

    # Parses all data for 
    def load_compressed_ndjson(self, root: Path):
        df_list = []
        for file_path in root.rglob('*.zst'):
            df_list.append(pd.read_json(
                file_path,
                lines=True,
                compression='zstd'
            ))
        new_df = pd.concat(df_list)
        self.add_data(new_df)

    def add_data(self, df):
        df = RedditData.drop_duplicates(df)
        if self.df is None:
            self.df = df
        else:
            new_df = pd.concat([self.df, df])
            overlap = len(new_df.duplicated())
            if overlap == 0:
                raise ValueError('Time ranges do not overlap')
            print(f'Dataframes overlap of {overlap} rows')
            self.df = RedditData.drop_duplicates(new_df)

    @staticmethod
    def restrict_time(df, years):
        if df is None:
            raise ValueError('No data loaded')
        cutoff = pd.Timestamp.today() - pd.DateOffset(years=years)
        df = df[df. >= cutoff]

    # Remove duplicates for (id, stock) and enforce timeframe
    @staticmethod
    def drop_duplicates(df):
        if df is None:
            raise ValueError('Dataframe is None')
        df.drop_duplicates(subset=['id', 'stock'], inplace=True)


class StockData:
    def __init__(self, tickers, period='10y', interval='1d'):
        self.tickers = tickers
        df = yf.download(tickers, period=period, interval=interval, auto_adjust=True)
        if df is None:
            raise ValueError('Data failed to download')
        self.df = df

    # Forward fill OHLC and zero Volume
    def impute_off_days(self):
        full_idx = pd.date_range(self.df.index.min(), self.df.index.max(), freq='D')
        self.df.reindex(full_idx)
        self.df['Volume'] = self.df['Volume'].fillna(0)
        price_cols = ['Open', 'High', 'Low', 'Close']
        self.df[price_cols] = self.df[price_cols].ffill()

    # Create percent change columns for prices
    def create_pct_columns(self):
        price_cols = ['Open', 'High', 'Low', 'Close']
        new_cols = [x+'_pct' for x in price_cols]
        self.df[new_cols] = self.df[price_cols].pct_change()

In [None]:
# Create reddit csv
    reddit_data = RedditData(years=10)
    reddit_data.load_compressed_ndjson(compressed_data_root)
    reddit_data.load_api_data()
    assert isinstance(reddit_data.df, pd.DataFrame)
    reddit_data.df.to_csv(reddit_csv_path, encoding='utf-8', index=True, header=True, sep=',')

    # Create stock csv
    stock_data = StockData(tickers, period='10y')
    stock_data.impute_off_days()
    stock_data.create_pct_columns()
    stock_data.df.to_csv(stock_csv_path, encoding='utf-8', index=True, header=True, sep=',')