Author: Asher Farr

CSCI 4622


Sources:

Captioning Transformer with Stacked Attention Modules: https://www.mdpi.com/2076-3417/8/5/739

Show and Tell: Lessons learned from the 2015 MSCOCO Image Captioning Challenge: https://arxiv.org/abs/1609.06647

https://www.tensorflow.org/tutorials/text/image_captioning

In [None]:
import numpy as np
import pandas as pd
import praw
import tensorflow as tf
import os
import json
from psaw import PushshiftAPI
from profanity_check import predict as profanity_predict
import requests
import datetime as dt
import wget
import multiprocessing

In [None]:
# You shouldn't have to run this unless you need to get new information from reddit. 
# Should just use Json files provided
class fetchRedditData():
    # Requires reddit api information
    def __init__(self, client_id, client_secret, user_agent):
        reddit = praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent)

        self.api = PushshiftAPI(reddit)
    
    # Warning: this takes about 15 minutes
    def pullSubmissions(self):
        roastData = list(self.api.search_submissions(subreddit='roastme',
                            filter=['url', 'title', 'subreddit', "id", "selftext"],
                            num_comments='>20',
                            over_18 = 'false',
                            is_video = 'false',))
        
        return roastData
    
    # Get rid of deleted posts and profane posts
    def filterSubmissions(self, roastData):
        def roastFilter(submission):
            fil = not profanity_predict([submission.title]) and submission.selftext == ""
            return fil
        
        return list(filter(roastFilter, roastData))
    
    # Warning: This takes at least 4 or 5 hours to run
    def fetchComments(self, filteredRoastData, comLimit=5, banndAuthors=('roastbot', 'AutoModerator'), printStatus = True):
        comments = []
        images = []
        
        #Loop creates comment tree for each submission
        for i, sub in enumerate(filteredRoastData):
            if i%100 == 0 and printStatus:
                print(i)
            images.append({"id": sub.id, "url": sub.url, "title": sub.title})

            sub.comment_sort = "top"
            sub.comment_limit = 15
            comTree = sub.comments
            comCount = 0
            
            #get at most 5 comments for each submission
            for i, c in enumerate(comTree):
                if isinstance(c, praw.models.MoreComments):
                    print("Hit More Comments")
                    continue
                if c.author in bannedAuthors:
                    continue
                elif c.score < 1:
                    print("low Score")
                    break

                comments.append({"id":sub.id, "title": sub.title, "body": c.body})
                comCount+=1
                if comCount >= comLimit:
                    break
            if comCount < comLimit:
                print("Less than limit")
        
        return (images, comments)
    
    def saveJson(self, images, comments):
        redditFolder = '/reddit/'
        if not os.path.exists(os.path.abspath('.') + redditFolder):
            os.makedirs(os.path.abspath('.') + redditFolder)

        imagePath = os.path.abspath('.') + redditFolder + "image_list.json"
        commentPath = os.path.abspath('.') + redditFolder + "comment_list.json"
        with open(imagePath, "w+") as write_file:
            json.dump(images, write_file)
        with open(commentPath, "w+") as write_file:
            json.dump(comments, write_file)
            
    def runAll(self):
        print("Gathering submissions")
        roastData = self.pullSubmissions()
        print("Gathered"+len(roastData)+"submissions")
        print("Filtering Submissions")
        filteredRoastData = self.filterSubmissions(roastData)
        print("Submissions after filtering:"+ len(filteredRoastData))
        print("Fetching Comments, hold on to your hat")
        images, comments = self.fetchComments(filteredRoastData)
        print("saving")
        self.saveJson(images, comments)
        

In [None]:
reddit_info = {"client_id":'7xMlWC8bZftcCw', 
               "client_secret":'RVMu0vwIyMVEKdBtPkifPln4WA0', 
               "user_agent":'script:roastbot:v0.1 (by u/the_ninja_99)'}
reddit = fetchRedditData(**reddit_info)
reddit.runAll()

In [None]:
def run_process(url, output_path):
        wget.download(url, out=output_path)

def downloadImages(imagePath):
    cpus = multiprocessing.cpu_count()
    max_pool_size = 6
    pool = multiprocessing.Pool(cpus if cpus < max_pool_size else max_pool_size)
    base_dir = os.path.dirname(os.path.abspath('./reddit'))
    
    with open(imagePath, 'r') as f:
        images = json.load(f)
    
    if not os.path.exists(base_dir + '/images/'):
            os.makedirs(base_dir + '/images/')
    
    download_list = []
    for im in images:
        download_list.append([im['url'], base_dir + '/images/' + im['id'] + ".jpg"])

    for url, path in download_list:
        #print('Beginning file download with wget module {n}'.format(n=url))
        pool.apply_async(run_process, args=(url, path, ))

    pool.close()
    pool.join()
    #print(download_list)
    print("finish")

downloadImages(os.path.abspath('./reddit/image_list.json'))

In [None]:
import imghdr
import shutil

In [None]:
#takes about 20 minutes
def removeBadImages(imagePath, redditPath, outputPath):
    finalImageJson = []
    finalComments = []

    # finals paths
    textPath = os.path.abspath(outputPath+'/textFiles/')
    imPath = os.path.abspath(outputPath + "/images/")
    if not os.path.exists(textPath):
        os.makedirs(textPath)
    if not os.path.exists(imPath):
        os.makedirs(imPath)

    #get json
    with open(redditPath + '/image_list.json', 'r') as f:
        images = json.load(f)
    with open(redditPath + '/comment_list.json', 'r') as f:
        comments = json.load(f)
        
    for f in os.listdir(imagePath):
        imId = f.split('.')[0]
        if imghdr.what(os.path.abspath(imagePath+"/"+f)) is None:
            continue

        shutil.copyfile(os.path.abspath(imagePath+"/"+f), os.path.abspath(imPath+"/"+f))

        for i in images:
            if i['id'] == imId:
                finalImageJson.append(i)
                break
        cMax = 5
        cCount = 0
        for c in comments:
            if c['id'] == imId:
                finalComments.append(c)
                cCount += 1
            if cCount > cMax:
                break
 
    print("finish")

removeBadImages(os.path.abspath('./images/'), os.path.abspath('./reddit/'), os.path.abspath('./finals2/'))

In [None]:
from PIL import Image
def resizeImages(imagePath, savePath):
    print(imagePath)
    if not os.path.exists(savePath):
        os.makedirs(savePath)
    for f in os.listdir(imagePath):
        if imghdr.what(os.path.abspath(imagePath+"/"+f)) is None:
            continue
        print(f)
        im = Image.open(imagePath+"/"+f)
        im = im.resize((299, 299))
        if im.mode in ("RGBA", "P"):
            im = im.convert("RGB")
        im.save(savePath+"/"+f, format='jpeg')
        
        
resizeImages(os.path.abspath('./finals/images/'), os.path.abspath('./finals/resized'))