# WARNING: we cannot make any guarantees about the content of the memes. They were scraped automatically from r/wholesomememes. Proceed at your own discretion.

### Let's get some data to work with!

We'll be using the Reddit API (Python wrapper) to scrape data off of r/wholesomememes.

In [None]:
# importing packages that we need to run the program
import psaw                     # reddit API to scrape posts from the internet
import pandas as pd             # pandas is a dataframe library that lets us handle data cleanly in an organized manner
from PIL import Image           # PIL is a python imaging library, we'll be using it to show images
import requests                 # lets us makes HTTP requests to fetch the image from the URL
from io import BytesIO          # lets us handle bytestreams, which is the return result of the HTTP request
import numpy as np              # lets us handle array data very quickly through vectorization

In [None]:
# scraping the data from reddit's r/wholesomememes
# please don't run this cell more than once! you will be rate-limited

# making an instance of the scraper
scraper = psaw.PushshiftAPI()

# querying reddit from r/wholesomememes and pulling info about the image url, gildings, number of comments, and number of crossposts
results = scraper.search_submissions(subreddit='wholesomememes', 
                                     filter=['url', 'gildings', 'num_comments', 'num_crossposts'],
                                     sort_type='score', limit=10000)

# the results object is a generator, so we need to pull the actual data values from the generator object
# for more on generators: https://wiki.python.org/moin/Generators
memes = pd.DataFrame([item.d_ for item in results])

In [None]:
# filtering only image memes, then fetching the images
memes = memes[memes['url'].str.contains('https://i.redd.it/')]

def get_image(url):
  '''
  get the image from the specified url
  referenced from: https://datascience.stackexchange.com/questions/58351/how-to-retrieve-images-from-a-url-in-a-pandas-dataframe-and-store-them-as-pil-ob
  @params url -- the url of the image to get
  @return a numpy array of the image if available, else None
  '''
  request_result = requests.get(url)
  try: 
    request_result.raise_for_status()
    image = Image.open(BytesIO(request_result.content))
  except requests.exceptions.HTTPError:
    return None

  return image

# the apply function is a dataframe function that lets us apply a function to each row of the dataframe
# this is much more efficient than iterating through the dataframe row by row and applying the function
#   due to the internal architecture of the function
memes['image'] = memes['url'].apply(get_image)

In [None]:
# we'll be using the gildings as labels for the data
# gildings are reddit awards that cost currency, so we know it's a good metric of the quality
# for anything that does have gilding information, convert the dictionary to the total number of awards; we'll use that as the training data
# for anything that doesn't have gilding information, we'll use that as the testing data

def get_num_gildings(gildings):
  # if the gildings is NaN, then it will be a float type
  if type(gildings) == float:
    return -1

  # this is a pythonic way of writing the following code:
  # result = 0
  # for k,v in gildings.items():
  #   result += v
  # return result
  # where gildings.items() returns key value pairs for each pair in the dictionary

  return sum([v for k,v in gildings.items()])

# once again, we use the apply method to make this code run per row but faster
memes['label'] = memes['gildings'].apply(get_num_gildings)

# split into the training and testing datasets
# we filter the dataframe by when the label is >= 0 (training, since we have labels)
# and when the label is == -1 (testing, since we don't have labels and our function returned -1 when we had no labels)
train = memes[memes['label'] >= 0]
test = memes[memes['label'] == -1]

### Let's learn about the data!

We'll be using sklearn's KNeighborsClassifier to perform KNN classification.

If a cell or a line of code has `TODO` in it, make sure to substitute it with the correct code! We provided some links to documentation to see what you should fill in.

In [None]:
# importing the KNN classifier from sklearn, one of the most popular machine learning packages in python
# for documentation: https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt  # lets us plot data

In [None]:
# creating a classifier with k=5
# the k you choose can affect your results! feel free to change this and try it out
classifier = # TODO: make a classifier with k=5! 
# check https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html for more info on how
# you don't need to set any of the other parameters

In [None]:
# split the training and testing data into data (features) and labels
# we typically call the features X and the labels y
# we'll also create another array called images

# the features we'll use for training (a numerical representation of the meme's popularity) are the number of comments and crossposts. 
#   we extract the values from the dataframe then convert it to a numpy array for the classifier
# the labels we will use are just the labels column, and once again we must convert it to a numpy array
X_train = train[['num_comments', 'num_crossposts']].to_numpy(na_value=0)
y_train = train['label'].to_numpy()

X_test = test[['num_comments', 'num_crossposts']].to_numpy(na_value=0)

In [None]:
# let's take a look at what the data looks like in space
plt.scatter(X_train[:,0], X_train[:,1], c=y_train)
plt.colorbar()

plt.title('r/wholesomememes popularity space')
plt.xlabel('number of comments')
plt.ylabel('number of crossposts')

# when you take a look at the plot, pick any blank space on the graph and figure out what label a KNN would give that point for k=1,3,5
# also note that most posts get 0 crossposts -- this is pretty common on reddit
# number of comments will probably be the most informative metric

In [None]:
# fit the training data (that is, traing the KNN classifier on the training data)
# TODO: write code to fit the training data! https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn.neighbors.KNeighborsClassifier.fit




In [None]:
# predicting the labels (number of gildings) for the test data
# TODO: write code to predict on the test data! https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn.neighbors.KNeighborsClassifier.predict
y_pred = 

In [None]:
# we don't have real labels for the images that didn't get gilded because they didn't actually get gilded on reddit
# but if we take a look at the meme itself, we can evaluate for ourselves whether we think the meme is good

def show_image(index, source='test'):
  '''
  outputs the source image given the index of the image from the array
  @param index: index within the array of the sample to show
  @param source: whether the image is a test or train image. defaults to test
  @return a PIL image object of the meme
  '''
  # depending on where the image comes from, we index a different array
  if source == 'test':
    values = X_test[index]
    
    # filtering the array based on the values of the features
    data = test[test['num_comments'] == values[0]]
    data = data[data['num_crossposts'] == values[1]]

  else:
    values = X_train[index]
    data = train[train['num_comments'] == values[0]]
    data = data[data['num_crossposts'] == values[1]]

  try:
    return data['image'].iloc[0]
  except IndexError:
    return 'Couldn\'t find this image, try another index'

In [None]:
# run this cell and the next cell to see the results for a meme of your choice
index = 3 # TODO: change this to see different images and scores!

show_image(index)

In [None]:
# find out the rating of the meme you just showed!
print('The classifier thought that this meme would have gotten ' + str(y_pred[index]) + ' gildings')