Skip to content

Commit

Permalink
Mypy type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
EntilZha committed Jan 4, 2018
1 parent 54597a1 commit 437e996
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 22 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ tb-logs/
.terraform
.ignore
tags
.mypy_cache/
5 changes: 5 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[mypy]
ignore_missing_imports = True

[mypy-pyspark.*]
follow_imports = silent
23 changes: 9 additions & 14 deletions qanta/datasets/quiz_bowl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Iterable, Tuple
from typing import List, Dict, Iterable, Tuple, Optional, Any, Set
import pickle
import csv
import os
Expand Down Expand Up @@ -97,7 +97,7 @@ def add_text(self, sent, text):
def flatten_text(self):
return " ".join(self.text[x] for x in sorted(self.text))

def to_example(self, all_evidence=None) -> Tuple[List[QuestionText], Answer]:
def to_example(self, all_evidence: Optional[Dict[str, Dict[int, Any]]]=None) -> Tuple[List[QuestionText], Answer, Optional[Dict[str, Any]]]:
sentence_list = [self.text[i] for i in range(len(self.text))]
if all_evidence is not None and 'tagme' in all_evidence:
evidence = {'tagme': all_evidence['tagme'][self.qnum]}
Expand Down Expand Up @@ -224,14 +224,11 @@ def questions_by_answer(self, answer):
yield questions[ii]

def questions_with_pages(self) -> Dict[str, List[Question]]:
page_map = {}

page_map = defaultdict(list) # type: Dict[str, List[Question]]
questions = self.query('from questions where page != ""', ()).values()

for q in questions:
page = q.page
if page not in page_map:
page_map[page] = []
page_map[page].append(q)
return page_map

Expand Down Expand Up @@ -278,7 +275,7 @@ def all_answers(self):

class QuizBowlDataset(AbstractDataset):
def __init__(self, *, guesser_train=False, buzzer_train=False,
qb_question_db: str=QB_QUESTION_DB, use_tagme_evidence=False):
qb_question_db: str=QB_QUESTION_DB, use_tagme_evidence=False) -> None:
"""
Initialize a new quiz bowl data set
"""
Expand All @@ -292,7 +289,7 @@ def __init__(self, *, guesser_train=False, buzzer_train=False,
self.db = QuestionDatabase(qb_question_db)
self.guesser_train = guesser_train
self.buzzer_train = buzzer_train
self.training_folds = set()
self.training_folds = set() # type: Set[str]
if self.guesser_train:
self.training_folds.add(c.GUESSER_TRAIN_FOLD)
if self.buzzer_train:
Expand All @@ -301,19 +298,17 @@ def __init__(self, *, guesser_train=False, buzzer_train=False,
self.use_tagme_evidence = use_tagme_evidence
if self.use_tagme_evidence:
with open('output/tagme/tagme.pickle', 'rb') as f:
self.tagme_evidence = pickle.load(f)
self.tagme_evidence = pickle.load(f) # type: Optional[Dict[int, Any]]
else:
self.tagme_evidence = None
self.tagme_evidence = None # type: Optional[Dict[int, Any]]

def training_data(self) -> TrainingData:
from functional import seq
all_questions = seq(self.db.all_questions().values())
if self.use_tagme_evidence:
all_evidence = {
'tagme': self.tagme_evidence
}
all_evidence = {'tagme': self.tagme_evidence} # type: Optional[Dict[str, Any]]
else:
all_evidence = None
all_evidence = None # type: Optional[Dict[str, Any]]

filtered_questions = all_questions\
.filter(lambda q: q.fold in self.training_folds)\
Expand Down
13 changes: 5 additions & 8 deletions qanta/reporting/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,11 @@ def load_audit(audit_file: str, meta_file: str):
with open(audit_file) as audit_f, open(meta_file) as meta_f:
for a_line, m_line in zip(audit_f, meta_f):
qid, evidence = a_line.split('\t')
a_qnum, a_sentence, a_token = qid.split('_')
a_qnum = int(a_qnum)
a_sentence = int(a_sentence)
a_token = int(a_token)
m_qnum, m_sentence, m_token, guess = m_line.split()
m_qnum = int(m_qnum)
m_sentence = int(m_sentence)
m_token = int(m_token)
a_qnum, a_sentence, a_token = [int(t) for t in qid.split('_')]
s_m_qnum, s_m_sentence, s_m_token, guess = m_line.split()
m_qnum = int(s_m_qnum)
m_sentence = int(s_m_sentence)
m_token = int(s_m_token)
if a_qnum != m_qnum or a_sentence != m_sentence or a_token != m_token:
raise ValueError('Error occurred in audit and meta file alignment')
audit_data[(a_qnum, a_sentence, a_token, guess)] = evidence.strip()
Expand Down

0 comments on commit 437e996

Please sign in to comment.