Skip to content

Commit dd2af82

Browse files
committed
Add Language Model tutorial files, update cmake and readme
1 parent 5f28bb5 commit dd2af82

File tree

12 files changed

+345
-1
lines changed

12 files changed

+345
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ add_subdirectory("tutorials/intermediate/convolutional_neural_network")
2222
add_subdirectory("tutorials/intermediate/deep_residual_network")
2323
add_subdirectory("tutorials/intermediate/recurrent_neural_network")
2424
add_subdirectory("tutorials/intermediate/bidirectional_recurrent_neural_network")
25+
add_subdirectory("tutorials/intermediate/language_model")
2526

2627
# The following code block is suggested to be used on Windows.
2728
# According to https://github.com/pytorch/pytorch/issues/25457,

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ $ ./scripts.sh build
3333
* [Deep Residual Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/deep_residual_network/src/main.cpp)
3434
* [Recurrent Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/recurrent_neural_network/src/main.cpp)
3535
* [Bidirectional Recurrent Neural Network](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/bidirectional_recurrent_neural_network/src/main.cpp)
36-
* [Language Model (RNN-LM)]()
36+
* [Language Model (RNN-LM)](https://github.com/prabhuomkar/pytorch-cpp/tree/master/tutorials/intermediate/language_model/src/main.cpp)
3737

3838
#### 3. Advanced
3939
* [Generative Adversarial Networks]()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2+
3+
project(language-model VERSION 1.0.0 LANGUAGES CXX)
4+
5+
# Files
6+
set(SOURCES src/main.cpp
7+
src/rnn_lm.cpp
8+
src/corpus.cpp
9+
src/dictionary.cpp
10+
src/clip_grad_norm.cpp
11+
)
12+
13+
set(HEADERS include/rnn_lm.h
14+
include/corpus.h
15+
include/dictionary.h
16+
include/clip_grad_norm.h
17+
)
18+
19+
set(EXECUTABLE_NAME language-model)
20+
21+
22+
add_executable(${EXECUTABLE_NAME} ${SOURCES} ${HEADERS})
23+
target_include_directories(${EXECUTABLE_NAME} PRIVATE include)
24+
25+
target_link_libraries(${EXECUTABLE_NAME} "${TORCH_LIBRARIES}")
26+
27+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES
28+
CXX_STANDARD 11
29+
CXX_STANDARD_REQUIRED YES
30+
)
31+
32+
# The following code block is suggested to be used on Windows.
33+
# According to https://github.com/pytorch/pytorch/issues/25457,
34+
# the DLLs need to be copied to avoid memory errors.
35+
# See https://pytorch.org/cppdocs/installing.html.
36+
if (MSVC)
37+
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
38+
add_custom_command(TARGET ${EXECUTABLE_NAME}
39+
POST_BUILD
40+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
41+
${TORCH_DLLS}
42+
$<TARGET_FILE_DIR:${EXECUTABLE_NAME}>)
43+
endif (MSVC)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#pragma once
3+
4+
#include <torch/torch.h>
5+
#include <vector>
6+
7+
namespace nn_utils {
8+
void clip_grad_l2_norm(std::vector<torch::Tensor> parameters, double max_norm);
9+
} // namespace nn_utils
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#pragma once
3+
4+
#include <torch/torch.h>
5+
#include <string>
6+
#include "dictionary.h"
7+
8+
namespace data_utils {
9+
class Corpus {
10+
public:
11+
explicit Corpus(const std::string& path) : path_(path) {}
12+
torch::Tensor get_data(int64_t batch_size);
13+
const Dictionary& get_dictionary() const { return dictionary_; }
14+
private:
15+
std::string path_;
16+
Dictionary dictionary_;
17+
};
18+
} // namespace data_utils
19+
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#pragma once
3+
4+
#include <string>
5+
#include <unordered_map>
6+
#include <vector>
7+
8+
namespace data_utils {
9+
class Dictionary {
10+
public:
11+
int64_t add_word(const std::string& word);
12+
std::string word_at_index(int64_t index) const { return idx2word_[index]; }
13+
size_t size() const { return word2idx_.size(); }
14+
private:
15+
std::unordered_map<std::string, size_t> word2idx_;
16+
std::vector<std::string> idx2word_;
17+
};
18+
} // namespace data_utils
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#pragma once
3+
4+
#include <torch/torch.h>
5+
6+
class RNNLMImpl : public torch::nn::Module {
7+
public:
8+
RNNLMImpl(int64_t vocab_size, int64_t embed_size, int64_t hidden_size, int64_t num_layers);
9+
torch::nn::RNNOutput forward(torch::Tensor x, torch::Tensor h);
10+
11+
private:
12+
torch::nn::Embedding embed;
13+
torch::nn::LSTM lstm;
14+
torch::nn::Linear linear;
15+
};
16+
17+
TORCH_MODULE(RNNLM);
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#include "clip_grad_norm.h"
3+
#include <torch/torch.h>
4+
#include <vector>
5+
#include <algorithm>
6+
7+
namespace nn_utils {
8+
// Clips gradient norm of a vector of tensors
9+
//
10+
// Source (slightly modified):
11+
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/utils/clip_grad.h
12+
void clip_grad_l2_norm(std::vector<torch::Tensor> parameters, double max_norm) {
13+
std::vector<torch::Tensor> params_with_grad;
14+
15+
for (const auto& param : parameters) {
16+
auto& grad = param.grad();
17+
if (grad.defined()) {
18+
params_with_grad.push_back(param);
19+
}
20+
}
21+
22+
double total_norm = 0.0;
23+
24+
for (const auto& param : params_with_grad) {
25+
auto param_norm = param.grad().data().norm(2.0);
26+
total_norm += std::pow(param_norm.item().toDouble(), 2.0);
27+
}
28+
total_norm = std::pow(total_norm, 1.0 / 2.0);
29+
30+
auto clip_coef = max_norm / (total_norm + 1e-6);
31+
if (clip_coef < 1) {
32+
for (auto& param : params_with_grad) {
33+
param.grad().data().mul_(clip_coef);
34+
}
35+
}
36+
}
37+
} // namespace nn_utils
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
3+
#include "corpus.h"
4+
#include <torch/torch.h>
5+
#include <fstream>
6+
#include <sstream>
7+
#include <exception>
8+
#include <algorithm>
9+
10+
namespace data_utils {
11+
torch::Tensor Corpus::get_data(int64_t batch_size) {
12+
if (std::ifstream file{path_}) {
13+
std::vector<int64_t> ids;
14+
15+
for (std::string line; std::getline(file, line);) {
16+
std::istringstream line_stream(line);
17+
18+
for (std::string word; line_stream >> word;) {
19+
ids.push_back(dictionary_.add_word(word));
20+
}
21+
// End of sequence marker
22+
ids.push_back(dictionary_.add_word("<eos>"));
23+
}
24+
25+
int64_t num_batches = ids.size() / batch_size;
26+
return torch::from_blob(ids.data(), {batch_size, num_batches},
27+
torch::TensorOptions().dtype(torch::kInt64)).clone();
28+
} else {
29+
throw std::runtime_error("Could not read file at path: " + path_);
30+
}
31+
}
32+
} // namespace data_utils
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#include "dictionary.h"
3+
4+
namespace data_utils {
5+
int64_t Dictionary::add_word(const std::string& word) {
6+
auto it = word2idx_.find(word);
7+
8+
if (it == word2idx_.end()) {
9+
idx2word_.push_back(word);
10+
11+
auto new_index = idx2word_.size() - 1;
12+
word2idx_[word] = new_index;
13+
return new_index;
14+
}
15+
16+
return it->second;
17+
}
18+
} // namespace data_utils
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#include <torch/torch.h>
3+
#include <iostream>
4+
#include <iomanip>
5+
#include "rnn_lm.h"
6+
#include "corpus.h"
7+
#include "clip_grad_norm.h"
8+
9+
using data_utils::Corpus;
10+
using nn_utils::clip_grad_l2_norm;
11+
12+
int main() {
13+
std::cout << "Language Model\n\n";
14+
15+
// Device
16+
torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
17+
18+
// Hyper parameters
19+
const int64_t embed_size = 128;
20+
const int64_t hidden_size = 1024;
21+
const int64_t num_layers = 1;
22+
const int64_t num_epochs = 5;
23+
const int64_t num_samples = 1000; // the number of words to be sampled
24+
const int64_t batch_size = 20;
25+
const int64_t sequence_length = 30;
26+
const double learning_rate = 0.002;
27+
28+
// Load "Penn Treebank" dataset
29+
// See https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model/data/
30+
const std::string penn_treebank_data_path = "../../../../tutorials/intermediate/language_model/data/train.txt";
31+
32+
Corpus corpus(penn_treebank_data_path);
33+
34+
auto ids = corpus.get_data(batch_size);
35+
int64_t vocab_size = corpus.get_dictionary().size();
36+
37+
// Model
38+
RNNLM model(vocab_size, embed_size, hidden_size, num_layers);
39+
model->to(device);
40+
41+
// Optimizer
42+
auto optimizer = torch::optim::Adam(model->parameters(), torch::optim::AdamOptions(learning_rate));
43+
44+
// Set floating point output precision
45+
std::cout << std::fixed << std::setprecision(4);
46+
47+
std::cout << "Training...\n";
48+
49+
// Train the model
50+
for (size_t epoch = 0; epoch != num_epochs; ++epoch) {
51+
// Initialize running metrics
52+
float running_loss = 0.0;
53+
float running_perplexity = 0.0;
54+
size_t running_num_samples = 0;
55+
56+
// Set initial hidden- and cell-states (stacked into one tensor)
57+
auto state = torch::zeros({2, num_layers, batch_size, hidden_size}).to(device).detach();
58+
59+
for (size_t i = 0; i < ids.size(1) - sequence_length; i += sequence_length) {
60+
// Transfer data and target labels to device
61+
auto data = ids.slice(1, i, i + sequence_length).to(device);
62+
auto target = ids.slice(1, i + 1, i + 1 + sequence_length).to(device);
63+
64+
// Forward pass
65+
auto rnn_output = model->forward(data, state);
66+
auto output = rnn_output.output;
67+
state = rnn_output.state.detach();
68+
69+
// Calculate loss
70+
auto loss = torch::nll_loss(output, target.reshape(-1));
71+
72+
// Update running metrics
73+
running_loss += loss.item().toFloat() * data.size(0);
74+
running_perplexity += torch::exp(loss).item().toFloat() * data.size(0);
75+
running_num_samples += data.size(0);
76+
77+
// Backward pass and optimize
78+
optimizer.zero_grad();
79+
loss.backward();
80+
clip_grad_l2_norm(model->parameters(), 0.5);
81+
optimizer.step();
82+
}
83+
84+
auto sample_mean_loss = running_loss / running_num_samples;
85+
auto sample_mean_perplexity = running_perplexity / running_num_samples;
86+
87+
std::cout << "Epoch [" << (epoch + 1) << "/" << num_epochs << "], Trainset - Loss: "
88+
<< sample_mean_loss << ", Perplexity: " << sample_mean_perplexity << '\n';
89+
}
90+
91+
std::cout << "Training finished!\n\n";
92+
std::cout << "Generating samples...\n";
93+
94+
const std::string sample_output_path = "../../../../tutorials/intermediate/language_model/data/sample.txt";
95+
96+
// Generate samples
97+
model->eval();
98+
torch::NoGradGuard no_grad;
99+
100+
std::ofstream sample_output_file(sample_output_path);
101+
102+
// Set initial hidden- and cell-states (stacked into one tensor)
103+
auto state = torch::zeros({2, num_layers, 1, hidden_size}).to(device);
104+
105+
// Select one word-id at random
106+
auto prob = torch::ones(vocab_size);
107+
auto data = prob.multinomial(1).unsqueeze(1).to(device);
108+
109+
for (size_t i = 0; i != num_samples; ++i) {
110+
// Forward pass
111+
auto rnn_output = model->forward(data, state);
112+
auto out = rnn_output.output;
113+
state = rnn_output.state;
114+
115+
// Sample one word id
116+
prob = out.exp();
117+
auto word_id = prob.multinomial(1).item();
118+
119+
// Fill input data with sampled word id for the next time step
120+
data.fill_(word_id);
121+
122+
// Write the word corresponding to the id to the file
123+
auto word = corpus.get_dictionary().word_at_index(word_id.toLong());
124+
word = (word == "<eos>") ? "\n" : word + " ";
125+
sample_output_file << word;
126+
}
127+
std::cout << "Finished generating samples!\nSaved output to " << sample_output_path << "\n";
128+
}
129+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Copyright 2019 Markus Fleischhacker
2+
#include "rnn_lm.h"
3+
#include <torch/torch.h>
4+
5+
RNNLMImpl::RNNLMImpl(int64_t vocab_size, int64_t embed_size, int64_t hidden_size, int64_t num_layers)
6+
: embed(vocab_size, embed_size),
7+
lstm(torch::nn::LSTMOptions(embed_size, hidden_size).layers(num_layers).batch_first(true)),
8+
linear(hidden_size, vocab_size) {
9+
register_module("embed", embed);
10+
register_module("lstm", lstm);
11+
register_module("linear", linear);
12+
}
13+
14+
torch::nn::RNNOutput RNNLMImpl::forward(torch::Tensor x, torch::Tensor h) {
15+
auto lstm_out = lstm->forward(embed->forward(x), h);
16+
auto out = lstm_out.output;
17+
out = out.reshape({-1, out.size(2)});
18+
out = linear->forward(out);
19+
out = torch::log_softmax(out, 1);
20+
return {out, lstm_out.state};
21+
}

0 commit comments

Comments
 (0)