Skip to content

Commit

Permalink
change pooling to mean pooling in latent space
Browse files Browse the repository at this point in the history
  • Loading branch information
ImKeTT committed Sep 28, 2022
1 parent 687682d commit c05f708
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def __init__(self, config):
self.activation = nn.Tanh()

def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
first_token_tensor = hidden_states.mean(-1)
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
Expand Down

0 comments on commit c05f708

Please sign in to comment.