In [1]:
import requests
import time
import pandas as pd

In [3]:
# ================================================================================================================
# This function uses the pushshift api to connect to a given subreddit and return a set of posts.
#
# Inputs: 1) Name of the subreddit
#         2) Push shift api URL for the subreddit, e.g. 'https://api.pushshift.io/reddit/search/submission'
#         3) The number of posts that should be scraped from the subreddit (maximum is 100).
#         
#  The optional posts_before_time input is typically used when calling this function in a loop, to ensure that
#  the posts returned are farther back in time than the oldest post in the last batch.
# ================================================================================================================
def get_reddit_data(subreddit_name, url, num_posts, posts_before_time = 0, verbose=False): 
    
    # Params specified per the pushshift api to specify which subreddit and the
    # number of posts that should be returned.
    params = {'subreddit' : subreddit_name,
              'size' : num_posts}
    
    # Add an additional key value pair to the params dictionary if posts_before_time is specified.
    if posts_before_time != 0:
        params['before'] = posts_before_time
    
    # Requests posts
    res = requests.get(url, params)
    
    # Save status_code associated with the request
    status_code = res.status_code
    
    # If displaying messages is desired.
    if verbose:
        print("===============================================================================")
        print(f"Connection to subreddit {subreddit_name} returned status code {status_code}.")
        print("===============================================================================")
    
    # Check to see if the request was successful (status 200). 
    # If request was unsuccessful, print an error message and return -1.
    # If request was successful, return the json with the post information.
    if status_code != 200:
        print(f"Exiting function due to invalid status code ---> {status_code}")
        return -1
    else:
        return res.json()
        
    

In [4]:
# ================================================================================================================
# This is a helper function used to convert the json that is returned from a successful call to requests.get()
# into a pandas dataframe. 
# ================================================================================================================
def posts_to_dataframe(data, verbose=False, posts_left=None, total_posts=None, name=None):
    
    # List of dictionaries containing the reddit post information.
    posts = data['data']
    
    # Save the 'created_utc' information from the last post returned.
    # This will be use to specify where the next scrape should begin.
    last_post_time = posts[-1]['created_utc']
    
    # If displaying messages is desired.
    if verbose:
        print("\n==============================================================================")
        print(f"Successfully obtained {total_posts - posts_left} of {total_posts} posts so far in subreddit {name}.")
        print(f"First post in this batch created at UTC: {posts[0]['created_utc']}")
        print(f"Last post in this batch created at UTC: {posts[-1]['created_utc']}")
        print(f"There are {posts_left} more posts to collect.")
        print("==============================================================================\n")
    
    # Convert the list of dictionaries to a pandas dataframe.
    df = pd.DataFrame(posts)
                      
    return df, last_post_time

In [5]:
# ================================================================================================================
# When scraping posts in a loop for long periods of time, it may be useful to periodically save the data to a
# .csv. This helps provide confidence that the script is running as intended, and allows some data to be utilized
# before the data collection process has finished.
#
# ================================================================================================================
def posts_to_csv(df, posts_left, total_posts, subreddit, intermediate=True):
    
    # If this is an intermediate (checkpoint) save, specify the filename accordingly and perform the save.
    if intermediate: 
        
        posts_completed = total_posts - posts_left
        
        df.to_csv(f"./data/Intermediate/Intermediate_{subreddit}_{posts_completed}_of_{total_posts}_post_data.csv", index=False)
    
    # Otherwise this is the final save. Indicate this in the filename and perform the save.
    else: 
        
        df.to_csv(f"./data/Final/Final_{subreddit}_data_{total_posts}_posts.csv", index=False)
    
    return

In [6]:
# ================================================================================================================
# When scraping posts in a loop for long periods of time, there are likely going to be some intermittent failures  
# to connect, i.e. requests.get() not returning status 200. This function helps prevent the data collection process
# from being disrupted due to this intermittent errors.
#
# If an attempt to collect posts fails, this function will try up to 5 more times to reconnect. There is a 5 second
# waiting period between each reconnection attempt.
#
# If any reconnection attempt is successful, the data collection process continues as normal. If after five attempts
# there has not been a successful connection, the function returns -1, which will cause the data collection
# process to halt. 
# ================================================================================================================
def retry_reddit_scrape(subreddit_name, url, posts_per_request, posts_before_time):
    
    # Cool down period between each reconnection attempt
    time.sleep(5)
    
    # Track the number of times we have tried to reconnect.
    attempts = 1
    
    # Try up to 5 times to reconnect. 
    while attempts < 6:
        
        print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        print(f"Reconnection attempt number {attempts} after a 5 second wait...")
        print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
        
        # Reconnection attempt
        data = get_reddit_data(subreddit_name=subreddit_name, url=url, num_posts=posts_per_request,
                               posts_before_time = posts_before_time, verbose=True)
        
        # If the reconnection attempt was successful, return the data that was collected.
        if data != -1:
            return data
    
    return -1

In [7]:
# ================================================================================================================
# This function is used to scrape a specified number of posts from a given subreddit. 
#
# Optional parameters warm, before and start_file are used to add some extra flexibility for situations where
# there is a desire to start the post collection from a specified time frame, and append the data collected to
# a specified file. 
# ================================================================================================================
def scrape_reddit(subreddit_name = 'CryptoCurrency', url='https://api.pushshift.io/reddit/search/submission', posts_per_request=100, total_posts=2000,
                  warm=False, before=None, start_file=None, posts_completed=None, save_every=1000):
    
    # If we have already collected some of the initially desired posts, and are now rerunning the function
    # to pick up where the last round of data collection left off. 
    if warm:
        df = pd.read_csv(start_file)
        posts_left = total_posts - posts_completed
        last_post_time = before
        
        print(f"warm start...{posts_completed} posts have been scraped already, {posts_left} more to go")
    
    # If this is not a "warm" start, we do not have a file to begin appending our posts to. In this case
    # we will collect one batch of posts and store them in a dataframe. This dataframe can then be added to 
    # in the while loop as additional batches of posts are collected.
    else:
        
        # Get a set of posts
        data = get_reddit_data(subreddit_name, url, posts_per_request)
        
        # If the very first batch of posts we try to collect fails, print an error and return
        if data == -1: 
            print(f"Exiting {subreddit_name} reddit scrape due to invalid response code after collecting 0 posts")
            return -1
        
        # Create a dataframe and store the timestamp from the last post collected.
        df, last_post_time = posts_to_dataframe(data)
        
        # decrease the number of posts left to collect.
        posts_left = total_posts - posts_per_request
    
    while posts_left > 0:
        
        # Rest for 1 second so we don't ping the server too often
        time.sleep(1)
        
        # Save the dataframe to a csv every 500 posts scraped, this way if an error occurs we have
        # a backup copy of the data that has been collected. 
        if posts_left % save_every == 0: 
            posts_to_csv(df, posts_left, total_posts, subreddit_name)
            verbose=True
        else:
            verbose=False
            
        # Connect to the subreddit and return the json for the last set of posts as a list of python dictionaries
        data = get_reddit_data(subreddit_name, url, posts_per_request, posts_before_time = last_post_time)
        
        # If we were unable to connect to the reddit page...
        if data == -1:
            
            # Try to reconnect up to five times, with a 5 second delay between each attempt.
            data = retry_reddit_scrape(subreddit_name, url, posts_per_request, posts_before_time = last_post_time)
            
            # If none of the retry attempts were successful, save the data we have so far and return. 
            if data == -1:
                print(f"Failed to reconnect with a valid status code, exiting the {subreddit_name} call to scrape_redit")
                new_df, last_post_time = posts_to_dataframe(data, verbose=verbose, posts_left=posts_left, total_posts=total_posts, name=subreddit_name)
                return df
            
        
        # Decrease the number of posts left to scrape
        posts_left = posts_left - posts_per_request
        
        # Convert the list of dictionaries created above to a pandas dataframe
        new_df, last_post_time = posts_to_dataframe(data, verbose=verbose, posts_left=posts_left, total_posts=total_posts, name=subreddit_name)
        
        # concatenate the pandas dataframe created above to the dataframe of all posts scraped so far.
        df = pd.concat([df, new_df], ignore_index=True)
    
    # Final save
    posts_to_csv(df, posts_left, total_posts, subreddit_name, intermediate=False)
    
    print(f"\n>>>>>>>>>>>>>>>>>>>> Finished Scraping {total_posts} posts from subreddit {subreddit_name} <<<<<<<<<<<<<<<<<<<<\n")
    
    return df

In [8]:
# ================================================================================================================
# Helper function to scrape_multiple_subreddits.
# Takes in a pandas dataframe with reddit post information and returns a dataframe that only contains
# the few columns we are interested in keeping.
# ================================================================================================================
def clean_reddit_data(df, subreddit_name, total_posts, columns_to_keep=['subreddit', 'selftext', 'title', 'created_utc']):
    
    df = df.loc[:, columns_to_keep]
    
    df.to_csv(f"./data/Final/Clean_{subreddit_name}_data_{total_posts}_posts.csv", index=False)
    
    return df
    

In [9]:
# ================================================================================================================
# This function calls the scrape_reddit function in a loop for a given list of subreddit names
#
# Optional parameters for 
# ================================================================================================================
def scrape_multiple_subreddits(subreddit_names = ['wallstreetbets', 'CryptoCurrency'], url='https://api.pushshift.io/reddit/search/submission', posts_per_request=100,
                               total_posts_per_subreddit=2000, save_every=500):
    
    clean_dfs = []
    full_dfs = []
    
    for reddit_name in subreddit_names: 
        
        # This section was added after the fact to add flexibility when collecting additional posts.
        # The if/else below can be easily edited as needed to allow data collection to begin after any desired timestamp, and to allow for
        # different numbers of posts to be colleted from each subreddit.
        if reddit_name == 'NA':
            warm = True
            before = '1610672400'
            file = "./data/Intermediate/blank_warmstart_wallstreetbets.csv"
            completed = 0
        else:
            warm = True
            before = "1510724913"
            file = "./data/Intermediate/blank_warmstart_crypto.csv"
            completed = 0
        
        # Call scrape_reddit for each of the desired subredits.
        df = scrape_reddit(subreddit_name = reddit_name, url='https://api.pushshift.io/reddit/search/submission', posts_per_request=posts_per_request,
                           total_posts=total_posts_per_subreddit, warm=warm, before=before, start_file=file, posts_completed=completed, save_every=save_every)
        
        # Keep a list of the "full" dataframes for each subreddit that is scraped.
        full_dfs.append(df)
        
        # Remove columns that are not going to be used.
        df = clean_reddit_data(df, subreddit_name=reddit_name, total_posts=total_posts_per_subreddit)
        
        # Keep a list of the reduced size dataframes for each subreddit that is scraped.
        clean_dfs.append(df)
    
    
    print("======================================== Scrape complete! ========================================")
    print(f"Successfully collected {total_posts_per_subreddit} from each of the following subreddits ---> {subreddit_names}")
    print("==========================================================================================")
    
    
    all_data = {'full_dfs' : full_dfs,
                "clean_dfs" : clean_dfs}
    
    return all_data

In [10]:
scrape_dataframes = scrape_multiple_subreddits(subreddit_names = ['wallstreetbets'], total_posts_per_subreddit=125_000, save_every=25000)

warm start...0 posts have been scraped already, 125000 more to go

Successfully obtained 100 of 125000 posts so far in subreddit wallstreetbets.
First post in this batch created at UTC: 1610672395
Last post in this batch created at UTC: 1610670451
There are 124900 more posts to collect.


Successfully obtained 25100 of 125000 posts so far in subreddit wallstreetbets.
First post in this batch created at UTC: 1608740238
Last post in this batch created at UTC: 1608737264
There are 99900 more posts to collect.


Successfully obtained 50100 of 125000 posts so far in subreddit wallstreetbets.
First post in this batch created at UTC: 1606768021
Last post in this batch created at UTC: 1606764832
There are 74900 more posts to collect.


Successfully obtained 75100 of 125000 posts so far in subreddit wallstreetbets.
First post in this batch created at UTC: 1604692089
Last post in this batch created at UTC: 1604533851
There are 49900 more posts to collect.


Successfully obtained 100100 of 125000