In [1]:
import bz2
import datetime
import itertools
import lzma as xz
import logging
import json
import os
import re
import requests

from collections import Counter
from dateutil.relativedelta import relativedelta
from functools import partial
from multiprocessing import Pool
from multiprocessing.dummy import Pool as dPool

from tqdm import tqdm_notebook as tqdm

In [2]:
os.environ['REDDIT_DATA'] = "/media/brian/ColdStore/Datasets/nlp/reddit"

In [3]:
_REDDIT_COMMENT_BASE_URL = "https://files.pushshift.io/reddit/comments/"
_BZ2_FILENAME_TEMPLATE = "RC_%Y-%m.bz2"
_XZ_FILENAME_TEMPLATE = "RC_%Y-%m.xz"
_DATA_START_DATE = datetime.date(2005, 12, 1)
_XZ_START_DATE = datetime.date(2017, 12, 1)


DEFAULT_REDDIT_DATA = os.environ.get('REDDIT_DATA') or os.path.expanduser("~/reddit")
DEFAULT_REDDIT_COMMENTS_DATA = os.path.join(DEFAULT_REDDIT_DATA, "comments")


def populate_reddit_comments_json(dest=DEFAULT_REDDIT_COMMENTS_DATA):
    curr_date = _DATA_START_DATE
    end_date = datetime.date.today() + relativedelta(months=-1)
    dates = []
    while curr_date <= end_date:
        dates.append(curr_date)
        curr_date += relativedelta(months=1)
    download_fn = partial(_download_reddit_comments_json, dest=dest)
    # Using too many processes causes "ERROR 429: Too Many Requests."
    list(multiproc_imap(download_fn,
                        dates,
                        processes=4,
                        thread_only=True,
                        total=len(dates)))


def download_reddit_comments_json(year, month, dest=DEFAULT_REDDIT_COMMENTS_DATA):
    url = get_reddit_comments_url(year, month)
    if not url:
        logging.warning(datetime.date(year, month, 1).strftime("No data exists for %Y-%m."))
        return False
    return download(url, dest=dest)


def load_reddit_comments_json(year, month, root=DEFAULT_REDDIT_COMMENTS_DATA):
    path = get_reddit_comments_local(year, month)
    if not path:
        logging.warning(datetime.date(year, month, 1).strftime("No data exists for %Y-%m."))
        return None
    assert path.endswith('.bz2') or path.endswith('.xz'), (
        "Failed to load {}.Only bz2 and xz are supported.".format(path))
    reader = bz2.BZ2File if path.endswith('.bz2') else xz.LZMAFile
    with reader(path, 'r') as fh:
        for line in fh:
            yield json.loads(line.decode())


def _download_reddit_comments_json(date, dest=DEFAULT_REDDIT_COMMENTS_DATA):
    return download_reddit_comments_json(date.year, date.month, dest=dest)


def get_reddit_comments_url(year, month):
    target_date = datetime.date(year, month, 1)
    url = _get_reddit_comments_path(target_date, _REDDIT_COMMENT_BASE_URL)
    return url
    

def get_reddit_comments_local(year, month, root=DEFAULT_REDDIT_COMMENTS_DATA):
    target_date = datetime.date(year, month, 1)
    path = _get_reddit_comments_path(target_date, root=root)
    return path


def download(url, dest='/tmp/'):
    filename = os.path.basename(url)
    if dest[-1] == '/' or os.path.isdir(dest):
        if not os.path.isdir(dest):
            os.makedirs(dest)
        dest = os.path.join(dest, filename)
    if os.path.isfile(dest):
        logging.info("{} already exist in {}.".format(url, dest))
    else:
        logging.info("Downloading {} to {}...".format(url, dest))
        resp = requests.get(url, stream=True)
        if not resp.ok:        
            logging.warning("{}: {}".format(resp.reason, url))
            return False
        total_size = int(resp.headers.get('content-length', 0)); 
        block_size = 2**20
        with open(dest, 'wb') as fh:
            for data in tqdm(resp.iter_content(block_size),
                             unit="MB",
                             total=total_size//block_size):
                fh.write(data)
    return True    


def multiproc_imap(func,
                   iterable,
                   processes=None,
                   thread_only=False,
                   total=None,
                   chunksize=1):
    pool_fn = dPool if thread_only else Pool
    pool = pool_fn(processes=processes)
    return tqdm(pool.imap(func, iterable, chunksize=chunksize), total=total)


def _get_reddit_comments_path(date, root):
    if not _validate_reddit_comments_date(date):
        return None
    filename = _get_reddit_comments_filename(date)
    path = os.path.join(root, filename)
    return path


def _get_reddit_comments_filename(date):
    if date < _XZ_START_DATE:
        return date.strftime(_BZ2_FILENAME_TEMPLATE)
    else:
        return date.strftime(_XZ_FILENAME_TEMPLATE)


def _validate_reddit_comments_date(date):
    start_date = _DATA_START_DATE
    end_date = datetime.date.today() + relativedelta(months=-1)
    if (date > end_date or date < start_date):
        logging.warning("date must be between {} and {}: given {}".format(
            start_date.strftime("%Y-%m"), 
            end_date.strftime("%Y-%m"), 
            date.strftime("%Y-%m")))
        return False
    return True

# Download Reddit Comments

There's about 450GB of data from 2005-12 to 2018-09, so make sure you have enough disk space.

In [5]:
populate_reddit_comments_json()




# Generate N-grams

In [None]:
DEFAULT_TOKEN_MAX_CHARS = 25


def extract_reddit_comments_upto_ngram_strs(year, month, n):
    """Extract 1- to n-gram simultaneously because file load is the bottleneck."""
    jsons = load_reddit_comments_json(year, month)
    texts = map(lambda d: d['body'], jsons)
    for text in texts:
        upto_ngrams = []
        for m in range(n):
            mgrams = extract_filtered_ngram_strs(text, m)
            upto_ngrams.append(mgrams)
        yield upto_ngrams


def extract_reddit_comments_ngram_strs(year, month, n):
    jsons = load_reddit_comments_json(year, month)
    texts = map(lambda d: d['body'], jsons)
    ngram_strs = map(lambda s: extract_filtered_ngram_strs(s, n), texts)
    return ngram_strs


def extract_filtered_ngram_strs(text, n, tok_max_chars=DEFAULT_TOKEN_MAX_CHARS):
    text_cleaned = re.sub('\s+', ' ', text)
    token_match_str = "[\^ ][^ ]{1,%d}" % tok_max_chars
    ngram_match_str = "(?=(%s))" % (token_match_str * n)
    return re.findall(ngram_match_str, text_cleaned)


def extract_filtered_ngram_strs_slow(text, n, tok_max_chars=DEFAULT_TOKEN_MAX_CHARS):
    ngrams = extract_ngrams(text, n)
    filtered_ngrams = filter(
        lambda ngram: not has_long_token(ngram, tok_max_chars=tok_max_chars), ngrams)
    filtered_ngram_strs = map(
        lambda ngram: ' '.join(ngram), filtered_ngrams)
    return filtered_ngram_strs


def extract_ngrams(text, n):
    tokens = text.split()
    return zip(*[tokens[i:] for i in range(n)])


def has_long_token(tokens, tok_max_chars=DEFAULT_TOKEN_MAX_CHARS):
    for tok in tokens:
        if len(tok) > tok_max_chars:
            return True
    return False

**Benchmark n-gram extraction with regex vs tokenization**

In [None]:
test_string = "asdf " * 1000

In [None]:
%%timeit
list(extract_filtered_ngram_strs(test_string, 5))

In [None]:
%%timeit
list(extract_filtered_ngram_strs_slow(test_string, 5))

**Benchmark file loading vs file loading + n-gram extraction**

In [None]:
%%timeit -r 1 -n 1
_ = list(load_reddit_comments_json(2006, 12))

In [None]:
%%timeit -r 1 -n 1
_ = list(extract_reddit_comments_ngram_strs(2006, 12, 3))