diff --git a/integration_tests/bchain/getdata_test.py b/integration_tests/bchain/getdata_test.py index b3addbf..34e1c3b 100644 --- a/integration_tests/bchain/getdata_test.py +++ b/integration_tests/bchain/getdata_test.py @@ -63,11 +63,11 @@ def test_get_all_posts_between(steem): def test_scrape_date(steem, temp_dir): yesterday = (pd.datetime.utcnow() - pd.Timedelta(days=1)).date() - p1 = tpbg.scrape_or_load_full_day(yesterday, steem, temp_dir, stop_after=25) + p1 = tpbg.load_or_scrape_full_day(yesterday, steem, temp_dir, stop_after=25) assert len(os.listdir(temp_dir)) == 1 - p2 = tpbg.scrape_or_load_full_day(yesterday, steem, temp_dir, stop_after=25) + p2 = tpbg.load_or_scrape_full_day(yesterday, steem, temp_dir, stop_after=25) assert len(os.listdir(temp_dir)) == 1 @@ -77,11 +77,11 @@ def test_scrape_date(steem, temp_dir): def test_scrape_or_load_data_parallel(temp_dir, steem_kwargs): - frame = tpbg.scrape_or_load_training_data(steem_kwargs, - temp_dir, - days=3, - stop_after=10, - ncores=5) + frame = tpbg.load_or_scrape_training_data(steem_kwargs, + temp_dir, + days=3, + stop_after=10, + ncores=5) assert len(frame) >= 30 diff --git a/integration_tests/preprocessing_test.py b/integration_tests/preprocessing_test.py new file mode 100644 index 0000000..b017632 --- /dev/null +++ b/integration_tests/preprocessing_test.py @@ -0,0 +1,25 @@ +import os + +from pandas.testing import assert_frame_equal +import pandas as pd + +from integration_tests.model_test import temp_dir +from tests.fixtures.random_data import create_n_random_posts +import trufflepig.preprocessing as tppp + + +def test_load_or_preproc(temp_dir): + filename = os.path.join(temp_dir, 'pptest.gz') + + post_frame = pd.DataFrame(create_n_random_posts(10)) + + frame = tppp.load_or_preprocess(post_frame, filename, + ncores=5, chunksize=20) + + assert len(os.listdir(temp_dir)) == 1 + + frame2 = tppp.load_or_preprocess(post_frame, filename, + ncores=5, chunksize=20) + + assert len(os.listdir(temp_dir)) == 1 + assert_frame_equal(frame, frame2) \ No newline at end of file diff --git a/scripts/create_raw_data_fixture.py b/scripts/create_raw_data_fixture.py index 4684293..7f29b0a 100644 --- a/scripts/create_raw_data_fixture.py +++ b/scripts/create_raw_data_fixture.py @@ -10,7 +10,7 @@ def main(): directory = os.path.join(config.PROJECT_DIRECTORY, 'scraped_data') - frames = tpbg.scrape_or_load_training_data(dict(nodes=[config.NODE_URL]), + frames = tpbg.load_or_scrape_training_data(dict(nodes=[config.NODE_URL]), directory, days=20, stop_after=100, diff --git a/scripts/do_cross_val.py b/scripts/do_cross_val.py index b8c6995..545ce8d 100644 --- a/scripts/do_cross_val.py +++ b/scripts/do_cross_val.py @@ -15,7 +15,10 @@ def main(): steem = dict(nodes=[config.NODE_URL]) current_datetime = '2018-02-01' - post_frame = tpgd.scrape_or_load_training_data(steem, directory, + + crossval_filename = os.path.join(directory, 'xval_{}.gz'.format(current_datetime)) + + post_frame = tpgd.load_or_scrape_training_data(steem, directory, current_datetime=current_datetime, days=3, offset_days=0) @@ -26,7 +29,8 @@ def main(): topic_kwargs = dict(num_topics=50, no_below=5, no_above=0.7) - post_frame = tppp.preprocess(post_frame, ncores=4, chunksize=20) + post_frame = tppp.load_or_preprocess(post_frame, crossval_filename, + ncores=4, chunksize=20) param_grid = { 'feature_generation__topic_model__no_above':[0.2, 0.3], diff --git a/tests/filters/stylemeasures_test.py b/tests/filters/stylemeasures_test.py index 5cac9fe..98234b4 100644 --- a/tests/filters/stylemeasures_test.py +++ b/tests/filters/stylemeasures_test.py @@ -7,7 +7,8 @@ def test_count_paragraphs(): def test_detect_language(): - result = tpsm.detect_language('die katze ist klein der hund auch') + detector = tpsm.LanguageDetector() + result = detector.detect_language('die katze ist klein der hund auch') assert result == 'de' diff --git a/trufflepig/bchain/getdata.py b/trufflepig/bchain/getdata.py index fd50af3..9d9d146 100644 --- a/trufflepig/bchain/getdata.py +++ b/trufflepig/bchain/getdata.py @@ -171,7 +171,7 @@ def extract_authors_and_permalinks(operations): author = op[1]['author'] permalink = op[1]['permlink'] authors_and_permalinks.add((author, permalink)) - return list(authors_and_permalinks) + return authors_and_permalinks def get_post_data(authors_and_permalinks, steem): @@ -203,17 +203,20 @@ def get_post_data(authors_and_permalinks, steem): return posts -def get_all_posts_from_block(block_num, steem): +def get_all_posts_from_block(block_num, steem, + exclude_authors_and_permalinks=None): operations = steem.get_ops_in_block(block_num, False) if operations: authors_and_permalinks = extract_authors_and_permalinks(operations) + if exclude_authors_and_permalinks: + authors_and_permalinks -= exclude_authors_and_permalinks if authors_and_permalinks: - return get_post_data(authors_and_permalinks, steem) + return get_post_data(authors_and_permalinks, steem), authors_and_permalinks else: logger.debug('Could not find any posts for block {}'.format(block_num)) else: logger.warning('Could not find any operations for block {}'.format(block_num)) - return [] + return [], set() def get_all_posts_between(start_datetime, end_datetime, steem, @@ -228,8 +231,12 @@ def get_all_posts_between(start_datetime, end_datetime, steem, start_num, block_end_datetime, end_num)) + exclude_authors_and_permalinks = set() for idx, block_num in enumerate(range(start_num, end_num+1)): - posts_in_block = get_all_posts_from_block(block_num, steem) + posts_in_block, authors_and_permalinks = get_all_posts_from_block(block_num, + steem, + exclude_authors_and_permalinks) + exclude_authors_and_permalinks |= authors_and_permalinks posts.extend(posts_in_block) if progressbar(idx, total, percentage_step=1, logger=logger): logger.info('Finished block {} ' @@ -246,8 +253,12 @@ def _get_all_posts_for_blocks_parallel(block_nums, steem_args, stop_after=None): steem = check_and_convert_steem(steem_args) posts = [] + exclude_authors_and_permalinks = set() for block_num in block_nums: - posts_in_block = get_all_posts_from_block(block_num, steem) + posts_in_block, authors_and_permalinks = get_all_posts_from_block(block_num, + steem, + exclude_authors_and_permalinks) + exclude_authors_and_permalinks |= authors_and_permalinks posts.extend(posts_in_block) if stop_after is not None and len(posts) >= stop_after: break @@ -300,7 +311,7 @@ def get_all_posts_between_parallel(start_datetime, end_datetime, steem_args, return posts -def scrape_or_load_full_day(date, steem_or_args, directory, overwrite=False, +def load_or_scrape_full_day(date, steem_or_args, directory, overwrite=False, store=True, stop_after=None, ncores=1): start_datetime = pd.to_datetime(date) end_datetime = start_datetime + pd.Timedelta(days=1) @@ -335,7 +346,7 @@ def config_mp_logging(level=logging.INFO): logging.basicConfig(level=level) -def scrape_or_load_training_data(steem_or_args, directory, +def load_or_scrape_training_data(steem_or_args, directory, days=20, offset_days=8, ncores=8, current_datetime=None, @@ -351,7 +362,7 @@ def scrape_or_load_training_data(steem_or_args, directory, frames = [] for day in range(days): next_date = (start_datetime + pd.Timedelta(days=day)).date() - frame = scrape_or_load_full_day(next_date, steem_or_args, + frame = load_or_scrape_full_day(next_date, steem_or_args, directory, overwrite=False, store=True, stop_after=stop_after, ncores=ncores) diff --git a/trufflepig/filters/stylemeasures.py b/trufflepig/filters/stylemeasures.py index fe74d69..198e5c0 100644 --- a/trufflepig/filters/stylemeasures.py +++ b/trufflepig/filters/stylemeasures.py @@ -9,26 +9,6 @@ def count_paragraphs(text): return text.count('\n\n') + 1 -def detect_language(text, max_length=5000): - """ Detexts text language, returns None in case of failure - - Parameters - ---------- - text: str - max_length: int - Up to max_length characters are considered for the detection - - Returns - ------- - str: language or None in case of failure - - """ - try: - return langdetect.detect(text[:max_length]) - except Exception: - return None - - CAPS = "([A-Z])" PREFIXES = "(Mr|St|Mrs|Ms|Dr)[.]" SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)" @@ -87,3 +67,21 @@ def count_mistakes(self, text): nerrors = len([x for x in self.checker]) return nerrors + +class LanguageDetector(object): + + def __init__(self, max_length=5000, seed=42): + self.max_length = max_length + self.factory = langdetect.DetectorFactory() + self.factory.set_seed(seed) + self.factory.load_profile(langdetect.PROFILES_DIRECTORY) + + def detect_language(self, text): + try: + detector = self.factory.create() + if self.max_length: + text = text[:self.max_length] + detector.append(text) + return detector.detect() + except Exception as e: + return None \ No newline at end of file diff --git a/trufflepig/preprocessing.py b/trufflepig/preprocessing.py index 4311576..a56c6d1 100644 --- a/trufflepig/preprocessing.py +++ b/trufflepig/preprocessing.py @@ -1,6 +1,9 @@ import logging +import os import multiprocessing as mp +import pandas as pd + import trufflepig.filters.stylemeasures as tfsm import trufflepig.filters.textfilters as tftf @@ -8,7 +11,8 @@ def filter_duplicates(frame): - filtered = frame.drop_duplicates(subset=['author', 'permalink']) + filtered = frame.drop_duplicates(subset=['author', 'permalink'], + keep='last') if len(filtered) < len(frame): logger.info('Filtered {} duplicates'.format(len(frame) - len(filtered))) return filtered @@ -29,7 +33,8 @@ def apply_parallel(function, iterable, ncores, chunksize=1000): return results -def preprocess(post_df, ncores=8, chunksize=100): +def preprocess(post_df, ncores=8, chunksize=1000, + detect_seed=42, detect_max_length=3000): logger.info('Filtering duplicates') post_df = filter_duplicates(post_df) @@ -65,7 +70,9 @@ def preprocess(post_df, ncores=8, chunksize=100): len(post_df))) logger.info('Detecting language') - large_post_df.loc[:, 'language'] = apply_parallel(tfsm.detect_language, + detector = tfsm.LanguageDetector(seed=detect_seed, + max_length=detect_max_length) + large_post_df.loc[:, 'language'] = apply_parallel(detector.detect_language, large_post_df.filtered_body, ncores=ncores, chunksize=chunksize) @@ -129,4 +136,23 @@ def preprocess(post_df, ncores=8, chunksize=100): final_df = en_df.dropna() logger.info('Final data set has {} shape'.format(final_df.shape)) - return final_df \ No newline at end of file + return final_df + + +def load_or_preprocess(post_frame, filename, *args, overwrite=False, store=True, + **kwargs): + if os.path.isfile(filename) and not overwrite: + logger.info('Found file {} will load it'.format(filename)) + post_frame = pd.read_pickle(filename, compression='gzip') + else: + logger.info('File {} not found, will start prepocessing'.format(filename)) + post_frame = preprocess(post_frame, *args, **kwargs) + if store: + directory = os.path.dirname(filename) + if not os.path.isdir(directory): + os.makedirs(directory) + logger.info('Storing file {} to disk'.format(filename)) + post_frame.to_pickle(filename, compression='gzip') + return post_frame + +