Skip to content

Commit

Permalink
Classify learning
Browse files Browse the repository at this point in the history
  • Loading branch information
tribela committed Mar 17, 2016
1 parent e01db88 commit 2b3e7f8
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 206 deletions.
15 changes: 7 additions & 8 deletions autotweet/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
import tweepy

from .daemons import answer_daemon, import_timeline, learning_daemon
from .database import get_session
from .learning import add_document, recalc_idfs, recreate_grams
from .learning import DataCollection
from .twitter import authorize, CONSUMER_KEY, CONSUMER_SECRET, OAuthToken


Expand Down Expand Up @@ -87,9 +86,9 @@ def add_command(args, config):
question = args.question.decode('utf-8')
answer = args.answer.decode('utf-8')

session = get_session(db_url)
data_collection = DataCollection(db_url)

add_document(session, question, answer)
data_collection.add_document(question, answer)


def import_command(args, config):
Expand All @@ -101,14 +100,14 @@ def import_command(args, config):

def recalc_command(args, config):
db_url = config.get('database', 'db_url')
session = get_session(db_url)
recalc_idfs(session)
data_collection = DataCollection(db_url)
data_collection.recalc_idfs()


def recreate_command(args, config):
db_url = config.get('database', 'db_url')
session = get_session(db_url)
recreate_grams(session)
data_collection = DataCollection(db_url)
data_collection.recreate_grams()


parser = argparse.ArgumentParser(prog='autotweet')
Expand Down
34 changes: 11 additions & 23 deletions autotweet/daemons.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import time
import tweepy

from sqlalchemy.exc import OperationalError

from .database import get_session
from .learning import NoAnswerError, add_document, get_best_answer
from .learning import NoAnswerError, DataCollection
from .twitter import (CONSUMER_KEY, CONSUMER_SECRET, OAuthToken, expand_url,
strip_tweet)

Expand Down Expand Up @@ -42,8 +39,7 @@ class CollectorMentionListener(tweepy.streaming.StreamListener):
def __init__(self, api, db_url):
super(CollectorMentionListener, self).__init__()
self.api = api
self.db_url = db_url
self.db_session = get_session(db_url)
self.data_collection = DataCollection(db_url)
self.me = api.me()

def on_status(self, status):
Expand All @@ -57,17 +53,13 @@ def on_status(self, status):
answer = strip_tweet(expand_url(status.text), remove_url=False)

if question and answer:
try:
add_document(self.db_session, question, answer)
except OperationalError:
self.db_session = get_session(self.db_url)
add_document(self.db_session, question, answer)
self.data_collection.add_document(question, answer)

return True


def collector_polling_timeline(api, db_url):
db_session = get_session(db_url)
data_collection = DataCollection(db_url)
me = api.me()
last_id = me.status.id

Expand All @@ -94,11 +86,7 @@ def collector_polling_timeline(api, db_url):
answer = strip_tweet(status.text, remove_url=False)

if question and answer:
try:
add_document(db_session, question, answer)
except OperationalError:
db_session = get_session(db_url)
add_document(db_session, question, answer)
data_collection.add_document(question, answer)


def import_timeline(token, db_url, count):
Expand All @@ -109,7 +97,7 @@ def import_timeline(token, db_url, count):
auth.set_access_token(token.key, token.secret)
api = tweepy.API(auth)

db_session = get_session(db_url)
data_collection = DataCollection(db_url)
me = api.me()

statuses = me.timeline(count=count)
Expand All @@ -130,7 +118,7 @@ def import_timeline(token, db_url, count):
answer = strip_tweet(status.text, remove_url=False)

if question and answer:
add_document(db_session, question, answer)
data_collection.add_document(question, answer)


def learning_daemon(token, db_url, streaming=False):
Expand Down Expand Up @@ -178,7 +166,7 @@ class AnswerMentionListener(tweepy.streaming.StreamListener):
def __init__(self, api, db_url, threshold=None):
super(AnswerMentionListener, self).__init__()
self.api = api
self.db_session = get_session(db_url)
self.data_collection = DataCollection(db_url)
self.me = api.me()

if threshold:
Expand All @@ -202,7 +190,7 @@ def on_status(self, status):
status_id = status.id

try:
answer, ratio = get_best_answer(self.db_session, question)
answer, ratio = self.data_collection.get_best_answer(question)
except NoAnswerError:
return True

Expand All @@ -221,7 +209,7 @@ def on_status(self, status):


def answer_polling_timeline(api, db_url, threshold=None):
db_session = get_session(db_url)
data_collection = DataCollection(db_url)
me = api.me()
threshold = threshold or DEFAULT_THRESHOLD

Expand Down Expand Up @@ -258,7 +246,7 @@ def answer_polling_timeline(api, db_url, threshold=None):
mentions = get_mentions(status, friends)

try:
(answer, ratio) = get_best_answer(db_session, question)
(answer, ratio) = data_collection.get_best_answer(question)
except NoAnswerError:
pass

Expand Down

0 comments on commit 2b3e7f8

Please sign in to comment.