-
Notifications
You must be signed in to change notification settings - Fork 123
/
train_recurrent_atten.py
76 lines (66 loc) · 3.31 KB
/
train_recurrent_atten.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import trainer
from data_processing.qa_data import FixedParagraphQaTrainingData, Batcher
from doc_qa_models import Attention
from encoder import DocumentAndQuestionEncoder, SingleSpanAnswerEncoder
from evaluator import LossEvaluator, BoundedSpanEvaluator, SentenceSpanEvaluator
from nn.attention import StaticAttentionSelf
from nn.attention_recurrent_layers import RecurrentAttention
from nn.embedder import FixedWordEmbedder, CharWordEmbedder, LearnedCharEmbedder
from nn.layers import NullBiMapper, SequenceMapperSeq, FullyConnectedMerge, DropoutLayer
from nn.prediction_layers import ChainPredictor
from nn.recurrent_layers import BiRecurrentMapper, LstmCellSpec, RecurrentEncoder, EncodeOverTime
from nn.similarity_layers import DotProductProject, BiLinear
from trainer import SerializableOptimizer, TrainParams
from squad.squad import SquadCorpus
from utils import get_output_name_from_cli
"""
-> Increasing the dropout (to 0.75) hurts
-> Adding another LSTM between the self attention and prediction hurts
-> Adding a question encoding merge layer between the self attention and is neutral
-> MatchWord features (with shared encoders) makes training faster, but hurts in the long run
"""
def main():
out = get_output_name_from_cli()
train_params = TrainParams(SerializableOptimizer("Adadelta", dict(learning_rate=1.0)),
num_epochs=20, log_period=20, eval_period=1200, save_period=1200,
eval_samples=dict(dev=8000, train=8000))
enc = SequenceMapperSeq(
DropoutLayer(0.8),
BiRecurrentMapper(LstmCellSpec(80)),
DropoutLayer(0.8),
)
model = Attention(
encoder=DocumentAndQuestionEncoder(SingleSpanAnswerEncoder()),
word_embed_layer=None,
word_embed=FixedWordEmbedder(vec_name="glove.840B.300d", word_vec_init_scale=0, learn_unk=False),
char_embed=CharWordEmbedder(
LearnedCharEmbedder(word_size_th=14, char_th=50, char_dim=15, init_scale=0.1),
EncodeOverTime(RecurrentEncoder(LstmCellSpec(50), 'h'), mask=True),
shared_parameters=True
),
embed_mapper=enc,
question_mapper=None,
context_mapper=None,
memory_builder=NullBiMapper(),
attention=RecurrentAttention(LstmCellSpec(80, keep_probs=0.8), BiLinear(80, bias=True)),
match_encoder=SequenceMapperSeq(
BiRecurrentMapper(LstmCellSpec(80, keep_probs=0.8)),
DropoutLayer(0.8),
StaticAttentionSelf(DotProductProject(160, bias=True, scale=True, share_project=True),
FullyConnectedMerge(160)),
),
predictor=ChainPredictor(
start_layer=BiRecurrentMapper(LstmCellSpec(80, keep_probs=0.8)),
end_layer=BiRecurrentMapper(LstmCellSpec(80, keep_probs=0.8))
)
)
with open(__file__, "r") as f:
notes = f.read()
corpus = SquadCorpus()
train_batching = Batcher(45, "bucket_context_words_3", True, False)
eval_batching = Batcher(45, "context_words", False, False)
data = FixedParagraphQaTrainingData(corpus, None, train_batching, eval_batching)
eval = [LossEvaluator(), BoundedSpanEvaluator(bound=[17]), SentenceSpanEvaluator()]
trainer.start_training(data, model, train_params, eval, trainer.ModelDir(out), notes, False)
if __name__ == "__main__":
main()