-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
140 lines (115 loc) · 5.06 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import theano
import theano.tensor as T
import numpy as np
from theano.printing import Print
from theano_toolkit import utils as U
from theano_toolkit.parameters import Parameters
import controller
import head
import scipy
def cosine_sim(k,M):
"""
Cosine similarity used for content addressing
"""
k_unit = k / ( T.sqrt(T.sum(k**2)) + 1e-5 )
k_unit = k_unit.dimshuffle(('x',0))
k_unit.name = "k_unit"
M_lengths = T.sqrt(T.sum(M**2,axis=1)).dimshuffle((0,'x'))
M_unit = M / ( M_lengths + 1e-5 )
M_unit.name = "M_unit"
return T.sum(k_unit * M_unit,axis=1)
def build_step(P,controller,controller_size,mem_size,mem_width,similarity=cosine_sim,shift_width=3):
# Set of shift indices (for shift_width=3, have shift offsets of -1, 0, and +1)
shift_conv = scipy.linalg.circulant(np.arange(mem_size)).T[np.arange(-(shift_width//2),(shift_width//2)+1)][::-1]
# Initial N X M memory: M_0
P.memory_init = 2 * (np.random.rand(mem_size,mem_width) - 0.5)
memory_init = P.memory_init
# Initial N-dim weight vector: w_0
P.weight_init = np.random.randn(mem_size)
weight_init = U.vector_softmax(P.weight_init)
# heads is a function taking the hidden layer of the controller and
# computes the key, key strength, interpolation gate,
# sharpening factor, and erase and add vectors as outputs
heads = head.build(P,controller_size,mem_width,mem_size,shift_width)
def build_memory_curr(M_prev,erase_head,add_head,weight):
"""
Update memory with write consisting of erase and add
(described in section 3.2 in paper)
"""
weight = weight.dimshuffle((0,'x'))
erase_head = erase_head.dimshuffle(('x',0))
add_head = add_head.dimshuffle(('x',0))
# Equation (3)
M_erased = M_prev * (1 - (weight * erase_head))
# Equation (4)
M_curr = M_erased + (weight * add_head)
return M_curr
def build_read(M_curr,weight_curr):
"""
Obtain read vector r_t (Equation (2) in paper)
"""
return T.dot(weight_curr, M_curr)
def shift_convolve(weight,shift):
"""
Circular convolution (Equation (8) in paper)
"""
shift = shift.dimshuffle((0,'x'))
return T.sum(shift * weight[shift_conv],axis=0)
def build_head_curr(weight_prev,M_curr,head,input_curr):
"""
Implement addressing mechanism shown in Figure 2 in paper.
Also return add and erase vectors computed by head.
"""
# input_curr is hidden layer from controller
# this is passing the hidden layer into the heads layer
# which computes key, beta, g, shift, gamma, erase, and add
# as outputs (see head_params in head.py)
key,beta,g,shift,gamma,erase,add = head(input_curr)
# 3.3.1 Focusing b Content (Equation (5))
weight_c = U.vector_softmax(beta * similarity(key,M_curr))
weight_c.name = "weight_c"
# 3.3.2 Focusing by Location (Equation (7))
weight_g = g * weight_c + (1 - g) * weight_prev
weight_g.name = "weight_g"
# Equation (8)
weight_shifted = shift_convolve(weight_g,shift)
# Equation (9)
weight_sharp = weight_shifted ** gamma
weight_curr = weight_sharp / T.sum(weight_sharp)
return weight_curr,erase,add
def step(input_curr,M_prev,weight_prev):
"""
Update the weights and memory from the previous time step
given the current input
"""
# Get read vector r_t
read_prev = build_read(M_prev,weight_prev)
# Feed current input and read input to controller to get
# controller output and hidden layer of controller
output,controller_hidden = controller(input_curr,read_prev)
# Obtain new weight vector (as described in figure 2) and erase and add vectors
weight_curr,erase,add = build_head_curr(weight_prev,M_prev,heads,controller_hidden)
# Update memory with current weight, erase, and add vectors (Section 3.2 in paper)
M_curr = build_memory_curr(M_prev,erase,add,weight_curr)
return M_curr,weight_curr,output
return step,[memory_init,weight_init,None]
def build(P,mem_size,mem_width,controller_size,ctrl):
"""
Build model for prediction.
"""
# step is a function that takes the current external input and state (memory and weight vector)
# and returns an updated memory and state
# outputs_info consists of the initial memory and weights
step,outputs_info = build_step(P,ctrl,controller_size,mem_size,mem_width)
def predict(input_sequence):
"""
Use NTM to predict outputs given input_sequence.
"""
outputs,_ = theano.scan(
step, # apply step to input_sequence
sequences = [input_sequence],
outputs_info = outputs_info
)
# output is current memory, weight, and output (from step)
return outputs
return predict