-
Notifications
You must be signed in to change notification settings - Fork 4
/
memory.py
131 lines (100 loc) · 4.81 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
import os
import tensorflow as tf
import math
def linear(x, state_dim, name='linear', reuse=True):
with tf.variable_scope(name) as scope:
if reuse:
scope.reuse_variables()
weight = tf.get_variable('weight', [x.get_shape()[-1], state_dim], initializer=tf.truncated_normal_initializer(stddev=0.02))
bias = tf.get_variable('bias', [state_dim], initializer=tf.constant_initializer(0))
weighted_sum = tf.matmul(x, weight) + bias
return weighted_sum
# This class defines Memory architecture in DKVMN
class DSCMN_Memory():
def __init__(self, memory_size, memory_state_dim, name):
self.name = name
self.memory_size = memory_size
self.memory_state_dim = memory_state_dim
'''
Key or Value matrix
Key matrix is applied for calculating correlation weight
'''
def cor_weight(self, embedded, key_matrix):
'''
embedded : [batch size, memory state dim(d_k)]
Key_matrix : [memory size * memory state dim(d_k)]
Correlation weight : w(i) = k * Key matrix(i)
=> batch size * memory size
'''
embedding_result = tf.matmul(embedded, tf.transpose(key_matrix))
correlation_weight = tf.nn.softmax(embedding_result)
return correlation_weight
# Getting read content
def read(self, value_matrix, correlation_weight):
'''
Value matrix : [batch size ,memory size ,memory state dim]
Correlation weight : [batch size ,memory size], each element represents each concept embedding for 1 question
'''
# Reshaping
# [batch size * memory size, memory state dim(d_v)]
vmtx_reshaped = tf.reshape(value_matrix, [-1, self.memory_state_dim])
# [batch size * memory size, 1]
cw_reshaped = tf.reshape(correlation_weight, [-1,1])
#print('Transformed shape : %s, %s' %(vmtx_reshaped.get_shape(), cw_reshaped.get_shape()))
# Read content, will be [batch size * memory size, memory state dim] and reshape it to [batch size, memory size, memory state dim]
rc = tf.multiply(vmtx_reshaped, cw_reshaped)
read_content = tf.reshape(rc, [-1,self.memory_size,self.memory_state_dim])
# Summation through memory size axis, make it [batch size, memory state dim(d_v)]
read_content = tf.reduce_sum(read_content, axis=1, keep_dims=False)
#print('Read content shape : %s' % (read_content.get_shape()))
return read_content
def write(self, value_matrix, correlation_weight, qa_embedded, reuse=False):
'''
Value matrix : [batch size, memory size, memory state dim(d_k)]
Correlation weight : [batch size, memory size]
qa_embedded : (q, r) pair embedded, [batch size, memory state dim(d_v)]
'''
erase_vector = linear(qa_embedded, self.memory_state_dim, name=self.name+'/Erase_Vector', reuse=reuse)
# [batch size, memory state dim(d_v)]
erase_signal = tf.sigmoid(erase_vector)
add_vector = linear(qa_embedded, self.memory_state_dim, name=self.name+'/Add_Vector', reuse=reuse)
# [batch size, memory state dim(d_v)]
add_signal = tf.tanh(add_vector)
# Add vector after erase
# [batch size, 1, memory state dim(d_v)]
erase_reshaped = tf.reshape(erase_signal, [-1,1,self.memory_state_dim])
# [batch size, memory size, 1]
cw_reshaped = tf.reshape(correlation_weight, [-1,self.memory_size,1])
# w_t(i) * e_t
erase_mul = tf.multiply(erase_reshaped, cw_reshaped)
# Elementwise multiply between [batch size, memory size, memory state dim(d_v)]
erase = value_matrix * (1 - erase_mul)
# [batch size, 1, memory state dim(d_v)]
add_reshaped = tf.reshape(add_signal, [-1, 1, self.memory_state_dim])
add_mul = tf.multiply(add_reshaped, cw_reshaped)
new_memory = erase + add_mul
# [batch size, memory size, memory value staet dim]
#print('Memory shape : %s' % (new_memory.get_shape()))
return new_memory
# This class construct key matrix and value matrix
class DSCMN():
def __init__(self, memory_size, memory_key_state_dim, memory_value_state_dim, init_memory_key, init_memory_value, name='DSCMN'):
print('Initializing memory..')
self.name = name
self.memory_size = memory_size
self.memory_key_state_dim = memory_key_state_dim
self.memory_value_state_dim = memory_value_state_dim
self.key = DSCMN_Memory(self.memory_size, self.memory_key_state_dim, name=self.name+'_key_matrix')
self.value = DSCMN_Memory(self.memory_size, self.memory_value_state_dim, name=self.name+'_value_matrix')
self.memory_key = init_memory_key
self.memory_value = init_memory_value
def attention(self, q_embedded):
correlation_weight = self.key.cor_weight(embedded=q_embedded, key_matrix=self.memory_key)
return correlation_weight
def read(self, c_weight):
read_content = self.value.read(value_matrix=self.memory_value, correlation_weight=c_weight)
return read_content
def write(self, c_weight, qa_embedded, reuse):
self.memory_value = self.value.write(value_matrix=self.memory_value, correlation_weight=c_weight, qa_embedded=qa_embedded, reuse=reuse)
return self.memory_value