-
Notifications
You must be signed in to change notification settings - Fork 78
/
layerwise_attention.py
154 lines (136 loc) · 5.68 KB
/
layerwise_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# -*- coding: utf-8 -*-
# Copyright (C) 2020 Unbabel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Layer-Wise Attention Mechanism
================================
Computes a parameterised scalar mixture of N tensors,
`mixture = gamma * sum(s_k * tensor_k)`
where `s = softmax(w)`, with `w` and `gamma` scalar parameters.
If `layer_norm=True` then apply layer normalization.
If `dropout > 0`, then for each scalar weight, adjust its softmax
weight mass to 0 with the dropout probability (i.e., setting the
unnormalized weight to -inf). This effectively should redistribute
dropped probability mass to all other weights.
Original implementation:
- https://github.com/Hyperparticle/udify
"""
from typing import List, Optional
import torch
from torch.nn import Parameter, ParameterList
class LayerwiseAttention(torch.nn.Module):
def __init__(
self,
num_layers: int,
layer_norm: bool = False,
layer_weights: Optional[List[int]] = None,
dropout: float = None,
layer_transformation: str = "softmax",
) -> None:
super(LayerwiseAttention, self).__init__()
self.num_layers = num_layers
self.layer_norm = layer_norm
self.dropout = dropout
self.transform_fn = torch.softmax
if layer_transformation == "sparsemax":
from entmax import sparsemax
self.transform_fn = sparsemax
if layer_weights is None:
layer_weights = [0.0] * num_layers
elif len(layer_weights) != num_layers:
raise Exception(
"Length of layer_weights {} differs \
from num_layers {}".format(
layer_weights, num_layers
)
)
self.scalar_parameters = ParameterList(
[
Parameter(
torch.FloatTensor([layer_weights[i]]),
requires_grad=True,
)
for i in range(num_layers)
]
)
self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=True)
if self.dropout:
dropout_mask = torch.zeros(len(self.scalar_parameters))
dropout_fill = torch.empty(len(self.scalar_parameters)).fill_(-1e20)
self.register_buffer("dropout_mask", dropout_mask)
self.register_buffer("dropout_fill", dropout_fill)
def forward(
self,
tensors: List[torch.Tensor], # pylint: disable=arguments-differ
mask: torch.Tensor = None,
) -> torch.Tensor:
if len(tensors) != self.num_layers:
raise Exception(
"{} tensors were passed, but the module was initialized to \
mix {} tensors.".format(
len(tensors), self.num_layers
)
)
def _layer_norm(tensor, broadcast_mask, mask):
tensor_masked = tensor * broadcast_mask
batch_size, _, input_dim = tensors[0].size()
num_elements_not_masked = torch.tensor(
[mask[i].sum() * input_dim for i in range(batch_size)],
device=tensor.device,
)
# mean for each sentence
mean = torch.sum(torch.sum(tensor_masked, dim=2), dim=1)
mean = mean / num_elements_not_masked
variance = torch.vstack(
[
torch.sum(((tensor_masked[i] - mean[i]) * broadcast_mask[i]) ** 2)
/ num_elements_not_masked[i]
for i in range(batch_size)
]
)
normalized_tensor = torch.vstack(
[
((tensor[i] - mean[i]) / torch.sqrt(variance[i] + 1e-12)).unsqueeze(
0
)
for i in range(batch_size)
]
)
return normalized_tensor
# BUG: Pytorch bug fix when Parameters are not well copied across GPUs
# https://github.com/pytorch/pytorch/issues/36035
if len([parameter for parameter in self.scalar_parameters]) != self.num_layers:
weights = torch.tensor(self.weights, device=tensors[0].device)
gamma = torch.tensor(self.gamma_value, device=tensors[0].device)
else:
weights = torch.cat([parameter for parameter in self.scalar_parameters])
gamma = self.gamma
if self.training and self.dropout:
weights = torch.where(
self.dropout_mask.uniform_() > self.dropout, weights, self.dropout_fill
)
normed_weights = self.transform_fn(weights, dim=0)
normed_weights = torch.split(normed_weights, split_size_or_sections=1)
if not self.layer_norm:
pieces = []
for weight, tensor in zip(normed_weights, tensors):
pieces.append(weight * tensor)
return gamma * sum(pieces)
else:
mask_float = mask.float()
broadcast_mask = mask_float.unsqueeze(-1)
pieces = []
for weight, tensor in zip(normed_weights, tensors):
pieces.append(weight * _layer_norm(tensor, broadcast_mask, mask_float))
return gamma * sum(pieces)