Skip to content

Commit

Permalink
Small refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert Meyer committed Feb 13, 2018
1 parent 59d0712 commit d17128b
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 19 deletions.
40 changes: 40 additions & 0 deletions scripts/do_cross_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging
import os

from steem import Steem

import trufflepig.model as tpmo
import trufflepig.preprocessing as tppp
import trufflepig.bchain.getdata as tpgd
from trufflepig import config


def main():
logging.basicConfig(level=logging.INFO)
directory = os.path.join(config.PROJECT_DIRECTORY, 'scraped_data')

steem = dict(nodes=[config.NODE_URL])
current_datetime = '2018-02-01'
post_frame = tpgd.scrape_or_load_training_data(steem, directory,
current_datetime=current_datetime,
days=3,
offset_days=0)

regressor_kwargs = dict(n_estimators=20, max_leaf_nodes=100,
max_features=0.1, n_jobs=-1, verbose=1,
random_state=42)

topic_kwargs = dict(num_topics=50, no_below=5, no_above=0.7)

post_frame = tppp.preprocess(post_frame, ncores=4, chunksize=20)

param_grid = {
'feature_generation__topic_model__no_above':[0.2, 0.3],
'regressor__max_leaf_nodes': [50, 100],
}

tpmo.cross_validate(post_frame, param_grid, topic_kwargs=topic_kwargs,
regressor_kwargs=regressor_kwargs)

if __name__ == '__main__':
main()
Empty file added tests/filters/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import trufflepig.stylemeasures as tpsm
import trufflepig.filters.stylemeasures as tpsm


def test_count_paragraphs():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import trufflepig.textfilters as tptf
import trufflepig.filters.textfilters as tptf


def test_filter_html_tags():
Expand Down
29 changes: 17 additions & 12 deletions trufflepig/bchain/getdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
from steem import Steem
from steem.blockchain import Blockchain
from steem.post import Post
from steem.post import Post, PostDoesNotExist
import json
from json import JSONDecodeError

Expand Down Expand Up @@ -179,7 +179,12 @@ def get_post_data(authors_and_permalinks, steem):
for kdx, (author, permalink) in enumerate(authors_and_permalinks):
try:
p = Post('@{}/{}'.format(author, permalink), steem)
except Exception as e:
except PostDoesNotExist:
# This happens to oftern we will suppress this
logger.debug('Post {} by {} does not exist!'.format(author,
permalink))
continue
except Exception:
logger.exception('Error in loading post {} by {}'.format(author,
permalink))
continue
Expand Down Expand Up @@ -213,15 +218,15 @@ def get_all_posts_from_block(block_num, steem):

def get_all_posts_between(start_datetime, end_datetime, steem,
stop_after=None):
start_num, _ = find_nearest_block_num(start_datetime, steem)
end_num, _ = find_nearest_block_num(end_datetime, steem)
start_num, block_start_datetime = find_nearest_block_num(start_datetime, steem)
end_num, block_end_datetime = find_nearest_block_num(end_datetime, steem)

total = end_num - start_num
posts = []
logger.info('Querying all posts between '
'{} (block {}) and {} (block {})'.format(start_datetime,
'{} (block {}) and {} (block {})'.format(block_start_datetime,
start_num,
end_datetime,
block_end_datetime,
end_num))
for idx, block_num in enumerate(range(start_num, end_num+1)):
posts_in_block = get_all_posts_from_block(block_num, steem)
Expand Down Expand Up @@ -253,14 +258,14 @@ def get_all_posts_between_parallel(start_datetime, end_datetime, steem_args,
stop_after=None, ncores=8,
chunksize=20, timeout=3600):
steem = check_and_convert_steem(steem_args)
start_num, _ = find_nearest_block_num(start_datetime, steem)
end_num, _ = find_nearest_block_num(end_datetime, steem)
start_num, block_start_datetime = find_nearest_block_num(start_datetime, steem)
end_num, block_end_datetime = find_nearest_block_num(end_datetime, steem)

logger.info('Querying IN PARALLEL with {} cores all posts between'
logger.info('Querying IN PARALLEL with {} cores all posts between '
'{} (block {}) and {} (block {})'.format(ncores,
start_datetime,
block_start_datetime,
start_num,
end_datetime,
block_end_datetime,
end_num))
block_nums = list(range(start_num, end_num + 1))
chunks = [block_nums[irun: irun + chunksize]
Expand All @@ -284,7 +289,7 @@ def get_all_posts_between_parallel(start_datetime, end_datetime, steem_args,
try:
new_posts = async.get(timeout=timeout)
posts.extend(new_posts)
if progressbar(kdx, len(chunks), percentage_step=1, logger=logger):
if progressbar(kdx, len(chunks), percentage_step=5, logger=logger):
logger.info('Finished chunk {} '
'out of {} found so far {} '
'posts...'.format(kdx + 1, len(chunks), len(posts)))
Expand Down
Empty file added trufflepig/filters/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
7 changes: 2 additions & 5 deletions trufflepig/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import logging
import multiprocessing as mp

import pandas as pd

import trufflepig.textfilters as tftf
import trufflepig.stylemeasures as tfsm

import trufflepig.filters.stylemeasures as tfsm
import trufflepig.filters.textfilters as tftf

logger = logging.getLogger(__name__)

Expand Down

0 comments on commit d17128b

Please sign in to comment.