Skip to content

Commit

Permalink
fix attention input shapes
Browse files Browse the repository at this point in the history
specify kdim and vdim in case they are different from qdim
  • Loading branch information
Guitaricet committed Oct 8, 2020
1 parent bd665b0 commit 69cfe03
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions unet_transformer/unet_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,14 @@ def __init__(
self.conv_norm = LayerNorm(self.model_dim)

self.self_attn_layer_norm = LayerNorm(self.model_dim)

self.self_attn = MultiheadAttention(
self.model_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
kdim=self.input_dim,
vdim=self.input_dim,
)

self.dropout = args.dropout
Expand Down Expand Up @@ -156,13 +159,13 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None):

x = self.conv(x, output_size=output_size)

x = self.conv_out(x.transpose(1, 2)).transpose(1, 2)
x = self.conv_out(x.transpose(1, 2)).transpose(1, 2) # input_dim -> model_dim

if self.type_ == "down":
x = self.maxpool(x.contiguous())
x = self.maxpool(x.contiguous()) # seq_len / 2
encoder_padding_mask = self._get_next_mask(encoder_padding_mask)

x = x.permute(2, 0, 1) # (seq_len, batch, embed_dim)
x = x.permute(2, 0, 1) # (seq_len / 2, batch, embed_dim)

# only possible if type_=='same'
x = residual + x if self.conv_skip_connection else x
Expand Down

0 comments on commit 69cfe03

Please sign in to comment.