Skip to content

Commit

Permalink
fix attention issues for keys != queries
Browse files Browse the repository at this point in the history
fix pooling issues
now all three types of unet layer have .maxpool
up uses maxpool to get key_padding_mask
rename _get_next_mask -> _get_shrinked_mask

add wandb and data-bin folders to .gigignore
  • Loading branch information
Guitaricet committed Oct 9, 2020
1 parent 69cfe03 commit c9f019a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 21 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
.idea
data_bin
data-bin
checkpoints
wandb

# Created by https://www.gitignore.io/api/linux,macos,python,pycharm,visualstudiocode
# Edit at https://www.gitignore.io/?templates=linux,macos,python,pycharm,visualstudiocode
Expand Down
3 changes: 2 additions & 1 deletion unet_transformer/unet_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,13 @@ def forward(

# down layers
# required for 'up' layers to compute transposed conv output shape
layer_padding_masks = ([])
layer_padding_masks = []
down_states = []
for layer in self.down_layers:
layer_padding_masks.append(
padding_mask
) # ignore the last padding_mask, we don't need it

x, padding_mask = layer(x, padding_mask)

down_states.append(x)
Expand Down
58 changes: 39 additions & 19 deletions unet_transformer/unet_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def __init__(

self.type_ = type_
self.conv = None
self.maxpool = None

# shrinks down input in "down" layers
# not used in "same", but used in "up" for computing the key_padding_mask
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

# use groups=1 for fast computation, use self.input_dim to be closer to the original paper
groups = 1
Expand All @@ -68,7 +71,6 @@ def __init__(
self.conv = nn.Conv1d(
self.input_dim, self.input_dim, kernel_size=3, padding=1, groups=groups
)
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
elif type_ == "same":
# keep size of output the same
self.conv = nn.Conv1d(
Expand All @@ -90,7 +92,7 @@ def __init__(
self.model_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
self_attention=False,
kdim=self.input_dim,
vdim=self.input_dim,
)
Expand Down Expand Up @@ -119,11 +121,11 @@ def upgrade_state_dict_named(self, state_dict, name):
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]

def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None):
def forward(self, x, key_padding_mask, attn_mask: Optional[Tensor] = None):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
key_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
src_len = n_steps if type='same'
src_len = n_steps if type='down' and it maxpooled inside fowrard call
Expand Down Expand Up @@ -152,18 +154,18 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None):
else:
# the main problem with 'up' layer is that we don't know the actual output size
# e.g. if input seq_len=4, corresponding down layer may had the input seq_len= 8 or 9
# encoder_padding_mask is supposed to have the right seq_len
# key_padding_mask is supposed to have the right seq_len
output_size = None
if encoder_padding_mask is not None:
output_size = (batch_size, embed_dim, encoder_padding_mask.shape[1])
if key_padding_mask is not None:
output_size = (batch_size, embed_dim, key_padding_mask.shape[1])

x = self.conv(x, output_size=output_size)

x = self.conv_out(x.transpose(1, 2)).transpose(1, 2) # input_dim -> model_dim

if self.type_ == "down":
x = self.maxpool(x.contiguous()) # seq_len / 2
encoder_padding_mask = self._get_next_mask(encoder_padding_mask)
x = self.maxpool(x.contiguous()) # (batch_size, model_dim, seq_len / 2)
# key_padding_mask = self._get_next_mask(key_padding_mask)

x = x.permute(2, 0, 1) # (seq_len / 2, batch, embed_dim)

Expand All @@ -177,40 +179,58 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None):
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)

_key_padding_mask = key_padding_mask
if self.type_ == "up":
# in this case, key_padding_mask represents the mask over post-deconvs
# but in the attention we need a mask over pre-deconvs
_key_padding_mask = self._get_shrinked_mask(key_padding_mask)

x, _ = self.self_attn(
query=x,
key=input_kv,
value=input_kv,
key_padding_mask=encoder_padding_mask,
key_padding_mask=_key_padding_mask,
attn_mask=attn_mask,
)

x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)

# FFN
if self.type_ == "down":
key_padding_mask = self._get_shrinked_mask(key_padding_mask)

if encoder_padding_mask is not None:
x = x.masked_fill(encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0.0)
if key_padding_mask is not None:
x = x.masked_fill(key_padding_mask.transpose(0, 1).unsqueeze(-1), 0.0)

residual = x
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=float(self.activation_dropout), training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if encoder_padding_mask is not None:
x = x.masked_fill(encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0.0)
return x, encoder_padding_mask

def _get_next_mask(self, pad_mask):
if key_padding_mask is not None:
x = x.masked_fill(key_padding_mask.transpose(0, 1).unsqueeze(-1), 0.0)

assert (
x.shape[0] == key_padding_mask.shape[1]
), f"output seq_len {x.shape[0]} != padding seq_len {key_padding_mask.shape[1]}"
assert (
x.shape[1] == key_padding_mask.shape[0]
), f"output batch_size {x.shape[1]} != padding batch_size {key_padding_mask.shape[0]}"

return x, key_padding_mask

def _get_shrinked_mask(self, pad_mask):
"""
Computes a mask of the post-conv inputs based on the pre-conv inputs mask
:param pad_mask: torch.BoolTensor of shape (batch_size, seq_len)
:param layer:
:return:
"""
if self.maxpool is None:
return pad_mask
pad_mask = pad_mask.unsqueeze(2).transpose(1, 2)
non_pad_mask = self.maxpool((~pad_mask).float().contiguous()).bool() # ~ is logical NOT
pad_mask = ~non_pad_mask.squeeze(1)
Expand Down

0 comments on commit c9f019a

Please sign in to comment.