Skip to content

Commit

Permalink
Update lstm.py
Browse files Browse the repository at this point in the history
  • Loading branch information
GokulNC committed Oct 24, 2021
1 parent 30cb9a1 commit ab5691e
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions openhands/models/detection/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,25 @@
# Paper: https://arxiv.org/abs/2008.04637
# Models ported from: https://github.com/google-research/google-research/tree/master/sign_language_detection

class SignDetectionLSTM(nn.Module):
def __init__(self, input_dim=25, input_dropout = 0.5, encoder_layers = 1, hidden_size = 2**6, encoder_bidirectional = False):
super(SignDetectionLSTM, self).__init__()
class SignDetectionRNN(nn.Module):
def __init__(self, rnn_type="LSTM", input_dim=25, input_dropout = 0.5, num_layers = 1, hidden_size = 2**6, bidirectional = False):
super().__init__()
self.input_dropout = nn.Dropout(p=input_dropout)
self.lstm = nn.LSTM(input_dim, hidden_size, encoder_layers, bidirectional=encoder_bidirectional, batch_first=True)
self.rnn = getattr(nn, rnn_type)(
input_size=input_dim,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=bidirectional,
batch_first=True
)
self.hidden = None
self.linear = nn.Linear(hidden_size, 2)

def forward(self, x, hidden=None):
# input.shape = [batch_size, seq_length, input_dim]
x = self.input_dropout(x)
lstm_out, self.hidden = self.lstm(input, hidden)
rnn_out, self.hidden = self.rnn(x, hidden)
# shape = [batch_size, out_dim]
y_pred = self.linear(self.dropout(lstm_out[:,-1,:]))
return y_pred
y_pred = self.linear(rnn_out)
return y_pred#, self.hidden

0 comments on commit ab5691e

Please sign in to comment.