From c9f019a1e7267c80eb3e157e82bea6c831eb688b Mon Sep 17 00:00:00 2001 From: Vlad Lialin Date: Fri, 9 Oct 2020 14:46:46 -0400 Subject: [PATCH] fix attention issues for keys != queries fix pooling issues now all three types of unet layer have .maxpool up uses maxpool to get key_padding_mask rename _get_next_mask -> _get_shrinked_mask add wandb and data-bin folders to .gigignore --- .gitignore | 3 +- unet_transformer/unet_transformer.py | 3 +- unet_transformer/unet_transformer_layer.py | 58 +++++++++++++++------- 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index a5c7ea3..4f2c385 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .idea -data_bin +data-bin checkpoints +wandb # Created by https://www.gitignore.io/api/linux,macos,python,pycharm,visualstudiocode # Edit at https://www.gitignore.io/?templates=linux,macos,python,pycharm,visualstudiocode diff --git a/unet_transformer/unet_transformer.py b/unet_transformer/unet_transformer.py index b551c7f..eaa17b9 100644 --- a/unet_transformer/unet_transformer.py +++ b/unet_transformer/unet_transformer.py @@ -409,12 +409,13 @@ def forward( # down layers # required for 'up' layers to compute transposed conv output shape - layer_padding_masks = ([]) + layer_padding_masks = [] down_states = [] for layer in self.down_layers: layer_padding_masks.append( padding_mask ) # ignore the last padding_mask, we don't need it + x, padding_mask = layer(x, padding_mask) down_states.append(x) diff --git a/unet_transformer/unet_transformer_layer.py b/unet_transformer/unet_transformer_layer.py index f31d073..e2f2643 100644 --- a/unet_transformer/unet_transformer_layer.py +++ b/unet_transformer/unet_transformer_layer.py @@ -51,7 +51,10 @@ def __init__( self.type_ = type_ self.conv = None - self.maxpool = None + + # shrinks down input in "down" layers + # not used in "same", but used in "up" for computing the key_padding_mask + self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) # use groups=1 for fast computation, use self.input_dim to be closer to the original paper groups = 1 @@ -68,7 +71,6 @@ def __init__( self.conv = nn.Conv1d( self.input_dim, self.input_dim, kernel_size=3, padding=1, groups=groups ) - self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) elif type_ == "same": # keep size of output the same self.conv = nn.Conv1d( @@ -90,7 +92,7 @@ def __init__( self.model_dim, args.encoder_attention_heads, dropout=args.attention_dropout, - self_attention=True, + self_attention=False, kdim=self.input_dim, vdim=self.input_dim, ) @@ -119,11 +121,11 @@ def upgrade_state_dict_named(self, state_dict, name): state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] del state_dict[k] - def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): + def forward(self, x, key_padding_mask, attn_mask: Optional[Tensor] = None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` - encoder_padding_mask (ByteTensor): binary ByteTensor of shape + key_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. src_len = n_steps if type='same' src_len = n_steps if type='down' and it maxpooled inside fowrard call @@ -152,18 +154,18 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): else: # the main problem with 'up' layer is that we don't know the actual output size # e.g. if input seq_len=4, corresponding down layer may had the input seq_len= 8 or 9 - # encoder_padding_mask is supposed to have the right seq_len + # key_padding_mask is supposed to have the right seq_len output_size = None - if encoder_padding_mask is not None: - output_size = (batch_size, embed_dim, encoder_padding_mask.shape[1]) + if key_padding_mask is not None: + output_size = (batch_size, embed_dim, key_padding_mask.shape[1]) x = self.conv(x, output_size=output_size) x = self.conv_out(x.transpose(1, 2)).transpose(1, 2) # input_dim -> model_dim if self.type_ == "down": - x = self.maxpool(x.contiguous()) # seq_len / 2 - encoder_padding_mask = self._get_next_mask(encoder_padding_mask) + x = self.maxpool(x.contiguous()) # (batch_size, model_dim, seq_len / 2) + # key_padding_mask = self._get_next_mask(key_padding_mask) x = x.permute(2, 0, 1) # (seq_len / 2, batch, embed_dim) @@ -177,21 +179,30 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): if attn_mask is not None: attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) + _key_padding_mask = key_padding_mask + if self.type_ == "up": + # in this case, key_padding_mask represents the mask over post-deconvs + # but in the attention we need a mask over pre-deconvs + _key_padding_mask = self._get_shrinked_mask(key_padding_mask) + x, _ = self.self_attn( query=x, key=input_kv, value=input_kv, - key_padding_mask=encoder_padding_mask, + key_padding_mask=_key_padding_mask, attn_mask=attn_mask, ) + x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.self_attn_layer_norm(x) # FFN + if self.type_ == "down": + key_padding_mask = self._get_shrinked_mask(key_padding_mask) - if encoder_padding_mask is not None: - x = x.masked_fill(encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0.0) + if key_padding_mask is not None: + x = x.masked_fill(key_padding_mask.transpose(0, 1).unsqueeze(-1), 0.0) residual = x x = self.activation_fn(self.fc1(x)) @@ -199,18 +210,27 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x - if encoder_padding_mask is not None: - x = x.masked_fill(encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0.0) - return x, encoder_padding_mask - def _get_next_mask(self, pad_mask): + if key_padding_mask is not None: + x = x.masked_fill(key_padding_mask.transpose(0, 1).unsqueeze(-1), 0.0) + + assert ( + x.shape[0] == key_padding_mask.shape[1] + ), f"output seq_len {x.shape[0]} != padding seq_len {key_padding_mask.shape[1]}" + assert ( + x.shape[1] == key_padding_mask.shape[0] + ), f"output batch_size {x.shape[1]} != padding batch_size {key_padding_mask.shape[0]}" + + return x, key_padding_mask + + def _get_shrinked_mask(self, pad_mask): """ + Computes a mask of the post-conv inputs based on the pre-conv inputs mask + :param pad_mask: torch.BoolTensor of shape (batch_size, seq_len) :param layer: :return: """ - if self.maxpool is None: - return pad_mask pad_mask = pad_mask.unsqueeze(2).transpose(1, 2) non_pad_mask = self.maxpool((~pad_mask).float().contiguous()).bool() # ~ is logical NOT pad_mask = ~non_pad_mask.squeeze(1)