In [None]:
!pip install pyspark
!pip install praw

Collecting praw
[?25l  Downloading https://files.pythonhosted.org/packages/48/a8/a2e2d0750ee17c7e3d81e4695a0338ad0b3f231853b8c3fa339ff2d25c7c/praw-7.2.0-py3-none-any.whl (159kB)
[K     |████████████████████████████████| 163kB 7.7MB/s 
[?25hCollecting prawcore<3,>=2
  Downloading https://files.pythonhosted.org/packages/7d/df/4a9106bea0d26689c4b309da20c926a01440ddaf60c09a5ae22684ebd35f/prawcore-2.0.0-py3-none-any.whl
Collecting update-checker>=0.18
  Downloading https://files.pythonhosted.org/packages/0c/ba/8dd7fa5f0b1c6a8ac62f8f57f7e794160c1f86f31c6d0fb00f582372a3e4/update_checker-0.18.0-py3-none-any.whl
Collecting websocket-client>=0.54.0
[?25l  Downloading https://files.pythonhosted.org/packages/08/33/80e0d4f60e84a1ddd9a03f340be1065a2a363c47ce65c4bd3bae65ce9631/websocket_client-0.58.0-py2.py3-none-any.whl (61kB)
[K     |████████████████████████████████| 61kB 5.0MB/s 
Installing collected packages: prawcore, update-checker, websocket-client, praw
Successfully installed praw-7.2.0 

In [None]:
from pyspark.rdd import RDD
from pyspark.sql import Row
from pyspark.sql import DataFrame
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit
from pyspark.sql.functions import desc
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark import SparkContext as sc
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
# tools
import math
import json
import requests
import itertools
import numpy as np
import time
from datetime import datetime, timedelta

def init_spark():
    spark = SparkSession \
        .builder \
        .appName("Python Spark SQL basic example") \
        .config("spark.some.config.option", "some-value") \
        .getOrCreate()
    return spark
spark = init_spark()

In [None]:
"""
Function to make an HTTP request to the Pushshift API.
- max_retry: Nb of times the request is re-tried if failure occurs.
- returns: Python object with the content of the request. (index 'data' property)
"""
def get_request(uri, max_retry = 5):
  def get(uri):
    response = requests.get(uri)
    assert response.status_code == 200
    return json.loads(response.content)
  # Retry if request call failed
  retry = 1
  while retry < max_retry:
    try:
      response = get(uri)
      return response
    except:
      print(f"[{retry}] Request failed, re-trying...")
      # wait 1 second before retry
      time.sleep(1)
      retry += 1

In [None]:
# Testing get_request() with test uri. Should return a non-empty Python object.
obj = get_request("https://httpbin.org/get")
print(obj['url'])
# Returning posts from wallstreetbets. The posts are in the "data" property of the response.
# Use this to check format: https://jsonformatter.curiousconcept.com/#
# obj2 = get_request('https://api.pushshift.io/reddit/search/submission?subreddit=wallstreetbets')
# print(obj2)

https://httpbin.org/get


In [None]:
"""
Gets the all the posts from a given subreddit in the specific time range.
- subreddit: name of subreddit
- begin: timestamp (in unix) of start date
- end: timestamp (in unix) of end date
- returns: list of all the posts in the time interval. the posts are objects with properties "id", "title" and "creation_utc".
"""
def get_posts(subreddit, begin, end):
  # Max size of Pushshift API retrieve is 500 posts.
  SIZE = 100
  PUSHSHIFT_URI = r'https://api.pushshift.io/reddit/search/submission?subreddit={}&after={}&before={}&size={}'
  nb_requests_made = 1

  # Get the ids and creation time of the posts only. Can use later to get the actual posts with PRAW with these ids.
  # Alternatively, we could also directly use get_request() instead of this function, and get all the posts with their content.
  def filter_ids_time(uri, begin, end):
    full_posts = get_request(uri.format(subreddit, begin, end, SIZE))
    # Test prints
    #if nb_requests_made != 1:
    #  print(f"Retrieved full_posts {nb_requests_made} times", len(full_posts['data']))
    if full_posts is None:
      raise ValueError("Response is empty or none.")

    filtered = map(lambda post: {
        'id': post['id'],
        'title': post['title'],
        'created_utc': post['created_utc']
    }, full_posts['data'])
    return list(filtered)

  posts = filter_ids_time(PUSHSHIFT_URI, begin, end)
  posts_amount = len(posts)
  # If reached limit of 500 posts retrieved, make request again until 'end' time.
  while posts_amount == SIZE:
    # Timestamp of the last post we previously retrieved
    new_begin = posts[-1]['created_utc'] - 10
    more_posts = filter_ids_time(PUSHSHIFT_URI, new_begin, end)
    posts_amount = len(more_posts)
    posts.extend(more_posts)
    nb_requests_made += 1
  
  return posts

In [50]:
"""
Testing get_posts() function.
- Timestamp converter: https://www.unixtimestamp.com/index.php?ref=theredish.com%2Fweb
- Able to retrieve up till latest posts from 6 hours ago.
- To print till "now": math.ceil(datetime.utcnow().timestamp())
"""
# Posts from March 13th
#posts = get_posts('wallstreetbets', 1615687200, math.ceil(datetime.utcnow().timestamp()))
# All posts from nb_days_from_today
nb_days_from_today = 5
begin = math.ceil((datetime.utcnow() - timedelta(days=nb_days_from_today)).timestamp())
end = math.ceil(datetime.utcnow().timestamp())
print("Timestamps: ", begin, end)
posts = get_posts('wallstreetbets', begin, end)
unique_posts = np.unique([post['id'] for post in posts])
# Use np.unique to get rid of duplicates and filter posts only by id (only need id for praw).
print("Size: ", len(posts))
print("Size of uniques: ", len(unique_posts))
print("Example posts: ", unique_posts[:5])

Timestamps:  1615342779 1615774779
Size:  30737
Size of uniques:  29970
Example posts:  ['m1mx0q' 'm1mx54' 'm1mx6b' 'm1mx8k' 'm1mxo2']


In [41]:
import praw
import os
"""
Connect to Reddit API using environment variables. (values from reddit app)
- returns: reddit instance
"""
def connect_reddit():
  reddit = praw.Reddit(
    username = "fryingpannnnnn", #os.environ['REDDIT_NAME'],
    password = "wsbpan123", #os.environ['REDDIT_PASS'],
    client_secret = "GWFFdjptuoOsDPjpxERQsBY-xV7GfQ", #os.environ['REDDIT_SECRET'],
    client_id = "VWDfl0X3JFVwWQ", #os.environ['REDDIT_CLIENTWSB'],
    user_agent = "wsbscrape"
  )
  return reddit

In [42]:
# Testing connect_reddit() function
# Should print "<praw.reddit.Reddit object at ...>"
reddit = connect_reddit()
print(reddit)

<praw.reddit.Reddit object at 0x7fe1a9c60110>


In [44]:
"""
Function to retrieve all Reddit posts of given ids.
- unique_posts: the id of the posts you want to retrieve.
- returns: list of reddit submission instances (reddit posts) of the input ids.
"""
def get_reddit_posts(unique_posts):
  reddit = connect_reddit()
  PAUSE_BEFORE_NEXT_CALL = 0.3

  reddit_posts = []

  for post_id in unique_posts:
    reddit_posts.append(reddit.submission(id=post_id))
    time.sleep(PAUSE_BEFORE_NEXT_CALL)

  return reddit_posts

In [51]:
# Testing get_reddit_posts(): retrieving 1000 posts.
thousand_posts = unique_posts[:1000]
reddit_posts = get_reddit_posts(thousand_posts)
print("Size of posts: ", len(reddit_posts))
print("First 5 posts: ", reddit_posts[:5])

Size of posts:  1000
First 5 posts:  [Submission(id='m1mx0q'), Submission(id='m1mx54'), Submission(id='m1mx6b'), Submission(id='m1mx8k'), Submission(id='m1mxo2')]


In [58]:
reddit_posts[0].score

1