-
Notifications
You must be signed in to change notification settings - Fork 326
/
adapter_modeling.py
471 lines (363 loc) · 18.4 KB
/
adapter_modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
import math
import torch
from torch import nn
class Activation_Function_Class(nn.Module):
"""
Implementation of various activation function.
"""
def __init__(self, hidden_act):
if hidden_act.lower() == "relu":
self.f = nn.functional.relu
elif hidden_act.lower() == "tanh":
self.f = torch.tanh
elif hidden_act.lower() == "swish":
def swish(x):
return x * torch.nn.functional.sigmoid(x)
self.f = swish
elif hidden_act.lower() == "gelu":
def gelu_new(x):
"""
Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
self.f = gelu_new
elif hidden_act.lower() == "leakyrelu":
self.f = nn.functional.leaky_relu
super().__init__()
def forward(self, x):
return self.f(x)
class Adapter(nn.Module):
"""
Implementation of a single Adapter block.
"""
def __init__(
self,
input_size,
down_sample=None,
non_linearity="relu",
init_bert_weights=True,
add_layer_norm_before=True,
add_layer_norm_after=False,
residual_before_ln=True,
):
super().__init__()
self.input_size = input_size
self.add_layer_norm_before = add_layer_norm_before
self.add_layer_norm_after = add_layer_norm_after
self.residual_before_ln = residual_before_ln
# list for all modules of the adapter, passed into nn.Sequential()
seq_list = []
# If we want to have a layer norm on input, we add it to seq_list
if self.add_layer_norm_before:
self.adapter_norm_before = nn.LayerNorm(self.input_size)
seq_list.append(self.adapter_norm_before)
# if a downsample size is not passed, we just half the size of the original input
self.down_sample = down_sample
if down_sample is None:
self.down_sample = self.input_size // 2
# Linear down projection of the input
seq_list.append(nn.Linear(self.input_size, self.down_sample))
# select non-linearity
# TODO give more options than just relu, or pass the non_linearity directly, not as a string
# if non_linearity.lower() == 'relu':
# self.non_linearity = nn.ReLU()
self.non_linearity = Activation_Function_Class(non_linearity.lower())
seq_list.append(self.non_linearity)
# sequential adapter, first downproject, then non-linearity then upsample. In the forward pass we include the
# residual connection
self.adapter_down = nn.Sequential(*seq_list)
# Up projection to input size
self.adapter_up = nn.Linear(self.down_sample, self.input_size)
# If we want to have a layer norm on output, we apply it later after a separate residual connection
# This means that we learn a new output layer norm, which replaces another layer norm learned in the bert layer
if self.add_layer_norm_after:
self.adapter_norm_after = nn.LayerNorm(self.input_size)
# if we want to initialize with the bert strategy then this function is called for all the linear layers
if init_bert_weights:
self.adapter_down.apply(self.init_bert_weights)
self.adapter_up.apply(self.init_bert_weights)
def forward(self, x, residual_input): # , residual_input=None):
down = self.adapter_down(x)
up = self.adapter_up(down)
output = up
# todo add brief documentation what that means
if self.residual_before_ln:
output = output + residual_input
# todo add brief documentation what that means
if self.add_layer_norm_after:
output = self.adapter_norm_after(output)
# todo add brief documentation what that means
if not self.residual_before_ln:
output = output + residual_input
return output, down, up
# This is copied from the BERT model so that this is a self containing class. This unfortunately introduces code
# copying so it might be better to pass the BERT model here TODO
@staticmethod
def init_bert_weights(module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
# module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
# TODO I set the std to default 0.02, this might need to be changed
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BertFusion(nn.Module):
def __init__(self, config):
super(BertFusion, self).__init__()
# if config.hidden_size % config.num_attention_heads != 0:
# raise ValueError(
# "The hidden size (%d) is not a multiple of the number of attention "
# "heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.config = config
self.output_attentions = config.output_attentions
self.dense_size = int(config.hidden_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
if (
not self.config.adapter_fusion["query"]
and not self.config.adapter_fusion["key"]
and not self.config.adapter_fusion["value"]
):
self.dense = nn.Linear(self.dense_size, 1)
if self.config.adapter_fusion["query"]:
self.query = nn.Linear(int(config.hidden_size), self.dense_size)
self.query.apply(Adapter.init_bert_weights)
if self.config.adapter_fusion["key"]:
self.key = nn.Linear(self.dense_size, self.dense_size)
self.key.apply(Adapter.init_bert_weights)
if self.config.adapter_fusion["value"]:
self.value = nn.Linear(int(config.hidden_size), int(config.hidden_size), bias=False)
self.value.apply(Adapter.init_bert_weights)
if self.config.adapter_fusion["value_initialized"]:
self.value.weight.data = (
torch.zeros(int(config.hidden_size), int(config.hidden_size)) + 0.000001
).fill_diagonal_(1.0)
if self.config.adapter_fusion["temperature"]:
self.T = 50.0
else:
self.T = 1.0
self.reduction = self.T / 1000.0
def forward(self, query, key, value, residual):
if self.config.adapter_fusion["residual_before"]:
value += residual[:, :, None, :].repeat(1, 1, value.size(2), 1)
if self.config.adapter_fusion["query"]:
query_layer = self.query(query)
else:
query_layer = query
if self.config.adapter_fusion["key"]:
key_layer = self.key(key)
else:
key_layer = key
if self.config.adapter_fusion["value"] and self.config.adapter_fusion["value_before_softmax"]:
# key/value have dims => batch, toks, number-of-adapters, feats
value_layer = self.value(value)
else:
value_layer = value
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.squeeze(torch.matmul(query_layer.unsqueeze(2), key_layer.transpose(-2, -1)), dim=2)
attention_scores = self.dropout(attention_scores)
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores / self.T)
self.T = max(self.T - self.reduction, 1.0)
if not self.training:
self.recent_attention = attention_probs.detach().cpu().numpy()
context_layer = torch.squeeze(torch.matmul(attention_probs.unsqueeze(2), value_layer), dim=2)
if self.config.adapter_fusion["value"] and not self.config.adapter_fusion["value_before_softmax"]:
# key/value have dims => batch, toks, number-of-adapters, feats
context_layer = self.value(context_layer)
else:
context_layer = context_layer
if not self.config.adapter_fusion["residual_before"]:
context_layer += residual
return context_layer
class AdapterFusionSentLvlDynamic(nn.Module):
def __init__(self, config, n_tasks):
super(AdapterFusionSentLvlDynamic, self).__init__()
self.config = config
# TODO
self.dense_size = int(config.hidden_size) // config.text_task_adapter_config["reduction_factor"]
if (
not self.config.adapter_fusion["query"]
and not self.config.adapter_fusion["key"]
and not self.config.adapter_fusion["value"]
):
self.dense = nn.Linear(self.dense_size, 1)
if self.config.adapter_fusion["query"]:
self.query = nn.Linear(int(config.hidden_size), self.dense_size)
if self.config.adapter_fusion["key"]:
self.key = nn.Linear(self.dense_size, self.dense_size)
if self.config.adapter_fusion["value"]:
self.value = nn.Linear(int(config.hidden_size), int(config.hidden_size))
if self.config.adapter_fusion["temperature"]:
self.T = 50.0
else:
self.T = 1.0
self.reduction = self.T / 1000.0
def forward(self, query, key, value, attention_mask):
attention_mask = (attention_mask == 0).float().to(key.device).squeeze()
length = torch.sum(attention_mask, dim=1)
# attention_mask = attention_mask[:,:,None,None].repeat((1,1,key.size()[-2], key.size()[-1]))
key = key * attention_mask[:, :, None, None].repeat((1, 1, key.size()[-2], key.size()[-1]))
key_sent = torch.sum(key, dim=1) / length[:, None, None].repeat(1, key.size()[-2], key.size()[-1])
if (
self.config.adapter_fusion["query"]
and not self.config.adapter_fusion["key"]
and not self.config.adapter_fusion["value"]
):
query = query * attention_mask[:, :, None].repeat((1, 1, query.size()[-1]))
query_sent = torch.sum(query, dim=1) / length[:, None].repeat(1, query.size()[-1])
query_enc = self.query(query_sent)
scores_t = torch.matmul(key_sent, query_enc[:, :, None]).squeeze(-1)
probs = nn.Softmax(dim=-1)(scores_t / self.T)
# result = torch.squeeze(torch.matmul(probs, value), dim=2)
result = torch.squeeze(torch.matmul(probs[:, None, None, :], value))
# {'MR': {'devacc': 77.53, 'acc': 76.7, 'ndev': 9596, 'ntest': 9596}}
if (
self.config.adapter_fusion["query"]
and self.config.adapter_fusion["key"]
and not self.config.adapter_fusion["value"]
):
query = query * attention_mask[:, :, None].repeat((1, 1, query.size()[-1]))
query_sent = torch.sum(query, dim=1) / length[:, None].repeat(1, query.size()[-1])
query_enc = self.query(query_sent)
key_enc = self.key(key_sent)
scores_t = torch.matmul(key_enc, query_enc[:, :, None]).squeeze(-1)
probs = nn.Softmax(dim=-1)(scores_t / self.T)
# result = torch.squeeze(torch.matmul(probs, value), dim=2)
result = torch.squeeze(torch.matmul(probs[:, None, None, :], value))
if (
self.config.adapter_fusion["query"]
and self.config.adapter_fusion["key"]
and self.config.adapter_fusion["value"]
):
query = query * attention_mask[:, :, None].repeat((1, 1, query.size()[-1]))
query_sent = torch.sum(query, dim=1) / length[:, None].repeat(1, query.size()[-1])
query_enc = self.query(query_sent)
key_enc = self.key(key_sent)
value_enc = self.value(value)
scores_t = torch.matmul(key_enc, query_enc[:, :, None]).squeeze(-1)
probs = nn.Softmax(dim=-1)(scores_t / self.T)
# result = torch.squeeze(torch.matmul(probs, value), dim=2)
result = torch.squeeze(torch.matmul(probs[:, None, None, :], value_enc))
if (
not self.config.adapter_fusion["query"]
and not self.config.adapter_fusion["key"]
and not self.config.adapter_fusion["value"]
):
# key_sent = torch.mean(key, dim=1)
scores = self.dense(key_sent)
scores_t = scores.transpose(-2, -1)
probs = nn.Softmax(dim=-1)(scores_t / self.T)
result = torch.squeeze(torch.matmul(probs.unsqueeze(2), value), dim=2)
# attention_scores = attention_scores + attention_mask
# weighted_value = probs.unsqueeze(1).unsqueeze(-1) * value
# result = torch.sum(weighted_value, dim=2)
self.T = max(self.T - self.reduction, 1.0)
return result
def get_subnet_constructor(non_linearity, reduction_factor):
def subnet(dims_in, dims_out):
return nn.Sequential(
nn.Linear(dims_in, dims_in // reduction_factor),
Activation_Function_Class(non_linearity),
nn.Linear(dims_in // reduction_factor, dims_out),
)
return subnet
class NICECouplingBlock(nn.Module):
"""Coupling Block following the NICE design."""
def __init__(self, dims_in, dims_c=[], non_linearity="relu", reduction_factor=2):
super().__init__()
channels = dims_in[0][0]
self.split_len1 = channels // 2
self.split_len2 = channels - channels // 2
assert all(
[dims_c[i][1:] == dims_in[0][1:] for i in range(len(dims_c))]
), "Dimensions of input and one or more conditions don't agree."
self.conditional = len(dims_c) > 0
condition_length = sum([dims_c[i][0] for i in range(len(dims_c))])
subnet_constructor = get_subnet_constructor(non_linearity, reduction_factor)
self.F = subnet_constructor(self.split_len2 + condition_length, self.split_len1)
self.G = subnet_constructor(self.split_len1 + condition_length, self.split_len2)
def forward(self, x, c=[], rev=False):
# x1, x2 = (x[0].narrow(1, 0, self.split_len1),
# x[0].narrow(1, self.split_len1, self.split_len2))
x1, x2 = (x[:, :, : self.split_len1], x[:, :, self.split_len1 :])
if not rev:
x2_c = torch.cat([x2, *c], 1) if self.conditional else x2
y1 = x1 + self.F(x2_c)
y1_c = torch.cat([y1, *c], 1) if self.conditional else y1
y2 = x2 + self.G(y1_c)
else:
x1_c = torch.cat([x1, *c], 1) if self.conditional else x1
y2 = x2 - self.G(x1_c)
y2_c = torch.cat([y2, *c], 1) if self.conditional else y2
y1 = x1 - self.F(y2_c)
return torch.cat((y1, y2), -1)
# return [torch.cat((y1, y2), 1)]
def jacobian(self, x, rev=False):
return 0
def output_dims(self, input_dims):
assert len(input_dims) == 1, "Can only use 1 input"
return input_dims
class GLOWCouplingBlock(nn.Module):
"""
Coupling Block following the GLOW design. The only difference to the RealNVP coupling blocks, is the fact that it
uses a single subnetwork to jointly predict [s_i, t_i], instead of two separate subnetworks. This reduces
computational cost and speeds up learning. clamp: Soft clamping for the multiplicative component. The amplification
or attenuation of each input dimension can be at most ±exp(clamp).
"""
def __init__(self, dims_in, dims_c=[], non_linearity="relu", reduction_factor=2, clamp=5.0):
super().__init__()
channels = dims_in[0][0]
self.ndims = len(dims_in[0])
self.split_len1 = channels // 2
self.split_len2 = channels - channels // 2
self.clamp = clamp
self.max_s = math.exp(clamp)
self.min_s = math.exp(-clamp)
assert all(
[tuple(dims_c[i][1:]) == tuple(dims_in[0][1:]) for i in range(len(dims_c))]
), f"Dimensions of input and one or more conditions don't agree: {dims_c} vs {dims_in}."
self.conditional = len(dims_c) > 0
condition_length = sum([dims_c[i][0] for i in range(len(dims_c))])
subnet_constructor = get_subnet_constructor(non_linearity, reduction_factor)
self.s1 = subnet_constructor(self.split_len1 + condition_length, self.split_len2 * 2)
self.s2 = subnet_constructor(self.split_len2 + condition_length, self.split_len1 * 2)
def e(self, s):
return torch.exp(self.clamp * 0.636 * torch.atan(s / self.clamp))
def log_e(self, s):
return self.clamp * 0.636 * torch.atan(s / self.clamp)
def forward(self, x, c=[], rev=False):
# x1, x2 = (x[0].narrow(1, 0, self.split_len1),
# x[0].narrow(1, self.split_len1, self.split_len2))
x1, x2 = (x[:, :, : self.split_len1], x[:, :, self.split_len1 :])
if not rev:
# r2 = self.s2(torch.cat([x2, *c], 1) if self.conditional else x2)
# s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:]
s2, t2 = x1.clone(), x2.clone()
y1 = self.e(s2) * x1 + t2
r1 = self.s1(torch.cat([y1, *c], 1) if self.conditional else y1)
s1, t1 = r1[:, : self.split_len2], r1[:, self.split_len2 :]
y2 = self.e(s1) * x2 + t1
self.last_jac = torch.sum(self.log_e(s1), dim=tuple(range(1, self.ndims + 1))) + torch.sum(
self.log_e(s2), dim=tuple(range(1, self.ndims + 1))
)
else: # names of x and y are swapped!
r1 = self.s1(torch.cat([x1, *c], 1) if self.conditional else x1)
s1, t1 = r1[:, : self.split_len2], r1[:, self.split_len2 :]
y2 = (x2 - t1) / self.e(s1)
r2 = self.s2(torch.cat([y2, *c], 1) if self.conditional else y2)
s2, t2 = r2[:, : self.split_len1], r2[:, self.split_len1 :]
y1 = (x1 - t2) / self.e(s2)
self.last_jac = -torch.sum(self.log_e(s1), dim=tuple(range(1, self.ndims + 1))) - torch.sum(
self.log_e(s2), dim=tuple(range(1, self.ndims + 1))
)
return [torch.cat((y1, y2), 1)]
def jacobian(self, x, c=[], rev=False):
return self.last_jac
def output_dims(self, input_dims):
return input_dims