Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inefficient _layer_norm implementation in layerwise_attention.py #122

Closed
K024 opened this issue Mar 21, 2023 · 1 comment
Closed

Inefficient _layer_norm implementation in layerwise_attention.py #122

K024 opened this issue Mar 21, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@K024
Copy link

K024 commented Mar 21, 2023

🐛 Bug

Iterations along batch_size axis are slow.

Code:

def _layer_norm(tensor, broadcast_mask, mask):
tensor_masked = tensor * broadcast_mask
batch_size, _, input_dim = tensors[0].size()
num_elements_not_masked = torch.tensor(
[mask[i].sum() * input_dim for i in range(batch_size)],
device=tensor.device,
)
# mean for each sentence
mean = torch.sum(torch.sum(tensor_masked, dim=2), dim=1)
mean = mean / num_elements_not_masked
variance = torch.vstack(
[
torch.sum(((tensor_masked[i] - mean[i]) * broadcast_mask[i]) ** 2)
/ num_elements_not_masked[i]
for i in range(batch_size)
]
)
normalized_tensor = torch.vstack(
[
((tensor[i] - mean[i]) / torch.sqrt(variance[i] + 1e-12)).unsqueeze(
0
)
for i in range(batch_size)
]
)
return normalized_tensor

This also prevents wmt20-comet-qe-da to be exported to onnx format.

To Reproduce

N/A

Expected behaviour

Use .view to create a proper shape for broadcasting.

Preferred code:

        def _layer_norm(tensor, broadcast_mask, mask):
            tensor_masked = tensor * broadcast_mask
            batch_size, _, input_dim = tensors[0].size()

            # mean for each sentence
            num_elements_not_masked = mask.sum(1) * input_dim
            mean = tensor_masked.view(batch_size, -1).sum(1)
            mean = (mean / num_elements_not_masked).view(batch_size, 1, 1)

            variance = (((tensor_masked - mean) * broadcast_mask) ** 2).view(batch_size, -1).sum(1) / num_elements_not_masked
            normalized_tensor = (tensor - mean) / torch.sqrt(variance + 1e-12).view(batch_size, 1, 1)
            return normalized_tensor

Screenshots

N/A

Environment

N/A

Additional context

N/A

@K024 K024 added the bug Something isn't working label Mar 21, 2023
ricardorei pushed a commit that referenced this issue Mar 21, 2023
@ricardorei
Copy link
Collaborator

This is a good suggestion. I tested it and everything worked well its currently in master.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants