Skip to content

Commit

Permalink
[#41] add dynamic_rnn module
Browse files Browse the repository at this point in the history
  • Loading branch information
raymondng76 committed Dec 9, 2021
1 parent 132bbb4 commit ae77dc7
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions sgnlp/models/asgcn/modules/dynamic_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.nn as nn


class DynamicLSTM(nn.Module):
"""
A dynamic LSTM class which can hold variable length sequence
"""
def __init__(
self,
input_size,
hidden_size,
num_layers=1,
bias=True,
batch_first=True,
dropout=0,
bidirectional=False,
only_use_last_hidden_state=False,
rnn_type='LSTM') -> None:
super(DynamicLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = dropout
self.bidirectional = bidirectional
self.only_use_last_hidden_state = only_use_last_hidden_state
self.rnn_type = rnn_type
self.__init_rnn()

def __init_rnn(self) -> None:
if self.rnn_type == 'LSTM':
self.rnn = nn.LSTM(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
bias=self.bias,
batch_first=self.batch_first,
dropout=self.dropout,
bidirectional=self.bidirectional
)
elif self.rnn_type == 'GRU':
self.rnn = nn.GRU(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
bias=self.bias,
batch_first=self.batch_first,
dropout=self.dropout,
bidirectional=self.bidirectional
)
elif self.rnn_type == 'RNN':
self.rnn = nn.RNN(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
bias=self.bias,
batch_first=self.batch_first,
dropout=self.dropout,
bidirectional=self.bidirectional
)

def forward(self, x, x_len, h0=None):
# Sort
x_sort_idx = torch.argsort(-x_len)
x_unsort_idx = torch.argsort(x_sort_idx).long()
x_len = x_len[x_sort_idx]
x = x[x_sort_idx.long()]

# Pack
x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=self.batch_first)

if self.rnn_type == "LSTM":
out_pack, (ht, ct) = self.rnn(x_emb_p, None) if h0 is None else self.rnn(x_emb_p, (h0, h0))
else:
out_pack, ht = self.rnn(x_emb_p, None) if h0 is None else self.rnn(x_emb_p, h0)
ct = None

# Unsort
# (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
ht = torch.transpose(ht, 0, 1)[x_unsort_idx]
ht = torch.transpose(ht, 0, 1)

if self.only_use_last_hidden_state:
return ht
else:
# Unpack: out
out = torch.nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=self.batch_first) # (sequence, lengths)
out = out[0]
out = out[x_unsort_idx]

# Unsort: out c
if self.rnn_type == 'LSTM':
# (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
ct = torch.transpose(ct, 0, 1)[x_unsort_idx]
ct = torch.transpose(ct, 0, 1)
return out, (ht, ct)

0 comments on commit ae77dc7

Please sign in to comment.