In [1]:
import praw
import pandas as pd
import datetime
import pytz
import time
import json
import os
from typing import List, Dict, Any
import logging
import random
import os
from dotenv import load_dotenv

load_dotenv(dotenv_path="../references/.env")

True

In [None]:
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class RedditMentalHealthScraper:
    def __init__(self, client_id: str, client_secret: str, user_agent: str):
        """
        Initialize Reddit scraper
        
        Args:
            client_id: Reddit API client ID
            client_secret: Reddit API client secret  
            user_agent: User agent string for API requests
        """
        # api request
        self.reddit = praw.Reddit(
            client_id=client_id,
            client_secret=client_secret,
            user_agent=user_agent
        )
        
        # Define rome timezone and period of analysis
        rome_tz = pytz.timezone('Europe/Rome')
        self.start_date = rome_tz.localize(datetime.datetime(2025, 1, 1, 0, 0, 0))
        now_rome = datetime.datetime.now(rome_tz)
        self.end_date = now_rome
        
        # Convert to UTC for Reddit API
        self.start_timestamp = self.start_date.astimezone(pytz.UTC).timestamp()
        self.end_timestamp = self.end_date.astimezone(pytz.UTC).timestamp()
        
        # Target subreddits and number of posts for each of them
        self.subreddits = ['depression', 'depression_help', 'mentalhealth', 'Anxiety', 'Anxietyhelp']
        self.posts_per_subreddit = 2000
        
    def get_post_data(self, post) -> Dict[str, Any]:
        """Extract only the essential data from a Reddit post"""
        return {
            'post_id': post.id,
            'title': post.title,
            'selftext': post.selftext,
            'author': str(post.author) if post.author else '[deleted]',
            'subreddit': str(post.subreddit),
            'created_utc': post.created_utc,
            'created_rome': datetime.datetime.fromtimestamp(
                post.created_utc, pytz.timezone('Europe/Rome')
            ).isoformat(),
            'num_comments': post.num_comments,
            'score': post.score,
            'flair': post.link_flair_text
        }
    
    def get_comment_data(self, comment, post_id: str) -> Dict[str, Any]:
        """Extract only the essential data from a Reddit comment"""
        return {
            'comment_id': comment.id,
            'post_id': post_id,
            'body': comment.body,
            'author': str(comment.author) if comment.author else '[deleted]',
            'created_utc': comment.created_utc,
            'created_rome': datetime.datetime.fromtimestamp(
                comment.created_utc, pytz.timezone('Europe/Rome')
            ).isoformat(),
            'score': comment.score,
            'is_submitter': comment.is_submitter,
            'depth': getattr(comment, 'depth', 0)
        }
    
    def scrape_subreddit(self, subreddit_name: str) -> tuple[list[Dict], list[Dict]]:
        """
        Scrape posts and top-level comments from a specific subreddit using search + date filtering.
     """
        logger.info(f"Starting to scrape r/{subreddit_name}")
    
        subreddit = self.reddit.subreddit(subreddit_name)
        posts_data = []
        comments_data = []

        try:
            fetched_posts = []
        
            for post in subreddit.new(limit = 2000):
                if self.start_timestamp <= post.created_utc <= self.end_timestamp:
                    fetched_posts.append(post)
                elif post.created_utc < self.start_timestamp:
                    break  # We're past the date range, stop
        
                time.sleep(random.uniform(1.0, 2.0))  # Be gentle to Reddit

        except Exception as e:
           logger.error(f"Error searching posts from r/{subreddit_name}: {e}")
           return [], []

        logger.info(f"Found {len(fetched_posts)} posts in time range for r/{subreddit_name}")

        random.seed(77)

        if len(fetched_posts) > self.posts_per_subreddit:
           sampled_posts = random.sample(fetched_posts, self.posts_per_subreddit)
        else:
           sampled_posts = fetched_posts
           logger.warning(f"Only found {len(fetched_posts)} posts for r/{subreddit_name}, less than target {self.posts_per_subreddit}")

        for i, post in enumerate(sampled_posts):
            try:
                post_data = self.get_post_data(post)
                posts_data.append(post_data)

                # Fetch direct replies to the posts
                post.comments.replace_more(limit=0) # 

                for top_level_comment in post.comments:
                    if (
                       hasattr(top_level_comment, 'body') and 
                       top_level_comment.body not in ['[deleted]', '[removed]']
                    ):
                      comment_data = self.get_comment_data(top_level_comment, post.id)
                      comments_data.append(comment_data)

                    #if len(comments_data) >= 10:
                     #   break  # Only take up to 10 top-level comments

                if (i + 1) % 100 == 0:
                    logger.info(f"Processed {i + 1}/{len(sampled_posts)} posts from r/{subreddit_name}")

                time.sleep(1)

            except Exception as e:
                logger.error(f"Error processing post {post.id} from r/{subreddit_name}: {e}")
                continue

        logger.info(f"Completed r/{subreddit_name}: {len(posts_data)} posts, {len(comments_data)} comments")
        return posts_data, comments_data
    
    def scrape_all_subreddits(self, output_dir: str = "../references/data"):
        """
        Scrape all target subreddits and save data
        
        Args:
            output_dir: Directory to save the extracted data
        """
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        all_posts = []
        all_comments = []
        
        for subreddit_name in self.subreddits:
            try:
                posts_data, comments_data = self.scrape_subreddit(subreddit_name)
                
                # Save individual subreddit data
                posts_df = pd.DataFrame(posts_data)
                comments_df = pd.DataFrame(comments_data)
                
                posts_df.to_csv(f"{output_dir}/{subreddit_name}_posts.csv", index=False)
                comments_df.to_csv(f"{output_dir}/{subreddit_name}_comments.csv", index=False)
                
                # Add to combined dataset
                all_posts.extend(posts_data)
                all_comments.extend(comments_data)
                
                logger.info(f"Saved data for r/{subreddit_name}")
                
                # Sleep between subreddits to be respectful to Reddit's servers
                time.sleep(1)
                
            except Exception as e:
                logger.error(f"Failed to scrape r/{subreddit_name}: {e}")
                continue
        
        # Save combined dataset
        all_posts_df = pd.DataFrame(all_posts)
        all_comments_df = pd.DataFrame(all_comments)
        
        all_posts_df.to_csv(f"{output_dir}/all_posts_combined.csv", index=False)
        all_comments_df.to_csv(f"{output_dir}/all_comments_combined.csv", index=False)
        
        # Save metadata
        metadata = {
            'scrape_date': datetime.datetime.now().isoformat(),
            'time_period_start': self.start_date.isoformat(),
            'time_period_end': self.end_date.isoformat(),
            'subreddits': self.subreddits,
            'target_posts_per_subreddit': self.posts_per_subreddit,
            'total_posts_collected': len(all_posts),
            'total_comments_collected': len(all_comments),
            'posts_per_subreddit': {sub: len([p for p in all_posts if p['subreddit'] == sub]) for sub in self.subreddits}
        }
        
        with open(f"{output_dir}/metadata.json", 'w') as f:
            json.dump(metadata, f, indent=2)
        
        logger.info(f"Scraping completed! Total: {len(all_posts)} posts, {len(all_comments)} comments")
        logger.info(f"Data saved to {output_dir}/")
        
        return all_posts_df, all_comments_df

# Usage example
if __name__ == "__main__":
    # You need to register a Reddit app at https://www.reddit.com/prefs/apps/
    # and get your credentials
    CLIENT_ID = os.getenv("REDDIT_CLIENT_ID")
    CLIENT_SECRET = os.getenv("REDDIT_CLIENT_SECRET")
    USER_AGENT = os.getenv("REDDIT_USER_AGENT")
    
    # Initialize scraper
    scraper = RedditMentalHealthScraper(
        client_id=CLIENT_ID,
        client_secret=CLIENT_SECRET,
        user_agent=USER_AGENT
    )
    
    # Run the scraping
    posts_df, comments_df = scraper.scrape_all_subreddits()
    
    # Display summary statistics
    print("\n=== SCRAPING SUMMARY ===")
    print(f"Total posts collected: {len(posts_df)}")
    print(f"Total comments collected: {len(comments_df)}")
    print(f"\nPosts per subreddit:")
    print(posts_df['subreddit'].value_counts())
    print(f"\nComments per subreddit:")
    comment_subreddit_counts = comments_df.merge(
        posts_df[['post_id', 'subreddit']], 
        on='post_id', 
        how='left'
    )['subreddit'].value_counts()
    print(comment_subreddit_counts)

2025-06-22 16:48:05,738 - INFO - Starting to scrape r/depression
2025-06-22 17:12:48,959 - INFO - Found 983 posts in time range for r/depression
2025-06-22 17:14:50,624 - INFO - Processed 100/983 posts from r/depression
2025-06-22 17:16:53,590 - INFO - Processed 200/983 posts from r/depression
2025-06-22 17:18:58,425 - INFO - Processed 300/983 posts from r/depression
2025-06-22 17:21:01,531 - INFO - Processed 400/983 posts from r/depression
2025-06-22 17:23:03,759 - INFO - Processed 500/983 posts from r/depression
2025-06-22 17:25:09,150 - INFO - Processed 600/983 posts from r/depression
2025-06-22 17:27:11,659 - INFO - Processed 700/983 posts from r/depression
2025-06-22 17:29:23,026 - INFO - Processed 800/983 posts from r/depression
2025-06-22 17:31:29,626 - INFO - Processed 900/983 posts from r/depression
2025-06-22 17:33:17,253 - INFO - Completed r/depression: 983 posts, 1378 comments
2025-06-22 17:33:17,529 - INFO - Saved data for r/depression
2025-06-22 17:33:18,538 - INFO - Star


=== SCRAPING SUMMARY ===
Total posts collected: 4822
Total comments collected: 11310

Posts per subreddit:
subreddit
depression         983
Anxietyhelp        974
depression_help    968
mentalhealth       961
Anxiety            936
Name: count, dtype: int64

Comments per subreddit:
subreddit
Anxietyhelp        2799
Anxiety            2739
depression_help    2422
mentalhealth       1972
depression         1378
Name: count, dtype: int64
