Skip to content

Commit

Permalink
better handling of duplicate posts (i.e. edits)
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert Meyer committed Feb 13, 2018
1 parent d17128b commit 0c67507
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 44 deletions.
14 changes: 7 additions & 7 deletions integration_tests/bchain/getdata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def test_get_all_posts_between(steem):
def test_scrape_date(steem, temp_dir):
yesterday = (pd.datetime.utcnow() - pd.Timedelta(days=1)).date()

p1 = tpbg.scrape_or_load_full_day(yesterday, steem, temp_dir, stop_after=25)
p1 = tpbg.load_or_scrape_full_day(yesterday, steem, temp_dir, stop_after=25)

assert len(os.listdir(temp_dir)) == 1

p2 = tpbg.scrape_or_load_full_day(yesterday, steem, temp_dir, stop_after=25)
p2 = tpbg.load_or_scrape_full_day(yesterday, steem, temp_dir, stop_after=25)

assert len(os.listdir(temp_dir)) == 1

Expand All @@ -77,11 +77,11 @@ def test_scrape_date(steem, temp_dir):

def test_scrape_or_load_data_parallel(temp_dir, steem_kwargs):

frame = tpbg.scrape_or_load_training_data(steem_kwargs,
temp_dir,
days=3,
stop_after=10,
ncores=5)
frame = tpbg.load_or_scrape_training_data(steem_kwargs,
temp_dir,
days=3,
stop_after=10,
ncores=5)
assert len(frame) >= 30


Expand Down
25 changes: 25 additions & 0 deletions integration_tests/preprocessing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os

from pandas.testing import assert_frame_equal
import pandas as pd

from integration_tests.model_test import temp_dir
from tests.fixtures.random_data import create_n_random_posts
import trufflepig.preprocessing as tppp


def test_load_or_preproc(temp_dir):
filename = os.path.join(temp_dir, 'pptest.gz')

post_frame = pd.DataFrame(create_n_random_posts(10))

frame = tppp.load_or_preprocess(post_frame, filename,
ncores=5, chunksize=20)

assert len(os.listdir(temp_dir)) == 1

frame2 = tppp.load_or_preprocess(post_frame, filename,
ncores=5, chunksize=20)

assert len(os.listdir(temp_dir)) == 1
assert_frame_equal(frame, frame2)
2 changes: 1 addition & 1 deletion scripts/create_raw_data_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def main():
directory = os.path.join(config.PROJECT_DIRECTORY, 'scraped_data')


frames = tpbg.scrape_or_load_training_data(dict(nodes=[config.NODE_URL]),
frames = tpbg.load_or_scrape_training_data(dict(nodes=[config.NODE_URL]),
directory,
days=20,
stop_after=100,
Expand Down
8 changes: 6 additions & 2 deletions scripts/do_cross_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ def main():

steem = dict(nodes=[config.NODE_URL])
current_datetime = '2018-02-01'
post_frame = tpgd.scrape_or_load_training_data(steem, directory,

crossval_filename = os.path.join(directory, 'xval_{}.gz'.format(current_datetime))

post_frame = tpgd.load_or_scrape_training_data(steem, directory,
current_datetime=current_datetime,
days=3,
offset_days=0)
Expand All @@ -26,7 +29,8 @@ def main():

topic_kwargs = dict(num_topics=50, no_below=5, no_above=0.7)

post_frame = tppp.preprocess(post_frame, ncores=4, chunksize=20)
post_frame = tppp.load_or_preprocess(post_frame, crossval_filename,
ncores=4, chunksize=20)

param_grid = {
'feature_generation__topic_model__no_above':[0.2, 0.3],
Expand Down
3 changes: 2 additions & 1 deletion tests/filters/stylemeasures_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ def test_count_paragraphs():


def test_detect_language():
result = tpsm.detect_language('die katze ist klein der hund auch')
detector = tpsm.LanguageDetector()
result = detector.detect_language('die katze ist klein der hund auch')
assert result == 'de'


Expand Down
29 changes: 20 additions & 9 deletions trufflepig/bchain/getdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def extract_authors_and_permalinks(operations):
author = op[1]['author']
permalink = op[1]['permlink']
authors_and_permalinks.add((author, permalink))
return list(authors_and_permalinks)
return authors_and_permalinks


def get_post_data(authors_and_permalinks, steem):
Expand Down Expand Up @@ -203,17 +203,20 @@ def get_post_data(authors_and_permalinks, steem):
return posts


def get_all_posts_from_block(block_num, steem):
def get_all_posts_from_block(block_num, steem,
exclude_authors_and_permalinks=None):
operations = steem.get_ops_in_block(block_num, False)
if operations:
authors_and_permalinks = extract_authors_and_permalinks(operations)
if exclude_authors_and_permalinks:
authors_and_permalinks -= exclude_authors_and_permalinks
if authors_and_permalinks:
return get_post_data(authors_and_permalinks, steem)
return get_post_data(authors_and_permalinks, steem), authors_and_permalinks
else:
logger.debug('Could not find any posts for block {}'.format(block_num))
else:
logger.warning('Could not find any operations for block {}'.format(block_num))
return []
return [], set()


def get_all_posts_between(start_datetime, end_datetime, steem,
Expand All @@ -228,8 +231,12 @@ def get_all_posts_between(start_datetime, end_datetime, steem,
start_num,
block_end_datetime,
end_num))
exclude_authors_and_permalinks = set()
for idx, block_num in enumerate(range(start_num, end_num+1)):
posts_in_block = get_all_posts_from_block(block_num, steem)
posts_in_block, authors_and_permalinks = get_all_posts_from_block(block_num,
steem,
exclude_authors_and_permalinks)
exclude_authors_and_permalinks |= authors_and_permalinks
posts.extend(posts_in_block)
if progressbar(idx, total, percentage_step=1, logger=logger):
logger.info('Finished block {} '
Expand All @@ -246,8 +253,12 @@ def _get_all_posts_for_blocks_parallel(block_nums, steem_args,
stop_after=None):
steem = check_and_convert_steem(steem_args)
posts = []
exclude_authors_and_permalinks = set()
for block_num in block_nums:
posts_in_block = get_all_posts_from_block(block_num, steem)
posts_in_block, authors_and_permalinks = get_all_posts_from_block(block_num,
steem,
exclude_authors_and_permalinks)
exclude_authors_and_permalinks |= authors_and_permalinks
posts.extend(posts_in_block)
if stop_after is not None and len(posts) >= stop_after:
break
Expand Down Expand Up @@ -300,7 +311,7 @@ def get_all_posts_between_parallel(start_datetime, end_datetime, steem_args,
return posts


def scrape_or_load_full_day(date, steem_or_args, directory, overwrite=False,
def load_or_scrape_full_day(date, steem_or_args, directory, overwrite=False,
store=True, stop_after=None, ncores=1):
start_datetime = pd.to_datetime(date)
end_datetime = start_datetime + pd.Timedelta(days=1)
Expand Down Expand Up @@ -335,7 +346,7 @@ def config_mp_logging(level=logging.INFO):
logging.basicConfig(level=level)


def scrape_or_load_training_data(steem_or_args, directory,
def load_or_scrape_training_data(steem_or_args, directory,
days=20, offset_days=8,
ncores=8,
current_datetime=None,
Expand All @@ -351,7 +362,7 @@ def scrape_or_load_training_data(steem_or_args, directory,
frames = []
for day in range(days):
next_date = (start_datetime + pd.Timedelta(days=day)).date()
frame = scrape_or_load_full_day(next_date, steem_or_args,
frame = load_or_scrape_full_day(next_date, steem_or_args,
directory, overwrite=False,
store=True, stop_after=stop_after,
ncores=ncores)
Expand Down
38 changes: 18 additions & 20 deletions trufflepig/filters/stylemeasures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,6 @@ def count_paragraphs(text):
return text.count('\n\n') + 1


def detect_language(text, max_length=5000):
""" Detexts text language, returns None in case of failure
Parameters
----------
text: str
max_length: int
Up to max_length characters are considered for the detection
Returns
-------
str: language or None in case of failure
"""
try:
return langdetect.detect(text[:max_length])
except Exception:
return None


CAPS = "([A-Z])"
PREFIXES = "(Mr|St|Mrs|Ms|Dr)[.]"
SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)"
Expand Down Expand Up @@ -87,3 +67,21 @@ def count_mistakes(self, text):
nerrors = len([x for x in self.checker])
return nerrors


class LanguageDetector(object):

def __init__(self, max_length=5000, seed=42):
self.max_length = max_length
self.factory = langdetect.DetectorFactory()
self.factory.set_seed(seed)
self.factory.load_profile(langdetect.PROFILES_DIRECTORY)

def detect_language(self, text):
try:
detector = self.factory.create()
if self.max_length:
text = text[:self.max_length]
detector.append(text)
return detector.detect()
except Exception as e:
return None
34 changes: 30 additions & 4 deletions trufflepig/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import logging
import os
import multiprocessing as mp

import pandas as pd

import trufflepig.filters.stylemeasures as tfsm
import trufflepig.filters.textfilters as tftf

logger = logging.getLogger(__name__)


def filter_duplicates(frame):
filtered = frame.drop_duplicates(subset=['author', 'permalink'])
filtered = frame.drop_duplicates(subset=['author', 'permalink'],
keep='last')
if len(filtered) < len(frame):
logger.info('Filtered {} duplicates'.format(len(frame) - len(filtered)))
return filtered
Expand All @@ -29,7 +33,8 @@ def apply_parallel(function, iterable, ncores, chunksize=1000):
return results


def preprocess(post_df, ncores=8, chunksize=100):
def preprocess(post_df, ncores=8, chunksize=1000,
detect_seed=42, detect_max_length=3000):
logger.info('Filtering duplicates')
post_df = filter_duplicates(post_df)

Expand Down Expand Up @@ -65,7 +70,9 @@ def preprocess(post_df, ncores=8, chunksize=100):
len(post_df)))

logger.info('Detecting language')
large_post_df.loc[:, 'language'] = apply_parallel(tfsm.detect_language,
detector = tfsm.LanguageDetector(seed=detect_seed,
max_length=detect_max_length)
large_post_df.loc[:, 'language'] = apply_parallel(detector.detect_language,
large_post_df.filtered_body,
ncores=ncores,
chunksize=chunksize)
Expand Down Expand Up @@ -129,4 +136,23 @@ def preprocess(post_df, ncores=8, chunksize=100):
final_df = en_df.dropna()
logger.info('Final data set has {} shape'.format(final_df.shape))

return final_df
return final_df


def load_or_preprocess(post_frame, filename, *args, overwrite=False, store=True,
**kwargs):
if os.path.isfile(filename) and not overwrite:
logger.info('Found file {} will load it'.format(filename))
post_frame = pd.read_pickle(filename, compression='gzip')
else:
logger.info('File {} not found, will start prepocessing'.format(filename))
post_frame = preprocess(post_frame, *args, **kwargs)
if store:
directory = os.path.dirname(filename)
if not os.path.isdir(directory):
os.makedirs(directory)
logger.info('Storing file {} to disk'.format(filename))
post_frame.to_pickle(filename, compression='gzip')
return post_frame


0 comments on commit 0c67507

Please sign in to comment.