In [3]:
import os
from dotenv import load_dotenv
import praw
import pandas as pd
import time
import re
from datetime import datetime
import logging
from typing import List, Dict, Set
from dataclasses import dataclass
from ratelimit import limits, sleep_and_retry

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class StockMention:
    date: datetime.date
    time: datetime.time
    subreddit: str
    title: str
    text: str
    ticker: str
    company_name: str

class RedditStockScraper:
    def __init__(self, client_id: str, client_secret: str, user_agent: str):
        """Initialize the Reddit scraper with API credentials."""
        self.reddit = praw.Reddit(
            client_id=client_id,
            client_secret=client_secret,
            user_agent=user_agent
        )
        
        self.stocks = {
            'NVDA': 'NVIDIA',
            'TSLA': 'Tesla',
            'AAPL': 'Apple',
            'MC.PA': 'LVMH'
        }
        
        self.subreddits = [
            "WallStreetBets",
            "StockMarket",
            "investing",
            "dividends",
            "cryptocurrency",
            "Investing_Discussion"
        ]

    def _validate_dates(self, start_date: str, end_date: str) -> tuple:
        """Validate and convert date strings to timestamps."""
        try:
            start_timestamp = int(datetime.strptime(start_date, '%Y-%m-%d').timestamp())
            end_timestamp = int(datetime.strptime(end_date, '%Y-%m-%d').timestamp())
            if start_timestamp > end_timestamp:
                raise ValueError("Start date must be before end date")
            return start_timestamp, end_timestamp
        except ValueError as e:
            logger.error(f"Date validation error: {e}")
            raise

    def _extract_stock_mentions(self, text: str) -> Set[str]:
        """Extract stock mentions from text using multiple patterns."""
        if not text:
            return set()

        mentions = set()
        
        # Find $TICKER mentions
        dollar_mentions = set(re.findall(r'\$([A-Z]{1,5})', text))
        
        # Find company name mentions
        company_mentions = {ticker for ticker, company in self.stocks.items() 
                          if company.lower() in text.lower()}
        
        # Find raw ticker mentions
        ticker_mentions = {ticker for ticker in self.stocks.keys() 
                         if ticker in text.upper()}
        
        mentions.update(dollar_mentions, company_mentions, ticker_mentions)
        return {m for m in mentions if m in self.stocks}

    @sleep_and_retry
    @limits(calls=60, period=60)  # Reddit API rate limit
    def _fetch_subreddit_posts(self, subreddit_name: str, start_timestamp: int, 
                             end_timestamp: int) -> List[StockMention]:
        """Fetch posts from a subreddit within the specified timeframe."""
        mentions = []
        subreddit = self.reddit.subreddit(subreddit_name)
        
        search_query = ' OR '.join(f'({ticker} OR "{company}")' 
                                 for ticker, company in self.stocks.items())
        
        try:
            for submission in subreddit.search(search_query,
                                             syntax='lucene',
                                             time_filter='all',
                                             sort='new',
                                             limit=None):
                
                if submission.created_utc < start_timestamp:
                    break
                    
                if submission.created_utc <= end_timestamp:
                    text = f"{submission.title} {submission.selftext}"
                    stock_mentions = self._extract_stock_mentions(text)
                    
                    timestamp = datetime.fromtimestamp(submission.created_utc)
                    
                    for ticker in stock_mentions:
                        mentions.append(StockMention(
                            date=timestamp.date(),
                            time=timestamp.time(),
                            subreddit=subreddit_name,
                            title=submission.title,
                            text=submission.selftext,
                            ticker=ticker,
                            company_name=self.stocks[ticker]
                        ))
                        
        except Exception as e:
            logger.error(f"Error scraping r/{subreddit_name}: {e}")
            raise
            
        return mentions

    def get_posts_by_timeframe(self, start_date: str, end_date: str) -> pd.DataFrame:
        """
        Get all stock mentions within the specified timeframe.
        
        Args:
            start_date: Start date in 'YYYY-MM-DD' format
            end_date: End date in 'YYYY-MM-DD' format
            
        Returns:
            DataFrame containing all stock mentions
        """
        start_timestamp, end_timestamp = self._validate_dates(start_date, end_date)
        all_mentions = []
        
        for subreddit_name in self.subreddits:
            try:
                logger.info(f"Scraping r/{subreddit_name}...")
                mentions = self._fetch_subreddit_posts(
                    subreddit_name, start_timestamp, end_timestamp
                )
                all_mentions.extend(mentions)
                logger.info(f"Found {len(mentions)} mentions in r/{subreddit_name}")
                
            except Exception as e:
                logger.error(f"Failed to scrape r/{subreddit_name}: {e}")
                continue
                
        return pd.DataFrame([vars(mention) for mention in all_mentions])


if __name__ == "__main__":
    load_dotenv()
    
    scraper = RedditStockScraper(
        client_id=os.getenv("REDDIT_CLIENT_ID"),
        client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
        user_agent=os.getenv("REDDIT_USER_AGENT")
    )
    
    df = scraper.get_posts_by_timeframe('2024-01-01', '2025-01-31')
    print(f"Total mentions found: {len(df)}")

INFO:__main__:Scraping r/WallStreetBets...
INFO:__main__:Found 214 mentions in r/WallStreetBets
INFO:__main__:Scraping r/StockMarket...
INFO:__main__:Found 215 mentions in r/StockMarket
INFO:__main__:Scraping r/investing...
INFO:__main__:Found 304 mentions in r/investing
INFO:__main__:Scraping r/dividends...
INFO:__main__:Found 174 mentions in r/dividends
INFO:__main__:Scraping r/cryptocurrency...
INFO:__main__:Found 154 mentions in r/cryptocurrency
INFO:__main__:Scraping r/Investing_Discussion...
INFO:__main__:Found 184 mentions in r/Investing_Discussion


Total mentions found: 1245


In [4]:
df

Unnamed: 0,date,time,subreddit,title,text,ticker,company_name
0,2025-01-27,21:53:38,WallStreetBets,Why is nvidia stock going down when deep seek ...,I understand why ai company stock going down b...,NVDA,NVIDIA
1,2025-01-27,21:52:57,WallStreetBets,Proof even sophisticated investors don't research,It hasn't even been a week and people are doom...,NVDA,NVIDIA
2,2025-01-27,21:34:57,WallStreetBets,I sold my index funds to sell Nvidia puts last...,F,NVDA,NVIDIA
3,2025-01-27,21:29:35,WallStreetBets,"These are the 3 largest drops in NVDA history,...",https://preview.redd.it/toblad0cjlfe1.png?widt...,NVDA,NVIDIA
4,2025-01-27,20:51:22,WallStreetBets,Believe it or not calls it is,NVDA just went on a firesale today and dragged...,NVDA,NVIDIA
...,...,...,...,...,...,...,...
1240,2024-01-08,19:02:23,Investing_Discussion,Tesla earnings,What is everyone’s thoughts on TSLA earnings c...,TSLA,Tesla
1241,2024-01-07,15:20:56,Investing_Discussion,Improve my portfolio strategy,started with stocks then got into ETFs.\nI do ...,AAPL,Apple
1242,2024-01-04,07:06:05,Investing_Discussion,Should I invest in Individual companies if I'm...,Should I invest in individual companies like t...,TSLA,Tesla
1243,2024-01-04,07:06:05,Investing_Discussion,Should I invest in Individual companies if I'm...,Should I invest in individual companies like t...,AAPL,Apple


In [5]:
df.to_csv('../bertopic_project/data/stock_mentions.csv', index=False)

In [21]:
import yfinance as yf
import pandas as pd
from datetime import datetime
import logging
from typing import Dict
from pathlib import Path

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class StockDataFetcher:
    def __init__(self):
        self.stocks = {
            'NVDA': 'NVIDIA',
            'TSLA': 'Tesla',
            'AAPL': 'Apple',
            'MC.PA': 'LVMH'
        }
    
    def fetch_stock_data(self, start_date: str, end_date: str) -> pd.DataFrame:
        """
        Fetch stock data and return a clean DataFrame.
        """
        all_data = []
        
        for ticker, company_name in self.stocks.items():
            try:
                logger.info(f"Fetching data for {ticker}...")
                
                # Fetch data for single stock
                df = yf.download(
                    ticker,
                    start=start_date,
                    end=end_date,
                    progress=False
                )
                
                # Clean up the DataFrame
                df = df.reset_index()  # Make Date a column
                df = df.droplevel(1, axis=1)  # Remove the ticker level from columns
                
                # Add identifier columns
                df['Ticker'] = ticker
                df['Stock Name'] = company_name
                
                all_data.append(df)
                logger.info(f"Successfully fetched data for {ticker}")
                
            except Exception as e:
                logger.error(f"Error fetching data for {ticker}: {e}")
                continue
        
        if not all_data:
            raise ValueError("No data was fetched for any stock")
            
        # Combine all stock data
        combined_df = pd.concat(all_data, ignore_index=True)
        
        # Reorder columns to desired format
        column_order = ['Date', 'Ticker', 'Stock Name', 'Open', 'High', 'Low', 'Close', 'Volume']
        combined_df = combined_df[column_order]
        
        # Sort by Date and Ticker
        combined_df = combined_df.sort_values(['Date', 'Ticker'])
        
        return combined_df

def main():
    # Initialize fetcher
    fetcher = StockDataFetcher()
    
    # Set date range
    start_date = '2024-01-01'
    end_date = '2024-01-31'
    
    try:
        # Fetch data
        df = fetcher.fetch_stock_data(start_date, end_date)
        
        # Save to CSV
        df.to_csv('stock_data.csv', index=False)
        
        # Display first few rows
        print("\nFirst few rows of the data:")
        print(df.head())
        
        # Display basic information
        print("\nDataFrame Info:")
        print(df.info())
        
    except Exception as e:
        logger.error(f"An error occurred: {e}")

if __name__ == "__main__":
    main()

INFO:__main__:Fetching data for NVDA...
INFO:__main__:Successfully fetched data for NVDA
INFO:__main__:Fetching data for TSLA...
INFO:__main__:Successfully fetched data for TSLA
INFO:__main__:Fetching data for AAPL...
INFO:__main__:Successfully fetched data for AAPL
INFO:__main__:Fetching data for MC.PA...
INFO:__main__:Successfully fetched data for MC.PA



First few rows of the data:
Price       Date Ticker Stock Name        Open        High         Low  \
536   2024-01-02   AAPL      Apple  186.237618  187.521338  182.993517   
804   2024-01-02  MC.PA       LVMH  726.230050  726.720746  705.326413   
0     2024-01-02   NVDA     NVIDIA   49.230042   49.281026   47.581511   
268   2024-01-02   TSLA      Tesla  250.080002  251.250000  244.410004   
537   2024-01-03   AAPL      Apple  183.321908  184.973819  182.535751   

Price       Close     Volume  
536    184.734985   82488700  
804    709.546387     271775  
0       48.154346  411254000  
268    248.419998  104654200  
537    183.351761   58414500  

DataFrame Info:
<class 'pandas.core.frame.DataFrame'>
Index: 1078 entries, 536 to 535
Data columns (total 8 columns):
 #   Column      Non-Null Count  Dtype         
---  ------      --------------  -----         
 0   Date        1078 non-null   datetime64[ns]
 1   Ticker      1078 non-null   object        
 2   Stock Name  1078 non-nul

In [22]:
df

Price,Close,High,Low,Open,Volume
Ticker,AAPL,AAPL,AAPL,AAPL,AAPL
Date,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
2024-01-02,184.73497,187.521323,182.993502,186.237603,82488700
2024-01-03,183.351746,184.973804,182.535736,183.321893,58414500
2024-01-04,181.023178,182.197418,179.998201,181.261998,71983600
2024-01-05,180.296707,181.869006,179.291637,181.102771,62303300
2024-01-08,184.655365,184.695178,180.615161,181.202281,59144500
2024-01-09,184.237411,184.247357,181.839157,183.023358,42841800
2024-01-10,185.282303,185.491271,183.023365,183.451277,46792900
2024-01-11,184.685226,186.138115,182.724829,185.630592,49128400
2024-01-12,185.013611,185.829621,184.287174,185.152928,40444700
2024-01-16,182.734772,183.36169,180.047923,181.271937,65603000
