Skip to content

Commit

Permalink
add module tests 3
Browse files Browse the repository at this point in the history
  • Loading branch information
senwu committed Dec 1, 2019
1 parent e928995 commit ed087d2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/emmental/modules/rnn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,15 @@ def forward(self, x: Tensor, x_mask: Optional[Tensor] = None) -> Tensor:
"""
Mean pooling
"""
if x_mask is None:
x_mask = x.new_full(x.size()[:2], fill_value=0, dtype=torch.uint8)
x_lens = x_mask.data.eq(0).long().sum(dim=1)
weights = torch.ones(x.size()) / x_lens.unsqueeze(1).float()
weights = (
output_word.new_ones(output_word.size())
/ x_lens.view(x_lens.size()[0], 1, 1).float()
)
weights.data.masked_fill_(x_mask.data.unsqueeze(dim=2), 0.0)
word_vectors = torch.bmm(
output_word.transpose(1, 2), weights.unsqueeze(2)
).squeeze(2)
word_vectors = torch.sum(output_word * weights, dim=1)
output = self.linear(word_vectors) if self.final_linear else word_vectors

return output
48 changes: 48 additions & 0 deletions tests/modules/test_rnn_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging

import torch

from emmental.modules.rnn_module import RNN
from emmental.utils.utils import pad_batch

logger = logging.getLogger(__name__)


def test_rnn_module(caplog):
"""Unit test of RNN Module"""

caplog.set_level(logging.INFO)

n_class = 2
emb_size = 10
lstm_hidden = 20
batch_size = 3
seq_len = 4

rnn = RNN(
num_classes=n_class,
emb_size=emb_size,
lstm_hidden=lstm_hidden,
attention=True,
dropout=0.2,
bidirectional=False,
)

assert rnn(torch.randn(batch_size, seq_len, emb_size)).size() == (3, n_class)

rnn = RNN(
num_classes=n_class,
emb_size=emb_size,
lstm_hidden=lstm_hidden,
attention=False,
dropout=0.2,
bidirectional=True,
)

_, input_mask = pad_batch(torch.randn(batch_size, seq_len))

assert rnn(torch.randn(batch_size, seq_len, emb_size)).size() == (3, n_class)
assert rnn(torch.randn(batch_size, seq_len, emb_size), input_mask).size() == (
3,
n_class,
)

0 comments on commit ed087d2

Please sign in to comment.