<a href="https://colab.research.google.com/github/AfsanehHabibi/reddit-conversation-quality/blob/main/RedditConversationQuality.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Scraping Reddit Data  

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#Pushshift API

In [None]:
base_path = "/content/drive/MyDrive/University/RedditData/"

#Create sampling timestamps

In [None]:
import random
import datetime
def create_timestamps(seed, num_values, period_length, year, month, day):
    random.seed(seed)
    start_date = datetime.datetime(year, month, day)
    minutes_after = set()
    scale = 24*60 / period_length
    while len(minutes_after) < num_values:
        day_after = random.randint(0, 7)
        minutes_after.add(day_after*24*60 + random.randint(0, scale)*period_length)
    timestamps = []
    for minute in minutes_after:
        timestamp = start_date + datetime.timedelta(minutes=minute)
        end_timestamp = timestamp + datetime.timedelta(minutes=period_length)
        timestamps.append((int(timestamp.timestamp()), int(end_timestamp.timestamp())))
    return timestamps

In [None]:
import os

def write_timestamps_to_file(seed, num_values, period_length, year, month, day):
    # Create the timestamps using the provided function
    timestamps = create_timestamps(seed, num_values, period_length, year, month, day)
    
    # Generate the file name based on the input arguments
    file_name = f"{base_path}timestamps_seed{seed}_num{num_values}_period{period_length}_date{year}-{month}-{day}.txt"
    
    # Write the timestamps to the file
    if os.path.exists(file_name):
        print(f"File already exists: {file_name}")
        return
    with open(file_name, "w") as f:
        for start_time, end_time in timestamps:
            f.write(f"{start_time},{end_time},0\n")
    
    # Print the file path for reference
    print(f"Timestamps file saved to {os.getcwd()}/{file_name}")


In [None]:
write_timestamps_to_file(seed=2, num_values=20, period_length=10, year=2022, month=10, day=1)

##Functions

In [None]:
import json

def filter_json_objects(data, keys):
    filtered_data = []
    for item in data:
        filtered_item = {}
        for key in keys:
            if key in item:
                filtered_item[key] = item[key]
        filtered_data.append(filtered_item)
    return filtered_data

In [None]:
def write_last_timestamp(last_timestamp, entry, id):
  with open(f"{base_path}last_timestamp_{entry}_{id}.txt", "w") as f:
    f.write(str(last_timestamp))

In [None]:
def read_last_timestamp(default, entry, id):
  try:
    with open(f"{base_path}last_timestamp_{entry}_{id}.txt", "r") as f:
        return int(f.read())
  except FileNotFoundError:
    return default

In [None]:
def write_data_to_file(data, entry, id):
  with open(f"{base_path}{entry}s_{id}.json", "a") as f:
    for element in data:
        json.dump(element, f)
        f.write("\n")

In [None]:
def create_fields_filter_in_url(fields):
    """
    Creates a filter query parameter for the API endpoint based on the specified fields.

    Parameters:
    - fields (list): a list of field names to include in the query results. 
      If the list has only one element "all", all fields are included.

    Returns:
    - filter_query (str): the constructed query parameter string to append to the API endpoint URL.
    """
    if len(fields) == 1 and fields[0] == "all":
        return ""
    else:
        filter_query = "&filter=" + ','.join(fields)
        return filter_query

In [None]:
import requests
import datetime
import json
import time

def extract_date_based_data_from_reddit(id, entry_type, keys, start_date, end_date, step):
  # type can be either submission or comment
  # keys are fields from submission or comment that we want to save
  # start date and end date shows the date interval which data is collected


  # read the last UTC timestamp from file
  last_timestamp = read_last_timestamp(int(end_date.timestamp()), entry_type, id)

  if last_timestamp == int(start_date.timestamp()):
    return
  # define the API endpoint
  filter_query = create_fields_filter_in_url(keys)
  const_url = f"https://api.pushshift.io/reddit/search/{entry_type}/?size=500{filter_query}&sort=created_utc&"
  url_template = const_url + "after={}&before={}"
  # get the set of enteries using pagination
  while True:
    new_url = url_template.format(
        max(last_timestamp-step,int(start_date.timestamp())),
         last_timestamp)
    try:
        new_response = requests.get(new_url)
        enteries = new_response.json()["data"]
        write_data_to_file(enteries, entry_type, id)
        # save the UTC timestamp of the last entry to file
        print(len(enteries))
        if len(enteries) == 0:
          last_timestamp -= step
          continue
        last_entry = enteries[-1]
        last_timestamp = last_entry["created_utc"]
        last_entry_date = datetime.datetime.fromtimestamp(last_timestamp)
        print(last_entry_date)
        write_last_timestamp(last_timestamp, entry_type, id)
        if last_timestamp == int(start_date.timestamp()):
          break
    except (json.JSONDecodeError, requests.exceptions.HTTPError) as e:
        if new_response.status_code == 429 :
            retry_after = 30#int(new_response.headers.get("Retry-After"))
            print(f"Got HTTP error 429, waiting {retry_after} seconds and retrying...")
            time.sleep(retry_after)
            continue
        elif new_response.status_code == 524:
            retry_after = 60
            print(f"Got HTTP error 524, waiting {retry_after} seconds and retrying...")
            time.sleep(retry_after)
            continue
        else:
            raise e

In [None]:
def update_timestamps_file(file_path, starting_timestamp, status):
    # Read the contents of the file
    with open(file_path, 'r') as f:
        lines = f.readlines()
    
    # Find the line with the starting timestamp
    for i, line in enumerate(lines):
        start_time, end_time, old_status = line.strip().split(',')
        if int(start_time) == starting_timestamp:
            # Update the status to the new value
            lines[i] = f"{start_time},{end_time},{status}\n"
            break
    
    # Write the updated contents back to the file
    with open(file_path, 'w') as f:
        f.writelines(lines)
        f.flush()

In [None]:
def extract(seed, num_values, period_length, year, month, day, entry_type, keys):
    file_path = f"{base_path}timestamps_seed{seed}_num{num_values}_period{period_length}_date{year}-{month}-{day}.txt"
    
    while True:
        all_done = True
        with open(file_path, "r") as f:
            for line in f:
                print(line)
                start, end, done_str = line.strip().split(",")
                done = int(done_str) == 1
                if not done:
                    print(int(start), int(end))
                    timestamp = (int(start), int(end))
                    all_done = False
                    break
        if all_done:
          break
        start_date = datetime.datetime.fromtimestamp(timestamp[0])
        end_date = datetime.datetime.fromtimestamp(timestamp[1])
        extract_date_based_data_from_reddit(timestamp[0], entry_type, keys, start_date, end_date, 10)
        update_timestamps_file(file_path, timestamp[0], 1)
    return timestamps

In [None]:
extract(seed=2, num_values=20, period_length=10, year=2022, month=10, day=1, entry_type="submission", keys=["all"])

In [None]:
with open(f'{base_path}submissions_1664946000.json', 'r') as f:
    num_lines = len(f.readlines())
    print(f"The file contains {num_lines} lines.")

##Write submmisons to file

In [None]:
# define the date range
start_date = datetime.datetime(2022, 3, 1)
end_date = datetime.datetime(2022, 3, 31)
keys = ["id","subreddit","selftext","title","quarantine","is_original_content","is_meta","is_created_from_ads_ui","author_premium","is_self","subreddit_type","allow_live_comments","is_crosspostable","over_18","removed_by","distinguished","subreddit_id","author","discussion_type","num_comments","whitelist_status","subreddit_subscribers","created_utc","retrieved_utc","updated_utc","media_metadata"]

extract_date_based_data_from_reddit(int(start_date.timestamp()), "submission", ["all"], start_date, end_date, 10)

##Write sample submission fields to file

In [None]:
import requests
import json

# define the API endpoint with the sample submission ID
submission_id = "1273r9g"
url = f"https://api.pushshift.io/reddit/submission/search/?ids={submission_id}"

# make the API request and get the submission
response = requests.get(url)
submission_list = response.json()["data"]

# get the first submission from the list
submission = submission_list[0]

# write the submission keys to a file
with open(f"{base_path}submission_keys.txt", "w") as f:
    for key in submission.keys():
        f.write(key + "\n")

##Write comments to file

###With Pushshift date search

In [None]:
# define the date range
start_date = datetime.datetime(2023, 3, 1)
end_date = datetime.datetime(2023, 3, 31)
keys = ['created_utc', 'id']

extract_date_based_data_from_reddit("comment", keys, start_date, end_date, 2)

##Write sample submission fields to file

In [None]:
import requests
import json

# define the API endpoint with the sample comment ID
comment_id = "jece0zo"
url = f"https://api.pushshift.io/reddit/comment/search/?ids={comment_id}"

# make the API request and get the submission
response = requests.get(url)
comment_list = response.json()["data"]

# get the first submission from the list
comment = comment_list[0]

# write the submission keys to a file
with open(f"{base_path}comment_keys.txt", "w") as f:
    for key in comment.keys():
        f.write(key + "\n")

In [None]:
def test_create_fields_filter_in_url():
    # Test case 1: when fields is ['all']
    fields = ['all']
    expected_output = ''
    assert create_fields_filter_in_url(fields) == expected_output
    
    # Test case 2: when fields is empty
    fields = []
    expected_output = '&filter='
    assert create_fields_filter_in_url(fields) == expected_output
    
    # Test case 3: when fields contains multiple values
    fields = ['field1', 'field2', 'field3']
    expected_output = '&filter=field1,field2,field3'
    assert create_fields_filter_in_url(fields) == expected_output

In [None]:
test_create_fields_filter_in_url()

In [None]:
def test_create_timestamps():
    # Create timestamps with 5-minute periods for 10 values starting from 2022-01-01
    timestamps = create_timestamps(seed=123, num_values=10, period_length=5, year=2022, month=1, day=1)

    # Check that there are no overlaps between periods
    for i in range(len(timestamps)):
        for j in range(i+1, len(timestamps)):
            start_i, end_i = timestamps[i]
            start_j, end_j = timestamps[j]
            assert end_i <= start_j or end_j <= start_i, "Overlapping periods"

    # Check that all periods start after or at the starting day
    start_date = datetime.datetime(2022, 1, 1)
    for start, end in timestamps:
        assert start >= start_date.timestamp(), "Period starting before start day"

    # Check that all periods end at most 7 days after the starting day
    end_date = start_date + datetime.timedelta(days=7)
    for start, end in timestamps:
        assert end <= end_date.timestamp(), "Period ending more than 7 days after start day"