-
-
Notifications
You must be signed in to change notification settings - Fork 266
/
Copy pathrnn_lm.cpp
26 lines (22 loc) · 997 Bytes
/
rnn_lm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// Copyright 2020-present pytorch-cpp Authors
#include "rnn_lm.h"
#include <torch/torch.h>
#include <tuple>
RNNLMImpl::RNNLMImpl(int64_t vocab_size, int64_t embed_size, int64_t hidden_size, int64_t num_layers)
: embed(vocab_size, embed_size),
lstm(torch::nn::LSTMOptions(embed_size, hidden_size).num_layers(num_layers).batch_first(true)),
linear(hidden_size, vocab_size) {
register_module("embed", embed);
register_module("lstm", lstm);
register_module("linear", linear);
}
std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> RNNLMImpl::forward(torch::Tensor x,
std::tuple<torch::Tensor, torch::Tensor> hx) {
torch::Tensor output;
std::tuple<torch::Tensor, torch::Tensor> state;
std::tie(output, state) = lstm->forward(embed->forward(x), hx);
output = output.reshape({-1, output.size(2)});
output = linear->forward(output);
output = torch::nn::functional::log_softmax(output, 1);
return std::make_tuple(output, state);
}