From a8c1215ffac61c0572e1c6eaa9ed18cc9fc7e604 Mon Sep 17 00:00:00 2001 From: Robert Meyer Date: Sun, 11 Feb 2018 23:16:18 +0100 Subject: [PATCH] New parallel data loading --- integration_tests/bchain/getdata_test.py | 29 ++++++++++---- trufflepig/bchain/getdata.py | 50 ++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/integration_tests/bchain/getdata_test.py b/integration_tests/bchain/getdata_test.py index 7cd729c..49b8f79 100644 --- a/integration_tests/bchain/getdata_test.py +++ b/integration_tests/bchain/getdata_test.py @@ -4,6 +4,7 @@ import tempfile import pandas as pd +from pandas.testing import assert_frame_equal from steem import Steem from steem.blockchain import Blockchain @@ -20,6 +21,11 @@ def bchain(steem): return Blockchain(steem) +@pytest.fixture +def temp_dir(tmpdir_factory): + return tmpdir_factory.mktemp('test', numbered=True) + + def test_get_headers(steem, bchain): offset = bchain.get_current_block_num() now = pd.datetime.utcnow() @@ -51,16 +57,25 @@ def test_get_all_posts_between(steem): assert posts -def test_scrape_date(steem): +def test_scrape_date(steem, temp_dir): yesterday = (pd.datetime.utcnow() - pd.Timedelta(days=1)).date() - directory = tempfile.mkdtemp() - tpbg.scrape_or_load_full_day(yesterday, steem, directory, stop_after=25) + p1 = tpbg.scrape_or_load_full_day(yesterday, steem, temp_dir, stop_after=25) + + assert len(os.listdir(temp_dir)) == 1 + + p2 = tpbg.scrape_or_load_full_day(yesterday, steem, temp_dir, stop_after=25) - assert len(os.listdir(directory)) == 1 + assert len(os.listdir(temp_dir)) == 1 - tpbg.scrape_or_load_full_day(yesterday, steem, directory, stop_after=25) + assert_frame_equal(p1, p2) + assert len(p1) > 0 - assert len(os.listdir(directory)) == 1 - shutil.rmtree(directory) +def test_scrape_or_load_data_parallel(temp_dir): + frames = tpbg.scrape_or_load_training_data_parallel([config.NODE_URL], + temp_dir, + days=5, + stop_after=10, + ncores=5) + assert len(frames) == 5 diff --git a/trufflepig/bchain/getdata.py b/trufflepig/bchain/getdata.py index c0e2c56..495744d 100644 --- a/trufflepig/bchain/getdata.py +++ b/trufflepig/bchain/getdata.py @@ -1,8 +1,10 @@ import logging import os +import multiprocessing as mp from collections import OrderedDict import pandas as pd +from steem import Steem from steem.blockchain import Blockchain from steem.post import Post import json @@ -248,3 +250,51 @@ def scrape_or_load_full_day(date, steem, directory, overwrite=False, logger.info('Storing file {} to disk'.format(filename)) post_frame.to_pickle(filename, compression='gzip') return post_frame + + +def scrape_or_load_full_day_mp(date, node_urls, directory, overwrite=False, + store=True, + stop_after=None): + steem = Steem(nodes=node_urls) + return scrape_or_load_full_day(date=date, + steem=steem, + directory=directory, + overwrite=overwrite, + store=store, + stop_after=stop_after) + + +def config_mp_logging(): + logging.basicConfig(level=logging.INFO) + + +def scrape_or_load_training_data_parallel(node_urls, directory, + days=20, offset=8, + ncores=10, + current_datetime=None, + stop_after=None): + ctx = mp.get_context('fork') + pool = ctx.Pool(ncores, initializer=config_mp_logging) + + if current_datetime is None: + current_datetime = pd.datetime.utcnow() + + start_datetime = current_datetime - pd.Timedelta(days=days + offset) + + async_results = [] + for day in range(days): + next_date = (start_datetime + pd.Timedelta(days=day)).date() + result = pool.apply_async(scrape_or_load_full_day_mp, + args=(next_date, node_urls, directory, + False, True, stop_after)) + async_results.append(result) + + pool.close() + + frames = [] + for async in async_results: + frames.append(async.get(timeout=3600*6)) + + pool.join() + + return frames \ No newline at end of file