-
Notifications
You must be signed in to change notification settings - Fork 2
/
configs.py
104 lines (90 loc) · 3.4 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
train_config = """
import trax.layers
import trax.models
import trax.optimizers
import trax.supervised.inputs
import trax.supervised.trainer_lib
# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM
# attn_type = @TimeBinCausalAttention
attn_type = [
@TimeBinCausalAttention,
@TimeBinCausalAttention,
@LSHCausalAttention,
@TimeBinCausalAttention,
]
share_qk = False # LSHCausalAttention ignores this flag and always shares q & k
attn_kv = 128
n_layers = 12
dropout = 0.2
# MemoryEfficientCausalAttention: full attention
# (no hparams to vary between experiments)
# TimeBinCausalAttention: attend to nearby items
TimeBinCausalAttention.n_bins = 512
# LSHCausalAttention: locality-sensitive hashing (LSH) attention
LSHCausalAttention.n_bins = 256
LSHCausalAttention.n_buckets = 512 # Always 2 * n_bins
LSHCausalAttention.n_hashes = 2
LSHCausalAttention.drop_for_hash_rate = 0.0
# Parameters for MultifactorSchedule:
# ==============================================================================
# 0.03125 ~= 1024^-0.5 = d_model^-0.5
MultifactorSchedule.constant = 0.03125
MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay'
MultifactorSchedule.warmup_steps = 2000
# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.9
Adam.b2 = 0.98
Adam.eps = 1e-9
# Parameters for MemoryEfficientCausalAttention:
# ==============================================================================
MemoryEfficientCausalAttention.dropout = 0.0
MemoryEfficientCausalAttention.loop_stride = 256
MemoryEfficientCausalAttention.share_qk = %share_qk
# Parameters for TimeBinCausalAttention:
# ==============================================================================
TimeBinCausalAttention.dropout = 0.2
# TimeBinCausalAttention.n_bins: see top
TimeBinCausalAttention.share_qk = %share_qk
# Parameters for LSHCausalAttention:
# ==============================================================================
LSHCausalAttention.allow_duplicate_attention = False
LSHCausalAttention.attend_across_buckets = True
LSHCausalAttention.rehash_each_round = True
# LSHCausalAttention.n_bins: see top
# LSHCausalAttention.n_buckets: see top
# LSHCausalAttention.n_hashes: see top
LSHCausalAttention.one_rng = False
LSHCausalAttention.hard_k = 0
LSHCausalAttention.dropout = 0.2
# LSHCausalAttention.drop_for_hash_rate: see top
# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 1024
ReformerLM.d_ff = 2048
ReformerLM.dropout = %dropout
ReformerLM.ff_activation = @trax.layers.Relu
ReformerLM.max_len = 65536
ReformerLM.mode = 'train'
ReformerLM.n_heads = 8
ReformerLM.n_layers = %n_layers
ReformerLM.vocab_size = 258 # Includes pad token and unused EOS token
ReformerLM.share_qk = %share_qk
ReformerLM.axial_pos_shape = (128, 512)
ReformerLM.d_axial_pos_embs= (256, 768)
"""
test_config = """
TimeBinCausalAttention.bin_length = 128
TimeBinCausalAttention.n_bins = None
LSHCausalAttention.n_hashes = 8
LSHCausalAttention.bucket_capacity_for_inference = 258
ReformerLM.dropout = 0.0
TimeBinCausalAttention.dropout = 0.0
LSHCausalAttention.dropout = 0.0
"""