Skip to content

Commit

Permalink
Add template postprocess module for faster_tokenizer (#2516)
Browse files Browse the repository at this point in the history
* Add template postprocessing

* add template json

* move template.cc to faster_tokenizers

* Add template processing pybind

Conflicts:
	faster_tokenizer/faster_tokenizer/src/core/CMakeLists.txt

* tmp template pybind

* Fix template processor pybind __init__

* Fix template processing overflowing bug

* fix from_json of template piece
  • Loading branch information
joey12300 committed Jun 14, 2022
1 parent 1f446ff commit 6ccc0be
Show file tree
Hide file tree
Showing 11 changed files with 853 additions and 7 deletions.
2 changes: 1 addition & 1 deletion faster_tokenizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ if(WITH_PYTHON)

add_subdirectory(python)
add_custom_target(build_tokenizers_bdist_wheel ALL
COMMAND ${PYTHON_EXECUTABLE} setup.py bdist_wheel
COMMAND ${PYTHON_EXECUTABLE} setup.py bdist_wheel --plat-name=manylinux1_x86_64
DEPENDS copy_python_tokenizers)

else(WITH_PYTHON)
Expand Down
2 changes: 1 addition & 1 deletion faster_tokenizer/faster_tokenizer/include/core/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class Tokenizer {
AddedVocabulary added_vocabulary_;
bool use_truncation_;
bool use_padding_;
// TODO(zhoushunjie): Implement Decoder later.

friend void to_json(nlohmann::json& j, const Tokenizer& tokenizer);
friend void from_json(const nlohmann::json& j, Tokenizer& tokenizer);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ limitations under the License. */

#include "postprocessors/bert.h"
#include "postprocessors/postprocessor.h"
#include "postprocessors/template.h"
188 changes: 188 additions & 0 deletions faster_tokenizer/faster_tokenizer/include/postprocessors/template.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <unordered_map>
#include <vector>

#include "boost/variant.hpp"
#include "glog/logging.h"
#include "nlohmann/json.hpp"
#include "postprocessors/postprocessor.h"

namespace tokenizers {
namespace postprocessors {

enum SequenceType { SEQ_A, SEQ_B };
NLOHMANN_JSON_SERIALIZE_ENUM(SequenceType,
{
{SEQ_A, "A"}, {SEQ_B, "B"},
});
// The template indicate `${Id} : ${TypeId}`
using TemplateSequence = std::pair<SequenceType, uint>;
using TemplateSpecialToken = std::pair<std::string, uint>;

using TemplatePiece = boost::variant<TemplateSequence, TemplateSpecialToken>;
void to_json(nlohmann::json& j, const TemplatePiece& template_piece);
void from_json(const nlohmann::json& j, TemplatePiece& template_piece);

void ParseIdFromString(const std::string& template_id_string,
TemplatePiece* template_piece);
void SetTypeId(uint type_id, TemplatePiece* template_piece);
void GetTemplatePieceFromString(const std::string& template_string,
TemplatePiece* template_piece);

struct SpecialToken {
std::string id_;
std::vector<uint> ids_;
std::vector<std::string> tokens_;
SpecialToken() = default;
SpecialToken(const std::string& id,
const std::vector<uint>& ids,
const std::vector<std::string>& tokens)
: id_(id), ids_(ids), tokens_(tokens) {}
SpecialToken(const std::string& token, uint id) {
id_ = token;
ids_.push_back(id);
tokens_.push_back(token);
}
friend void to_json(nlohmann::json& j, const SpecialToken& special_token);
friend void from_json(const nlohmann::json& j, SpecialToken& special_token);
};

struct Template {
std::vector<TemplatePiece> pieces_;
Template() = default;
explicit Template(const std::string& template_str) {
std::vector<std::string> pieces;

// Parse the pieces
size_t start = template_str.find_first_not_of(" ");
size_t pos;
while ((pos = template_str.find_first_of(" ", start)) !=
std::string::npos) {
pieces.push_back(template_str.substr(start, pos - start));
start = template_str.find_first_not_of(" ", pos);
}
if (start != std::string::npos) {
pieces.push_back(template_str.substr(start));
}
AddStringPiece(pieces);
}

explicit Template(const std::vector<TemplatePiece>& pieces)
: pieces_(pieces) {}
explicit Template(const std::vector<std::string>& pieces) {
AddStringPiece(pieces);
}

void GetPiecesFromVec(const std::vector<std::string>& pieces) {
AddStringPiece(pieces);
}

void GetPiecesFromStr(const std::string& template_str) {
std::vector<std::string> pieces;

// Parse the pieces
size_t start = template_str.find_first_not_of(" ");
size_t pos;
while ((pos = template_str.find_first_of(" ", start)) !=
std::string::npos) {
pieces.push_back(template_str.substr(start, pos - start));
start = template_str.find_first_not_of(" ", pos);
}
if (start != std::string::npos) {
pieces.push_back(template_str.substr(start));
}
AddStringPiece(pieces);
}

void Clean() { pieces_.clear(); }

private:
void AddStringPiece(const std::vector<std::string>& pieces) {
for (auto&& piece : pieces) {
TemplatePiece template_piece;
GetTemplatePieceFromString(piece, &template_piece);
if (boost::get<TemplateSequence>(&template_piece)) {
pieces_.push_back(boost::get<TemplateSequence>(template_piece));
} else {
pieces_.push_back(boost::get<TemplateSpecialToken>(template_piece));
}
}
}

friend void to_json(nlohmann::json& j, const Template& template_);
friend void from_json(const nlohmann::json& j, Template& template_);
};

struct SpecialTokensMap {
std::unordered_map<std::string, SpecialToken> tokens_map_;
SpecialTokensMap() = default;
explicit SpecialTokensMap(const std::vector<SpecialToken>& special_tokens) {
SetTokensMap(special_tokens);
}
void SetTokensMap(const std::vector<SpecialToken>& special_tokens) {
tokens_map_.clear();
for (const auto& special_token : special_tokens) {
tokens_map_.insert({special_token.id_, special_token});
}
}
friend void to_json(nlohmann::json& j, const SpecialTokensMap& tokens_map);
friend void from_json(const nlohmann::json& j, SpecialTokensMap& tokens_map);
};

struct TemplatePostProcessor : public PostProcessor {
TemplatePostProcessor();
TemplatePostProcessor(const Template&,
const Template&,
const std::vector<SpecialToken>&);

virtual void operator()(core::Encoding* encoding,
core::Encoding* pair_encoding,
bool add_special_tokens,
core::Encoding* result_encoding) const override;
virtual size_t AddedTokensNum(bool is_pair) const override;

void UpdateSinglePieces(const std::string& template_str);
void UpdateSinglePieces(const std::vector<std::string>& pieces);
void UpdatePairPieces(const std::string& template_str);
void UpdatePairPieces(const std::vector<std::string>& pieces);
void UpdateAddedTokensNum();
void SetTokensMap(const std::vector<SpecialToken>& special_tokens);
size_t CountAdded(Template* template_,
const SpecialTokensMap& special_tokens_map);
size_t DefaultAdded(bool is_single = true);
void ApplyTemplate(const Template& pieces,
core::Encoding* encoding,
core::Encoding* pair_encoding,
bool add_special_tokens,
core::Encoding* result_encoding) const;

friend void to_json(nlohmann::json& j,
const TemplatePostProcessor& template_postprocessor);
friend void from_json(const nlohmann::json& j,
TemplatePostProcessor& template_postprocessor);

Template single_;
Template pair_;
size_t added_single_;
size_t added_pair_;
SpecialTokensMap special_tokens_map_;
};

} // postprocessors
} // tokenizers
2 changes: 1 addition & 1 deletion faster_tokenizer/faster_tokenizer/src/core/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cc_library(added_vocabulary SRCS added_vocabulary.cc DEPS normalizers pretokenizers json)
cc_library(tokenizer SRCS tokenizer.cc DEPS added_vocabulary json decoders trie models)
cc_library(tokenizer SRCS tokenizer.cc DEPS added_vocabulary json decoders trie models postprocessors)
cc_library(core SRCS encoding.cc DEPS json)
add_dependencies(tokenizer extern_boost)
12 changes: 12 additions & 0 deletions faster_tokenizer/faster_tokenizer/src/core/tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,11 @@ void to_json(nlohmann::json& j, const Tokenizer& tokenizer) {
typeid(postprocessors::BertPostProcessor)) {
j["postprocessor"] = *dynamic_cast<postprocessors::BertPostProcessor*>(
tokenizer.post_processor_.get());
} else if (typeid(*tokenizer.post_processor_.get()) ==
typeid(postprocessors::TemplatePostProcessor)) {
j["postprocessor"] =
*dynamic_cast<postprocessors::TemplatePostProcessor*>(
tokenizer.post_processor_.get());
}
}

Expand Down Expand Up @@ -611,6 +616,10 @@ void from_json(const nlohmann::json& j, Tokenizer& tokenizer) {
postprocessors::BertPostProcessor bert_postprocessor;
post_processor.get_to(bert_postprocessor);
tokenizer.SetPostProcessor(bert_postprocessor);
} else if (post_processor.at("type") == "TemplateProcessing") {
postprocessors::TemplatePostProcessor template_postprocessor;
post_processor.get_to(template_postprocessor);
tokenizer.SetPostProcessor(template_postprocessor);
}
}

Expand Down Expand Up @@ -686,7 +695,10 @@ template void Tokenizer::SetModel(const models::FasterWordPiece&);
// Instantiate processors
template void Tokenizer::SetPostProcessor(
const postprocessors::BertPostProcessor&);
template void Tokenizer::SetPostProcessor(
const postprocessors::TemplatePostProcessor&);

// Instantiate Decoder
template void Tokenizer::SetDecoder(const decoders::WordPiece& decoder);
} // core
} // tokenizers
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
cc_library(models SRCS wordpiece.cc faster_wordpiece.cc DEPS core json trie failure)
cc_library(models SRCS wordpiece.cc faster_wordpiece.cc DEPS core json trie failure icu)
add_dependencies(models extern_boost)
Original file line number Diff line number Diff line change
@@ -1 +1 @@
cc_library(postprocessors SRCS bert.cc postprocessor.cc DEPS core json)
cc_library(postprocessors SRCS bert.cc postprocessor.cc template.cc DEPS core json)
Loading

0 comments on commit 6ccc0be

Please sign in to comment.