In [1]:
import torch
from torch import optim, nn
from transformer import Transformer
from collections import defaultdict
from typing import List

In [2]:
fname = './data/tiny-shakespeare.txt'
lines = []
with open(fname, 'r') as f:
    while (line := f.readline()):
        lines.append(line)

In [3]:
print(len(lines))
print(sum(l.endswith('\n') for l in lines))

40000
40000


In [4]:
data = "".join(lines)
print([c for c in data[:30]])

['F', 'i', 'r', 's', 't', ' ', 'C', 'i', 't', 'i', 'z', 'e', 'n', ':', '\n', 'B', 'e', 'f', 'o', 'r', 'e', ' ', 'w', 'e', ' ', 'p', 'r', 'o', 'c', 'e']


In [5]:
bytes = bytearray(data, 'utf8')

In [None]:
class TrieNode:
    def __init__(self) -> None:
        self.children = {}
        self.is_end = False
        self.tok = -1

class Trie:
    def __init__(self) -> None:
        self.root = TrieNode()
    
    def insert(self, bytes: bytearray, tok: int) -> None:
        dummy = self.root
        for b in bytes:
            if b not in dummy.children:
                dummy.children[b] = TrieNode()
            dummy = dummy.children[b]
        dummy.is_end = True
        dummy.tok = tok

class BytePairEncoder:
    def __init__(self) -> None:
        self.trie = Trie()
        self.bpe = {}
        self.tokens = {}
    
    def encode(self, bytes: bytearray) -> List[int]:
        # current 2n
        tokens = []
        cur = self.trie.root
        cur_tok = -1
        idx = -1

        i = 0
        n = len(bytes)

        while i < n:
            if cur.is_end:
                cur_tok = cur.tok
                idx = i

            if bytes[i] not in cur.children:
                tokens.append(cur_tok)
                cur = self.trie.root
                cur_tok = -1
                i = idx
            else:
                cur = cur.children[bytes[i]]
                i += 1
    
        return tokens

    def decode(self, tokens: List[int]) -> str:
        bytes = bytearray([b for tok in tokens for b in self.bytes[tok]])
        return str(bytes, encoding='utf8')

    def train(self, bytes: str, vocab_size: int) -> None:
        # TODO:
        #   add special tokens
        #   EOS/BOS not necessary for one doc
        #   bug right now
        #   how to place BOS and EOS?
        self.trie = Trie()
        self.tokens = {}
        self.bytes = {}
        for b in range(256):
            self.bytes[b] = [b]
            self.tokens[(b)] = b
            self.trie.insert([b], b)
        # for special_tok in ["<EOS>", "<BOS>", "<PAD>"]:
        #    tok = len(self.tokens)
        #    b = bytearray(special_tok, 'utf8')
        #    self.bytes[tok] = b
        #    self.tokens[tuple(b)] = tok
        if vocab_size < 256: raise ValueError("vocabulary size < 259")

        for _ in range(vocab_size - 256):
            freq = defaultdict(int)
            maxpair = None

            tokens = self.encode(bytes)
            n = len(tokens)
            for i in range(n - 1):
                tok1, tok2 = tokens[i], tokens[i + 1]
                freq[(tok1, tok2)] += 1
                if freq[(tok1, tok2)] > freq[maxpair]:
                    maxpair = (tok1, tok2)
            tok_bytes = self.bytes[maxpair[0]] + self.bytes[maxpair[1]]

            tok = len(self.tokens)
            self.trie.insert(tok_bytes, tok)
            self.bytes[tok] = tok_bytes
            self.tokens[tuple(tok_bytes)] = tok
        return

In [14]:
tokenizer = BytePairEncoder()
tokenizer.train(bytes, 280)

In [26]:
toks = tokenizer.encode(bytes)
decode = tokenizer.decode(toks)
print(len(data))
print(len(decode))
print(data[-2], decode[-2])

1115394
1115392
. n


In [46]:
def encode(data, trie):
    # 2 * n
    tokens = []
    cur = trie.root
    cur_tok = ""
    idx = -1

    i = 0
    n = len(data)

    while i < n:
        if cur.is_end:
            cur_tok = cur.s
            idx = i

        if data[i] not in cur.children:
            tokens.append(cur_tok)
            cur = trie.root
            i = idx
        else:
            cur = cur.children[data[i]]
            i += 1
    return tokens

In [56]:
def bpe(data, size):
    trie = Trie()
    bps = {}
    toks = {}
    for c in data:
        if c not in bps:
            toks[len(bps)] = c
            bps[c] = len(bps)
    if len(bps) > size: raise ValueError("vocabulary size < current size")
    for c in bps:
        trie.insert(c)

    for _ in range(size - len(bps)):
        tokens = encode(data, trie)
        freq = defaultdict(int)
        n = len(tokens)
        maxpair = None
        for i in range(n - 1):
            a, b = bps[tokens[i]], bps[tokens[i + 1]]
            freq[(a, b)] += 1
            if freq[(a, b)] > freq[maxpair]:
                maxpair = (a, b)
        tok = toks[maxpair[0]] + toks[maxpair[1]]
        trie.insert(tok)
        toks[len(bps)] = tok
        bps[tok] = len(bps)
    return bps, toks

print(bpe(data, 200))

{'F': 0, 'i': 1, 'r': 2, 's': 3, 't': 4, ' ': 5, 'C': 6, 'z': 7, 'e': 8, 'n': 9, ':': 10, '\n': 11, 'B': 12, 'f': 13, 'o': 14, 'w': 15, 'p': 16, 'c': 17, 'd': 18, 'a': 19, 'y': 20, 'u': 21, 'h': 22, ',': 23, 'm': 24, 'k': 25, '.': 26, 'A': 27, 'l': 28, 'S': 29, 'Y': 30, 'v': 31, '?': 32, 'R': 33, 'M': 34, 'W': 35, "'": 36, 'L': 37, 'I': 38, 'N': 39, 'g': 40, ';': 41, 'b': 42, '!': 43, 'O': 44, 'j': 45, 'V': 46, '-': 47, 'T': 48, 'H': 49, 'E': 50, 'U': 51, 'D': 52, 'P': 53, 'q': 54, 'x': 55, 'J': 56, 'G': 57, 'K': 58, 'Q': 59, '&': 60, 'Z': 61, 'X': 62, '3': 63, '$': 64}
({'F': 0, 'i': 1, 'r': 2, 's': 3, 't': 4, ' ': 5, 'C': 6, 'z': 7, 'e': 8, 'n': 9, ':': 10, '\n': 11, 'B': 12, 'f': 13, 'o': 14, 'w': 15, 'p': 16, 'c': 17, 'd': 18, 'a': 19, 'y': 20, 'u': 21, 'h': 22, ',': 23, 'm': 24, 'k': 25, '.': 26, 'A': 27, 'l': 28, 'S': 29, 'Y': 30, 'v': 31, '?': 32, 'R': 33, 'M': 34, 'W': 35, "'": 36, 'L': 37, 'I': 38, 'N': 39, 'g': 40, ';': 41, 'b': 42, '!': 43, 'O': 44, 'j': 45, 'V': 46, '-': 47

In [None]:
d = 1024        # num dimensions = dim(model)
L = 1           # num layers
H = 8           # num heads
MAXT = 4098     # max tokens
V = 4098        # vocab size
d_ff = 2048     # dim(FFN hidden layers)
p_drop = 0.1    # dropout prob

transformer = Transformer(d, L, H, MAXT, V, d_ff, p_drop)

In [3]:
X = torch.randint(V, (1, 256))
out = transformer(X)
print(out.shape)
print(out)

torch.Size([1, 256, 4098])
tensor([[[ 53.1798,  38.1477, -28.7207,  ..., -16.9216,   2.8838,  -1.6856],
         [ 27.5028, -17.9487,  30.4112,  ...,  83.4197,  31.8329, -35.6247],
         [ 44.7758,  42.3246, -23.0053,  ..., -47.0368, -24.6917,   1.8160],
         ...,
         [-17.7791,  95.5564,  21.2843,  ...,  38.1742,   9.4863,  15.7194],
         [ -1.5760, -93.9202, -34.0709,  ...,  75.0702,  12.7765, -14.5807],
         [-25.9921,  12.1569, -30.9109,  ...,  25.1669,  46.5577,   5.1974]]],
       grad_fn=<UnsafeViewBackward0>)
