Skip to content

Commit

Permalink
sliding_chunks_no_overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
ibeltagy committed Sep 12, 2020
1 parent 811dd10 commit 424e720
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 7 deletions.
27 changes: 20 additions & 7 deletions longformer/longformer.py
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from longformer.diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations
from longformer.sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv
from longformer.sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv
from transformers.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM


Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(self, attention_window: List[int] = None, attention_dilation: List[
self.attention_dilation = attention_dilation
self.autoregressive = autoregressive
self.attention_mode = attention_mode
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2', 'sliding_chunks_no_overlap']


class LongformerSelfAttention(nn.Module):
Expand Down Expand Up @@ -80,8 +81,8 @@ def __init__(self, config, layer_id):
self.autoregressive = config.autoregressive
assert self.attention_window > 0
assert self.attention_dilation > 0
assert self.attention_mode in ['tvm', 'sliding_chunks']
if self.attention_mode == 'sliding_chunks':
assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap']
if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']:
assert not self.autoregressive # not supported
assert self.attention_dilation == 1 # dilation is not supported

Expand Down Expand Up @@ -147,8 +148,12 @@ def forward(
q = q.float().contiguous()
k = k.float().contiguous()
attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False)
else: # "sliding_chunks"
elif self.attention_mode == "sliding_chunks":
attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0)
elif self.attention_mode == "sliding_chunks_no_overlap":
attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0)
else:
raise False
mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False)
if remove_from_windowed_attention_mask is not None:
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
Expand All @@ -162,10 +167,14 @@ def forward(
# diagonal mask with zeros everywhere and -inf inplace of padding
if self.attention_mode == 'tvm':
d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False)
else:
elif self.attention_mode == "sliding_chunks":
d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
elif self.attention_mode == "sliding_chunks_no_overlap":
d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)

attn_weights += d_mask
assert list(attn_weights.size()) == [bsz, seq_len, self.num_heads, self.attention_window * 2 + 1]
assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads]
assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3]

# the extra attention
if extra_attention_mask is not None:
Expand Down Expand Up @@ -199,8 +208,12 @@ def forward(
if self.attention_mode == 'tvm':
v = v.float().contiguous()
attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False)
else: # "sliding_chunks"
elif self.attention_mode == "sliding_chunks":
attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window)
elif self.attention_mode == "sliding_chunks_no_overlap":
attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window)
else:
raise False

attn = attn.type_as(hidden_states)
assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
Expand Down
43 changes: 43 additions & 0 deletions longformer/sliding_chunks.py
Expand Up @@ -131,3 +131,46 @@ def pad_to_window_size(input_ids: torch.Tensor, attention_mask: torch.Tensor,
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id)
attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens
return input_ids, attention_mask


# ========= "sliding_chunks_no_overlap": alternative implemenation of the sliding window attention =========
# This implementation uses non-overlapping chunks (or blocks) of size `w` with number of local attention = 3xw
# To make this implemenation comparable to "sliding_chunks" set w such that
# w_of_sliding_chunks_no_overlap = w_of_sliding_chunks * 2 / 3
# For example,
# w_of_sliding_chunks = 256 (this is one sided. Total attention size = 512)
# w_of_sliding_chunks_no_overlap = 170 (Total attention size = 510)
# Performance:
# - Speed: 30% faster than "sliding_chunks"
# - Memory: 95% of the memory usage of "sliding_chunks"
# The windows are asymmetric where number of attention on each side of a token ranges between w to 2w
# while "sliding_chunks" has a symmetric window around each token.


def sliding_chunks_no_overlap_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
bsz, seqlen, num_heads, head_dim = q.size()
assert seqlen % w == 0
assert q.size() == k.size()
# chunk seqlen into non-overlapping chunks of size w
chunk_q = q.view(bsz, seqlen // w, w, num_heads, head_dim)
chunk_k = k.view(bsz, seqlen // w, w, num_heads, head_dim)
chunk_k_expanded = torch.stack((
F.pad(chunk_k[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
chunk_k,
F.pad(chunk_k[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
), dim=-1)
diagonal_attn = torch.einsum('bcxhd,bcyhde->bcxhey', (chunk_q, chunk_k_expanded)) # multiply
return diagonal_attn.reshape(bsz, seqlen, num_heads, 3 * w)


def sliding_chunks_no_overlap_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
bsz, seqlen, num_heads, head_dim = v.size()
chunk_prob = prob.view(bsz, seqlen // w, w, num_heads, 3, w)
chunk_v = v.view(bsz, seqlen // w, w, num_heads, head_dim)
chunk_v_extended = torch.stack((
F.pad(chunk_v[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
chunk_v,
F.pad(chunk_v[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
), dim=-1)
context = torch.einsum('bcwhpd,bcdhep->bcwhe', (chunk_prob, chunk_v_extended))
return context.reshape(bsz, seqlen, num_heads, head_dim)

0 comments on commit 424e720

Please sign in to comment.