Skip to content
This repository has been archived by the owner on Apr 23, 2024. It is now read-only.

Support custom tokens #97

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ test_text = "".join([random.choice("abcde ") for _ in range(100)])
# Training model
yttm.BPE.train(data=train_data_path, vocab_size=5000, model=model_path)

# Training model with custom tokens
yttm.BPE.train(data=train_data_path, vocab_size=5000, model=model_path, custom_tokens=[b"[CLS]", b"[MASK]"])

# Loading model
bpe = yttm.BPE(model=model_path)

Expand All @@ -71,7 +74,7 @@ print(bpe.encode([test_text], output_type=yttm.OutputType.SUBWORD))
 
### Training model
```python
youtokentome.BPE.train(data, model, vocab_size, coverage, n_threads=-1, pad_id=0, unk_id=1, bos_id=2, eos_id=3)
youtokentome.BPE.train(data, model, vocab_size, coverage, n_threads=-1, pad_id=0, unk_id=1, bos_id=2, eos_id=3, custom_tokens=[])
```
Trains BPE model and saves to file.

Expand All @@ -86,6 +89,7 @@ Trains BPE model and saves to file.
* `unk_id`: int, reserved id for unknown symbols
* `bos_id`: int, reserved id for begin of sentence token
* `eos_id`: int, reserved id for end of sentence token
* `custom_tokens`: List[bytes], tokens which will not be split into subwords.

**Returns**: Class `youtokentome.BPE` with the loaded model.

Expand Down Expand Up @@ -191,7 +195,7 @@ Convert each id to subword and concatenate with space symbol.
### Example

```bash
$ yttm bpe --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000
$ yttm bpe --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000 --custom_tokens "[CLS],[MASK]"
$ yttm encode --model OUTPUT_MODEL_FILE --output_type subword < TEST_DATA_FILE > ENCODED_DATA
```

Expand Down Expand Up @@ -234,6 +238,9 @@ Options:
--unk_id INTEGER Unknown token id. [default: 1]
--bos_id INTEGER 'Begin of sentence' token id. [default: 2]
--eos_id INTEGER 'End of sentence' token id. [default: 3]
--custom_tokens TEXT Tokens which will not be split into
subwords, muiltple tokens should be
provided with comma seperated.
--help Show this message and exit.
```

Expand Down
23 changes: 23 additions & 0 deletions tests/unit_tests/test_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,26 @@ def test_japanese():
assert tokenized_text == expected_result
print(tokenized_text)
os.remove(TRAIN_DATA_PATH)

def test_special_token():
train_text = """
[CLS] Lorem ipsum dolor sit amet, consectetur adipiscing elit,
sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris
nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in
reprehenderit in voluptate velit [MASK] esse cillum dolore eu fugiat nulla
pariatur. Excepteur sint occaecat cupidatat non proident, sunt in
culpa qui officia deserunt mollit <SEP> anim id est laborum.
"""
test_text = "[CLS] Lorem ipsum [TOKEN] dolor <SEP> sit [MASK] amet"
TRAIN_DATA_PATH = "train_data.txt"
MODEL_PATH = "model.yttm"
with open(TRAIN_DATA_PATH, "w") as fin:
fin.write(train_text)
model = yttm.BPE.train(TRAIN_DATA_PATH, MODEL_PATH, 100, custom_tokens=[b'[CLS]',b'[TOKEN]',b'<SEP>'])
tokenized_text = model.encode([test_text], output_type=yttm.OutputType.SUBWORD)
expected_result = [['▁','[CLS]', '▁', 'L', 'or', 'e', 'm', '▁', 'ip', 's', 'um', '▁', '[TOKEN]', '▁dolor', '▁', '<SEP>', '▁s', 'it', '▁', '[', 'M', 'A', 'S', 'K', ']', '▁a', 'm', 'e', 't']]
print(tokenized_text)
assert tokenized_text == expected_result
print(tokenized_text)
os.remove(TRAIN_DATA_PATH)
32 changes: 32 additions & 0 deletions youtokentome/cpp/bpe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,12 @@ void print_config(const string &input_path, const string &model_path,
std::cerr << " unk: " << bpe_config.special_tokens.unk_id << std::endl;
std::cerr << " bos: " << bpe_config.special_tokens.bos_id << std::endl;
std::cerr << " eos: " << bpe_config.special_tokens.eos_id << std::endl;
if (bpe_config.special_tokens.custom_tokens.size()) {
std::cerr << " custom_tokens: ";
for (auto token:bpe_config.special_tokens.custom_tokens)
std::cerr << token << " ";
std::cerr << std::endl;
}
std::cerr << std::endl;
}

Expand Down Expand Up @@ -1665,6 +1671,7 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8,

uint32_t new_token_cur = new_tokens_start;
list.emplace_back(bpe_state.char2id.at(SPACE_TOKEN), 0);
string utf8_text;

for (auto it_char_in_word = begin_of_word; it_char_in_word < end_of_word;) {
if (bpe_state.char2id.count(*it_char_in_word) == 0) {
Expand All @@ -1674,15 +1681,31 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8,

unrecognized_tokens[new_token_cur] =
encode_utf8({it_char_in_word, it_unrecognized_word});
if (custom_token2id.size())
utf8_text.append(unrecognized_tokens[new_token_cur]);
it_char_in_word = it_unrecognized_word;

list.emplace_back(new_token_cur, list.size());
new_token_cur++;
} else {
if (custom_token2id.size())
utf8_to_chars(*it_char_in_word, std::back_inserter(utf8_text));
list.emplace_back(bpe_state.char2id.at(*it_char_in_word), list.size());
++it_char_in_word;
}
}

if (custom_token2id.size() && custom_token2id.count(utf8_text)) {
if (output_type == ID) {
output_ids.push_back(bpe_state.char2id.at(SPACE_TOKEN));
output_ids.push_back(custom_token2id.find(utf8_text) -> second);
} else {
output_pieces.push_back(encode_utf8({SPACE_TOKEN}));
output_pieces.push_back(utf8_text);
}
continue;
}

list.back().next = -1;


Expand Down Expand Up @@ -1840,6 +1863,11 @@ void BaseEncoder::fill_from_state() {
}
reversed_recipe[BOS_TOKEN] = bpe_state.special_tokens.bos_id;
reversed_recipe[EOS_TOKEN] = bpe_state.special_tokens.eos_id;
uint32_t custom_id = bpe_state.special_tokens.max_predefined_id();
for (auto token : bpe_state.special_tokens.custom_tokens) {
++custom_id;
custom_token2id[token] = custom_id;
}
}

int BaseEncoder::vocab_size() const {
Expand Down Expand Up @@ -1947,6 +1975,10 @@ Status BaseEncoder::id_to_subword(int id, string *subword, bool replace_space) c
*subword = EOS_TOKEN;
return Status();
}
if (id <= bpe_state.special_tokens.max_id() && id > bpe_state.special_tokens.max_predefined_id()) {
*subword = bpe_state.special_tokens.custom_tokens[id - bpe_state.special_tokens.max_predefined_id() - 1];
return Status();
}

assert(recipe.count(id));
if (replace_space) {
Expand Down
1 change: 1 addition & 0 deletions youtokentome/cpp/bpe.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class BaseEncoder {
flat_hash_map<uint32_t, uint32_t> id2char;
flat_hash_map<uint32_t, std::vector<uint32_t>> recipe;
flat_hash_map<std::string, uint32_t> reversed_recipe;
flat_hash_map<std::string, uint32_t> custom_token2id;
flat_hash_map<uint64_t, int> rule2id;
int n_threads;

Expand Down
2 changes: 2 additions & 0 deletions youtokentome/cpp/utf8.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ constexpr static uint32_t INVALID_UNICODE = 0x0fffffff;

uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len);

void utf8_to_chars(const uint32_t x, const std::back_insert_iterator<std::string> it);

std::string encode_utf8(const std::vector<uint32_t> &utext);

std::vector<uint32_t> decode_utf8(const char *begin, const char *end);
Expand Down
20 changes: 16 additions & 4 deletions youtokentome/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@ using std::string;
using std::vector;

void SpecialTokens::dump(std::ofstream &fout) {
fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id
<< std::endl;
fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id << " ";
for (auto token: custom_tokens) fout << token << " ";
fout << std::endl;

}

void SpecialTokens::load(std::ifstream &fin) {
fin >> unk_id >> pad_id >> bos_id >> eos_id;
std::string token;
while (fin >> token)
custom_tokens.push_back(token);
}

uint32_t SpecialTokens::max_id() const {
uint32_t SpecialTokens::max_predefined_id() const {
int ret = 0;
ret = std::max(ret, unk_id);
ret = std::max(ret, pad_id);
Expand All @@ -27,8 +32,14 @@ uint32_t SpecialTokens::max_id() const {
return ret;
}

uint32_t SpecialTokens::max_id() const {
int ret = max_predefined_id();
ret += custom_tokens.size();
return ret;
}

bool SpecialTokens::taken_id(int id) const {
return id == unk_id || id == pad_id || id == bos_id || id == eos_id;
return id == unk_id || id == pad_id || id == bos_id || id == eos_id || (id > max_predefined_id() && id <= max_id());
}

uint64_t SpecialTokens::n_special_tokens() const {
Expand All @@ -37,6 +48,7 @@ uint64_t SpecialTokens::n_special_tokens() const {
cnt += (pad_id != -1);
cnt += (bos_id != -1);
cnt += (eos_id != -1);
cnt += custom_tokens.size();
return cnt;
}

Expand Down
2 changes: 2 additions & 0 deletions youtokentome/cpp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct SpecialTokens {
int unk_id = -1;
int bos_id = -1;
int eos_id = -1;
std::vector<std::string> custom_tokens;

SpecialTokens() = default;

Expand All @@ -40,6 +41,7 @@ struct SpecialTokens {
bool taken_id(int id) const;

uint64_t n_special_tokens() const;
uint32_t max_predefined_id() const;
};

struct BpeConfig {
Expand Down
3 changes: 3 additions & 0 deletions youtokentome/cpp/yttm.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cdef extern from "bpe.h" namespace "vkcom":
int unk_id
int bos_id
int eos_id
vector[string] custom_tokens

cdef cppclass BpeConfig:
double character_coverage
Expand Down Expand Up @@ -67,6 +68,7 @@ cdef class BPE:
vocab_size,
coverage=1.0,
n_threads=-1,
custom_tokens=[],
pad_id=0,
unk_id=1,
bos_id=2,
Expand All @@ -79,6 +81,7 @@ cdef class BPE:
bpe_config.special_tokens.unk_id = unk_id
bpe_config.special_tokens.bos_id = bos_id
bpe_config.special_tokens.eos_id = eos_id
bpe_config.special_tokens.custom_tokens = custom_tokens

cdef Status status = train_bpe(data.encode(), model.encode(), vocab_size, bpe_config)
if status.code != 0:
Expand Down
2 changes: 2 additions & 0 deletions youtokentome/youtokentome.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def train(
data: str,
model: str,
vocab_size: int,
custom_tokens: List[bytes] = [],
coverage: float = 1.0,
n_threads: int = -1,
pad_id: int = 0,
Expand All @@ -35,6 +36,7 @@ def train(
vocab_size=vocab_size,
n_threads=n_threads,
coverage=coverage,
custom_tokens=custom_tokens,
pad_id=pad_id,
unk_id=unk_id,
bos_id=bos_id,
Expand Down
10 changes: 9 additions & 1 deletion youtokentome/yttm_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ def main():
default=3,
show_default=True,
)
def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eos_id):
@click.option(
"--custom_tokens",
type=click.STRING,
help="Tokens which will not be split into subwords, muiltple tokens should be provided with comma seperated.",
default="",
show_default=True,
)
def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eos_id, custom_tokens):
"""Train BPE model."""
yttmc.BPE.train(
data=data,
Expand All @@ -69,6 +76,7 @@ def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eo
unk_id=unk_id,
bos_id=bos_id,
eos_id=eos_id,
custom_tokens=map(lambda t: t.encode("utf8"), custom_tokens.split(','))
)


Expand Down