-
Notifications
You must be signed in to change notification settings - Fork 25
/
attention_models.py
247 lines (200 loc) · 10 KB
/
attention_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
"""
Copyright (C) 2022 King Saud University, Saudi Arabia
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
Author: Hamdi Altaheri
"""
#%%
import math
import tensorflow as tf
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense
from tensorflow.keras.layers import multiply, Permute, Concatenate, Conv2D, Add, Activation, Lambda
from tensorflow.keras.layers import Dropout, MultiHeadAttention, LayerNormalization, Reshape
from tensorflow.keras import backend as K
#%% Create and apply the attention model
def attention_block(in_layer, attention_model, ratio=8, residual = False, apply_to_input=True):
in_sh = in_layer.shape # dimensions of the input tensor
in_len = len(in_sh)
expanded_axis = 2 # defualt = 2
if attention_model == 'mha': # Multi-head self attention layer
if(in_len > 3):
in_layer = Reshape((in_sh[1],-1))(in_layer)
out_layer = mha_block(in_layer)
elif attention_model == 'mhla': # Multi-head local self-attention layer
if(in_len > 3):
in_layer = Reshape((in_sh[1],-1))(in_layer)
out_layer = mha_block(in_layer, vanilla = False)
elif attention_model == 'se': # Squeeze-and-excitation layer
if(in_len < 4):
in_layer = tf.expand_dims(in_layer, axis=expanded_axis)
out_layer = se_block(in_layer, ratio, residual, apply_to_input)
elif attention_model == 'cbam': # Convolutional block attention module
if(in_len < 4):
in_layer = tf.expand_dims(in_layer, axis=expanded_axis)
out_layer = cbam_block(in_layer, ratio=ratio, residual = residual)
else:
raise Exception("'{}' is not supported attention module!".format(attention_model))
if (in_len == 3 and len(out_layer.shape) == 4):
out_layer = tf.squeeze(out_layer, expanded_axis)
elif (in_len == 4 and len(out_layer.shape) == 3):
out_layer = Reshape((in_sh[1], in_sh[2], in_sh[3]))(out_layer)
return out_layer
#%% Multi-head self Attention (MHA) block
def mha_block(input_feature, key_dim=8, num_heads=2, dropout = 0.5, vanilla = True):
"""Multi Head self Attention (MHA) block.
Here we include two types of MHA blocks:
The original multi-head self-attention as described in https://arxiv.org/abs/1706.03762
The multi-head local self attention as described in https://arxiv.org/abs/2112.13492v1
"""
# Layer normalization
x = LayerNormalization(epsilon=1e-6)(input_feature)
if vanilla:
# Create a multi-head attention layer as described in
# 'Attention Is All You Need' https://arxiv.org/abs/1706.03762
x = MultiHeadAttention(key_dim = key_dim, num_heads = num_heads, dropout = dropout)(x, x)
else:
# Create a multi-head local self-attention layer as described in
# 'Vision Transformer for Small-Size Datasets' https://arxiv.org/abs/2112.13492v1
# Build the diagonal attention mask
NUM_PATCHES = input_feature.shape[1]
diag_attn_mask = 1 - tf.eye(NUM_PATCHES)
diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)
# Create a multi-head local self attention layer.
# x = MultiHeadAttention_LSA(key_dim = key_dim, num_heads = num_heads, dropout = dropout)(
# x, x, attention_mask = diag_attn_mask)
x = MultiHeadAttention_LSA(key_dim = key_dim, num_heads = num_heads, dropout = dropout)(
x, x, attention_mask = diag_attn_mask)
x = Dropout(0.3)(x)
# Skip connection
mha_feature = Add()([input_feature, x])
return mha_feature
#%% Multi head self Attention (MHA) block: Locality Self Attention (LSA)
class MultiHeadAttention_LSA(tf.keras.layers.MultiHeadAttention):
"""local multi-head self attention block
Locality Self Attention as described in https://arxiv.org/abs/2112.13492v1
This implementation is taken from https://keras.io/examples/vision/vit_small_ds/
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# The trainable temperature term. The initial value is the square
# root of the key dimension.
self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)
def _compute_attention(self, query, key, value, attention_mask=None, training=None):
query = tf.multiply(query, 1.0 / self.tau)
attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training
)
attention_output = tf.einsum(
self._combine_equation, attention_scores_dropout, value
)
return attention_output, attention_scores
#%% Squeeze-and-excitation block
def se_block(input_feature, ratio=8, residual = False, apply_to_input=True):
"""Squeeze-and-Excitation(SE) block.
As described in https://arxiv.org/abs/1709.01507
The implementation is taken from https://github.com/kobiso/CBAM-keras
"""
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
channel = input_feature.shape[channel_axis]
se_feature = GlobalAveragePooling2D()(input_feature)
se_feature = Reshape((1, 1, channel))(se_feature)
assert se_feature.shape[1:] == (1,1,channel)
if (ratio != 0):
se_feature = Dense(channel // ratio,
activation='relu',
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')(se_feature)
assert se_feature.shape[1:] == (1,1,channel//ratio)
se_feature = Dense(channel,
activation='sigmoid',
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')(se_feature)
assert se_feature.shape[1:] == (1,1,channel)
if K.image_data_format() == 'channels_first':
se_feature = Permute((3, 1, 2))(se_feature)
if(apply_to_input):
se_feature = multiply([input_feature, se_feature])
# Residual Connection
if(residual):
se_feature = Add()([se_feature, input_feature])
return se_feature
#%% Convolutional block attention module
def cbam_block(input_feature, ratio=8, residual = False):
""" Convolutional Block Attention Module(CBAM) block.
As described in https://arxiv.org/abs/1807.06521
The implementation is taken from https://github.com/kobiso/CBAM-keras
"""
cbam_feature = channel_attention(input_feature, ratio)
cbam_feature = spatial_attention(cbam_feature)
# Residual Connection
if(residual):
cbam_feature = Add()([input_feature, cbam_feature])
return cbam_feature
def channel_attention(input_feature, ratio=8):
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
# channel = input_feature._keras_shape[channel_axis]
channel = input_feature.shape[channel_axis]
shared_layer_one = Dense(channel//ratio,
activation='relu',
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')
shared_layer_two = Dense(channel,
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')
avg_pool = GlobalAveragePooling2D()(input_feature)
avg_pool = Reshape((1,1,channel))(avg_pool)
assert avg_pool.shape[1:] == (1,1,channel)
avg_pool = shared_layer_one(avg_pool)
assert avg_pool.shape[1:] == (1,1,channel//ratio)
avg_pool = shared_layer_two(avg_pool)
assert avg_pool.shape[1:] == (1,1,channel)
max_pool = GlobalMaxPooling2D()(input_feature)
max_pool = Reshape((1,1,channel))(max_pool)
assert max_pool.shape[1:] == (1,1,channel)
max_pool = shared_layer_one(max_pool)
assert max_pool.shape[1:] == (1,1,channel//ratio)
max_pool = shared_layer_two(max_pool)
assert max_pool.shape[1:] == (1,1,channel)
cbam_feature = Add()([avg_pool,max_pool])
cbam_feature = Activation('sigmoid')(cbam_feature)
if K.image_data_format() == "channels_first":
cbam_feature = Permute((3, 1, 2))(cbam_feature)
return multiply([input_feature, cbam_feature])
def spatial_attention(input_feature):
kernel_size = 7
if K.image_data_format() == "channels_first":
channel = input_feature.shape[1]
cbam_feature = Permute((2,3,1))(input_feature)
else:
channel = input_feature.shape[-1]
cbam_feature = input_feature
avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
assert avg_pool.shape[-1] == 1
max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
assert max_pool.shape[-1] == 1
concat = Concatenate(axis=3)([avg_pool, max_pool])
assert concat.shape[-1] == 2
cbam_feature = Conv2D(filters = 1,
kernel_size=kernel_size,
strides=1,
padding='same',
activation='sigmoid',
kernel_initializer='he_normal',
use_bias=False)(concat)
assert cbam_feature.shape[-1] == 1
if K.image_data_format() == "channels_first":
cbam_feature = Permute((3, 1, 2))(cbam_feature)
return multiply([input_feature, cbam_feature])