Skip to content

Commit

Permalink
[#41] slight refactor for code reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
raymondng76 committed Dec 9, 2021
1 parent ae77dc7 commit d19add5
Showing 1 changed file with 15 additions and 27 deletions.
42 changes: 15 additions & 27 deletions sgnlp/models/asgcn/modules/dynamic_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,24 @@ def __init__(
self.__init_rnn()

def __init_rnn(self) -> None:
"""
Helper method to initalized RNN type
"""
input_args = {
"input_size": self.input_size,
"hidden_size": self.hidden_size,
"num_layers": self.num_layers,
"bias": self.bias,
"batch_first": self.batch_first,
"dropout": self.dropout,
"bidirectional": self.bidirectional
}
if self.rnn_type == 'LSTM':
self.rnn = nn.LSTM(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
bias=self.bias,
batch_first=self.batch_first,
dropout=self.dropout,
bidirectional=self.bidirectional
)
self.rnn = nn.LSTM(**input_args)
elif self.rnn_type == 'GRU':
self.rnn = nn.GRU(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
bias=self.bias,
batch_first=self.batch_first,
dropout=self.dropout,
bidirectional=self.bidirectional
)
self.rnn = nn.GRU(**input_args)
elif self.rnn_type == 'RNN':
self.rnn = nn.RNN(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
bias=self.bias,
batch_first=self.batch_first,
dropout=self.dropout,
bidirectional=self.bidirectional
)
self.rnn = nn.RNN(**input_args)

def forward(self, x, x_len, h0=None):
# Sort
Expand Down

0 comments on commit d19add5

Please sign in to comment.