-
Notifications
You must be signed in to change notification settings - Fork 0
/
fullTokenizer.go
108 lines (93 loc) · 2.08 KB
/
fullTokenizer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package go_bert_tokenizer
import (
"strings"
)
const (
ClassToken = "[CLS]"
SeparatorToken = "[SEP]"
SequenceSeparator = " ||| "
)
type Encode struct {
Text string
Tokens []string
TokenIDs []int32
MaskIDs []int32
TypeIDs []int32
}
type FullTokenizer struct {
Basic *BasicTokenizer
Wordpiece *WordpieceTokenizer
SeqLen int
}
func NewFullTokenizer(voc *Vocab, seqLen int, lower bool) *FullTokenizer {
tkz := &FullTokenizer{
Basic: NewBasicTokenizer(lower),
Wordpiece: NewWordpieceTokenizer(voc),
SeqLen: seqLen,
}
return tkz
}
func (tkz *FullTokenizer) Tokenize(text string) *Encode {
encode := &Encode{
Text: text,
Tokens: make([]string, tkz.SeqLen),
TokenIDs: make([]int32, tkz.SeqLen),
MaskIDs: make([]int32, tkz.SeqLen),
TypeIDs: make([]int32, tkz.SeqLen),
}
parts := strings.Split(text, SequenceSeparator)
seqs := make([][]string, len(parts))
for i, part := range parts {
seqs[i] = tkz.tokenize(part)
}
truncate(seqs, tkz.SeqLen-len(seqs)-1)
voc := tkz.Wordpiece.Vocab
var s int
encode.Tokens[s] = ClassToken
encode.TokenIDs[s] = voc.GetID(ClassToken)
encode.TypeIDs[s] = 0
encode.MaskIDs[s] = 1
s++
for sid, seq := range seqs {
for _, tok := range seq {
encode.Tokens[s] = tok
encode.TokenIDs[s] = voc.GetID(tok)
encode.TypeIDs[s] = int32(sid)
encode.MaskIDs[s] = 1
s++
}
encode.Tokens[s] = SeparatorToken
encode.TokenIDs[s] = voc.GetID(SeparatorToken)
encode.TypeIDs[s] = int32(sid)
encode.MaskIDs[s] = 1
s++
}
return encode
}
func (tkz *FullTokenizer) tokenize(text string) []string {
var toks []string
for _, tok := range tkz.Basic.Tokenize(text) {
toks = append(toks, tkz.Wordpiece.Tokenize(tok)...)
}
return toks
}
func truncate(seqs [][]string, maxlen int) {
var seqlen int
for i := range seqs {
seqlen += len(seqs[i])
}
for slen := seqlen; slen > maxlen; slen-- {
var mi, mv int
for i := len(seqs) - 1; i >= 0; i-- {
seq := seqs[i]
if len(seq) > mv {
mi = i
mv = len(seq)
}
}
if mv <= 0 {
return
}
seqs[mi] = seqs[mi][:len(seqs[mi])-1]
}
}