# Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import math
import numpy as np

# Layers


## Scaled Dot Product Attention


$
\LARGE\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V\LARGE
$

In [2]:
def scaled_dot_product_attention(
    q, k, v, mask=0):
    scores = q @ k.transpose(-1, -2)      # q @ k.T
    scaled = scores / (q.shape[-1]**.5)   # * 1 / sqrt(dim)
    scaled = scaled + mask                # + mask
    attention = F.softmax(scaled, dim=-1) # softmax()
    value = attention @ v                 # @ v
    return value, attention

### Mask

#### Attention Mask

$
M = \begin{bmatrix}
0 & -\infty & -\infty & -\infty \\
0 & 0 & -\infty & -\infty \\
0 & 0 & 0 & -\infty \\
0 & 0 & 0 & 0 \\
\end{bmatrix}
$

#### Padding Mask

$
M = \begin{bmatrix}
0 & 0 & -\infty & -\infty \\
0 & 0 & -\infty & -\infty \\
0 & 0 & -\infty & -\infty \\
0 & 0 & -\infty & -\infty \\
\end{bmatrix}
$

#### Code

In [3]:

get_mask = lambda l=4, v=(0,float('-inf')) : torch.tensor([[v[0] if b <= a else v[1] for b in range(l)] for a in range(l)])
padding_mask = lambda l=4, i=2, v=(0, float('-inf')) : torch.tensor([v[0]] * i + [v[1]] * (l - i))[:l]


In [4]:
shape = [2, 10, 3] # B S D
l = nn.Linear(shape[-1], shape[-1]*3)
x = torch.randn(shape)
Wqkv = l(x)
q, k, v = torch.chunk(Wqkv, chunks=3, dim=-1)
value, attention = scaled_dot_product_attention(q, k, v, mask=0)
print(f" value      {value.shape}\n attention  {attention.shape}")

 value      torch.Size([2, 10, 3])
 attention  torch.Size([2, 10, 10])


## Multi-Head Attention

In [5]:
class MultiHeadAttention(nn.Module):
    
    def __init__(s, embed_dim, num_heads):
        super().__init__()
        
        assert embed_dim % num_heads == 0, f"cant split {embed_dim} embed_dim to {num_heads} heads"
        
        # :int
        s.edim = embed_dim
        s.head = num_heads
        s.hdim = embed_dim // num_heads
        
        # :nn.Linear
        s.Wq = nn.Linear(embed_dim, embed_dim)
        s.Wk = nn.Linear(embed_dim, embed_dim)
        s.Wv = nn.Linear(embed_dim, embed_dim)
        s.Wo = nn.Linear(embed_dim, embed_dim)
    
    def forward(s, xq, xk, xv, mask=0):
        
        # Linear
        q = s.Wq(xq) # B, S, ED
        k = s.Wk(xk)
        v = s.Wv(xv)

        hq = s.split_heads(q) # B, H, S, HD
        hk = s.split_heads(k)
        hv = s.split_heads(v)

        value, attention = scaled_dot_product_attention(hq, hk, hv, mask)
        
        combined = s.combine_heads(value) # B, S, ED
        output = s.Wo(combined)
        
        return output, attention

    def split_heads(s, x):                               # B, S, ED
        xh = x.reshape(x.shape[:-1] + (s.head, s.hdim))  # B, S, (H, HD)
        return xh.transpose(-2, -3)                      # B, (H, S), HD

    def combine_heads(s, x):                             # B, H, S, HD
        hs = x.transpose(-2, -3)                         # B, (S, H), HD
        return hs.reshape(hs.shape[:-2] + (s.edim,))     # B, S, (ED)


In [6]:

embed_dim = 9
num_heads = 3

m = MultiHeadAttention(embed_dim, num_heads)

batch_size = 2
seq_len = 5

x = torch.randn(batch_size, seq_len, embed_dim)

y, a = m(x, x, x)
y.shape

torch.Size([2, 5, 9])

## Layer Normalization


$
\LARGE\mu = \frac{1}{N} \sum_{i=1}^{N} x_i
$

$
\LARGE\sigma^2 = \frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2
$

$
\LARGE\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}
$

$
\LARGE y_i = \gamma \hat{x}_i + \beta
$

In [7]:
class LayerNorm(nn.Module):

    def __init__(s, shape, eps=1e-5):
        super().__init__()
        s.reshape(shape)
        s.eps = eps

    def reshape(s, shape):
        s.shape = (shape,) if type(shape) is int else shape
        s.gamma = nn.Parameter(torch.ones(*s.shape))
        s.beta = nn.Parameter(torch.zeros(*s.shape))

    def forward(s, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        std = torch.sqrt(var + s.eps)
        x_norm = (x - mean) / std
        y = x_norm * s.gamma + s.beta

        return y
        

## Position Wise Feed-Forward

In [8]:
class PositionWiseFeedForward(nn.Module):

    def __init__(s, embed_dim, d_ff=0, dropout=0.1):
        super().__init__()
        if not d_ff: d_ff = embed_dim * 4
        s.linear1 = nn.Linear(embed_dim, d_ff)
        s.linear2 = nn.Linear(d_ff, embed_dim)
        s.dropout = nn.Dropout(dropout)

    def forward(s, x):
        x = s.linear1(x)
        x = F.relu(x)
        x = s.dropout(x)
        x = s.linear2(x)
        return x
        

In [9]:
embed_dim = 10
d_ff = 20
m = PositionWiseFeedForward(embed_dim, d_ff)

x = torch.randn(2, 2, embed_dim)
y = m(x)

y.shape

torch.Size([2, 2, 10])

## Positional Encoding

In [10]:
class PositionalEncoding(nn.Module):

    def __init__(s, embed_dim, seq_len=5000):
        super().__init__()
        s.pe = positional_encoding(seq_len, embed_dim)

    def forward(s, x):
        x = x + s.pe[:x.size(-2),:]
        return x

$
\LARGE\text PE(pos, 2i) = \sin\left(\frac{pos}{10000^{\frac{2i}{d}}}\right) \text
$

$
\LARGE\text PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{\frac{2i}{d}}}\right) \text
$

In [11]:

positional_encoding = lambda s, d, n=10000 : torch.Tensor([[ ( np.sin(k/np.power(n, i/d)) if i % 2 == 0 else np.cos(k/np.power(n, (i-1)/d)) ) for i in range(d)] for k in range(s)])


In [12]:
x = torch.randn(10, 2, 3)
pe = positional_encoding(*x.shape[-2:])
x + pe

tensor([[[-0.4790,  1.2248, -0.9709],
         [ 0.6315,  0.3719, -0.1971]],

        [[ 0.6122,  1.2353, -0.7712],
         [ 1.9904, -0.8597, -1.3386]],

        [[-1.7526,  2.1024, -0.1130],
         [ 1.6807,  1.3510, -1.5821]],

        [[ 1.2562,  0.9885, -1.7097],
         [ 1.3562,  1.2714, -0.1139]],

        [[-0.2333,  0.8334, -2.4092],
         [ 1.1264,  0.7290,  0.6172]],

        [[ 0.5313,  1.3085, -0.6649],
         [ 0.4347,  0.8596,  1.0100]],

        [[ 2.3336,  0.2377,  0.7649],
         [ 1.4001, -0.6894, -1.6972]],

        [[-0.8829,  1.6883,  0.0192],
         [ 3.7390,  0.3130, -0.1684]],

        [[-0.7948,  1.4359,  0.5092],
         [ 1.8898,  1.3564,  0.6197]],

        [[-0.0524,  3.7284,  0.0861],
         [ 0.0645,  1.9532,  0.2532]]])

In [13]:
embed_dim = 4
seq_len = 2
pe = PositionalEncoding(embed_dim)
x = torch.zeros(5, seq_len, embed_dim)
y = pe(x)
y.shape

torch.Size([5, 2, 4])

# Transformer

## Encoder Layer

In [14]:
class EncoderLayer(nn.Module):
    
    def __init__(s, embed_dim, num_heads, d_ff=0, dropout=0.1):
        super().__init__()

        if not d_ff: d_ff = embed_dim * 4
        
        s.mha = MultiHeadAttention(embed_dim, num_heads)
        s.mha_dropout = nn.Dropout(dropout)
        s.mha_norm = LayerNorm(embed_dim)
        
        s.ffn = PositionWiseFeedForward(embed_dim, d_ff)
        s.ffn_dropout = nn.Dropout(dropout)
        s.ffn_norm = LayerNorm(embed_dim)

    def forward(s, x, _=0, mask=0):
        
        v, _ = s.mha(x, x, x, mask=mask)
        a    = s.mha_dropout(v) + x
        x    = s.mha_norm(a)

        f    = s.ffn(x)
        a    = s.ffn_dropout(f) + x
        x    = s.ffn_norm(a)

        return x

In [15]:
embed_dim = 9
num_heads = 3

e = EncoderLayer(embed_dim, num_heads)

x = torch.randn(2, 10, embed_dim)
y = e(x)
y.shape

torch.Size([2, 10, 9])

## Decoder Layer

In [16]:
class DecoderLayer(nn.Module):
    
    def __init__(s, embed_dim, num_heads, d_ff=0, dropout=0.1):
        super().__init__()

        if not d_ff: d_ff = embed_dim * 4

        s.m_mha = MultiHeadAttention(embed_dim, num_heads)
        s.m_mha_dropout = nn.Dropout(dropout)
        s.m_mha_norm = LayerNorm(embed_dim)
        
        s.c_mha = MultiHeadAttention(embed_dim, num_heads)
        s.c_mha_dropout = nn.Dropout(dropout)
        s.c_mha_norm = LayerNorm(embed_dim)

        s.p_ffn = PositionWiseFeedForward(embed_dim, d_ff)
        s.p_ffn_dropout = nn.Dropout(dropout)
        s.p_ffn_norm = LayerNorm(embed_dim)

    def forward(s, x, ey=None, mask=0):
        
        if ey is None: ey = x

        v, _ = s.m_mha(x, x, x, mask=mask)
        a    = s.m_mha_dropout(v) + x
        x    = s.m_mha_norm(a)
        
        v, _ = s.c_mha(x, ey, ey, mask=mask)
        a    = s.c_mha_dropout(v) + x
        x    = s.c_mha_norm(a)

        f    = s.p_ffn(x)
        a    = s.p_ffn_dropout(f) + x
        x    = s.p_ffn_norm(a)

        return x

In [17]:
embed_dim = 9
num_heads = 3

d = DecoderLayer(embed_dim, num_heads)

x = torch.randn(2, 10, embed_dim)
y = d(x)
y.shape

torch.Size([2, 10, 9])

## Transformer Layer Customized

In [18]:
encoder_layer = lambda *args : TransformerLayer(*args, layer=TransformerLayer.encoder_layer)
decoder_layer = lambda *args : TransformerLayer(*args, layer=TransformerLayer.decoder_layer)

class TransformerLayer(nn.Module):
    
    def __init__(s, embed_dim, num_heads, d_ff=0, dropout=0.1, layer=decoder_layer):
        super().__init__()
        s.layer = s.getLayer(layer=layer, embed_dim=embed_dim, num_heads=num_heads, d_ff=d_ff, dropout=dropout)
    
    encoder_layer = ['mha', 'ff']          # Encoder - MultiHeadAttention + PositionWiseFeedForward
    decoder_layer = ['mha', 'cmha', 'ff']  # Decoder - MultiHeadAttention + Cross-MultiHeadAttention + PositionWiseFeedForward
    
    blocks_init = {
        'mha'  : lambda *args, embed_dim, num_heads,      **kwargs : MultiHeadAttention(embed_dim, num_heads),
        'cmha' : lambda *args, embed_dim, num_heads,      **kwargs : MultiHeadAttention(embed_dim, num_heads),
        'ff'   : lambda *args, embed_dim, d_ff, dropout,  **kwargs : PositionWiseFeedForward(embed_dim, d_ff, dropout),
        
        'do'   : lambda *args, dropout,                   **kwargs : nn.Dropout(dropout),
        'ln'   : lambda *args, embed_dim,                 **kwargs : LayerNorm(embed_dim),
    }
    
    blocks_forward = {
        'mha'  : lambda *args, f, x, ey, mask, **kwargs : f(x, x, x, mask=mask)[0],
        'cmha' : lambda *args, f, x, ey, mask, **kwargs : f(x, ey, ey, mask=mask)[0],
        'ff'   : lambda *args, f, x, ey, mask, **kwargs : f(x),
    }
    
    def getLayer(s, layer, **kwargs):
        init = TransformerLayer.blocks_init
        forward = TransformerLayer.blocks_forward
        return [
            {
                'name': b,
                'layer': init[b](**kwargs),
                'forward': forward[b],
                'postprocess': {
                    'dropout': init['do'](**kwargs),
                    'norm': init['ln'](**kwargs),
                }
            }
        for b in layer]

    def blockForward(s, **kwargs):
        x = 0
        for b in s.layer:
            v = b['forward'](f=b['layer'], **kwargs)
            a = b['postprocess']['dropout'](v) + x
            x = b['postprocess']['norm'](a)
        return x
    
    def forward(s, x, ey=None, mask=0):
        if ey is None: ey = x
        return s.blockForward(x=x, ey=ey, mask=mask)



In [19]:
embed_dim = 9
num_heads = 3

d = decoder_layer(embed_dim, num_heads)

x = torch.randn(2, 10, embed_dim)
y = d(x)
d.layer

[{'name': 'mha',
  'layer': MultiHeadAttention(
    (Wq): Linear(in_features=9, out_features=9, bias=True)
    (Wk): Linear(in_features=9, out_features=9, bias=True)
    (Wv): Linear(in_features=9, out_features=9, bias=True)
    (Wo): Linear(in_features=9, out_features=9, bias=True)
  ),
  'forward': <function __main__.TransformerLayer.<lambda>(*args, f, x, ey, mask, **kwargs)>,
  'postprocess': {'dropout': Dropout(p=0.1, inplace=False),
   'norm': LayerNorm()}},
 {'name': 'cmha',
  'layer': MultiHeadAttention(
    (Wq): Linear(in_features=9, out_features=9, bias=True)
    (Wk): Linear(in_features=9, out_features=9, bias=True)
    (Wv): Linear(in_features=9, out_features=9, bias=True)
    (Wo): Linear(in_features=9, out_features=9, bias=True)
  ),
  'forward': <function __main__.TransformerLayer.<lambda>(*args, f, x, ey, mask, **kwargs)>,
  'postprocess': {'dropout': Dropout(p=0.1, inplace=False),
   'norm': LayerNorm()}},
 {'name': 'ff',
  'layer': PositionWiseFeedForward(
    (linear

In [20]:
tl = decoder_layer(9, 3)
tl.layer

[{'name': 'mha',
  'layer': MultiHeadAttention(
    (Wq): Linear(in_features=9, out_features=9, bias=True)
    (Wk): Linear(in_features=9, out_features=9, bias=True)
    (Wv): Linear(in_features=9, out_features=9, bias=True)
    (Wo): Linear(in_features=9, out_features=9, bias=True)
  ),
  'forward': <function __main__.TransformerLayer.<lambda>(*args, f, x, ey, mask, **kwargs)>,
  'postprocess': {'dropout': Dropout(p=0.1, inplace=False),
   'norm': LayerNorm()}},
 {'name': 'cmha',
  'layer': MultiHeadAttention(
    (Wq): Linear(in_features=9, out_features=9, bias=True)
    (Wk): Linear(in_features=9, out_features=9, bias=True)
    (Wv): Linear(in_features=9, out_features=9, bias=True)
    (Wo): Linear(in_features=9, out_features=9, bias=True)
  ),
  'forward': <function __main__.TransformerLayer.<lambda>(*args, f, x, ey, mask, **kwargs)>,
  'postprocess': {'dropout': Dropout(p=0.1, inplace=False),
   'norm': LayerNorm()}},
 {'name': 'ff',
  'layer': PositionWiseFeedForward(
    (linear

## Transformer Layers

In [21]:
class TransformerLayers(nn.Module):

    def __init__(s, 
                 vocab_size, seq_len, 
                 embed_dim, num_heads, 
                 d_ff=0, dropout=0.1, 
                 layer=encoder_layer, num_layers=1):
        super().__init__()
        
        s.embeding = nn.Embedding(vocab_size, embed_dim)
        s.pe = PositionalEncoding(embed_dim, seq_len)
        
        s.layers = nn.ModuleList([
            layer(embed_dim, num_heads, d_ff, dropout) 
            for _ in range(num_layers)
        ])

    def forward(s, x, ey=0, mask=0):
        
        x = s.embeding(x)
        x = s.pe(x)

        for layer in s.layers:
            x = layer(x, ey, mask)

        return x

In [22]:
vocab_size = 10
seq_len = 2
embed_dim = 9
num_heads = 3

encoder = TransformerLayers(
    vocab_size, seq_len, embed_dim, num_heads, 
    layer=EncoderLayer, num_layers=2
)

x = torch.randint(0, vocab_size, (3, 2))

encoder(x).shape

torch.Size([3, 2, 9])

## Transformer

In [23]:
class Transformer(nn.Module):

    def __init__(s, 
                 vocab_size, src_len, tgt_len, 
                 embed_dim, num_heads, 
                 d_ff=0, dropout=0.1, num_layers=2):
        super().__init__()
        
        s.encoder = TransformerLayers(
            vocab_size, src_len, embed_dim, num_heads,
            d_ff=d_ff, dropout=dropout,
            layer=encoder_layer, num_layers=num_layers
        )

        s.decoder = TransformerLayers(
            vocab_size, tgt_len, embed_dim, num_heads,
            d_ff=d_ff, dropout=dropout,
            layer=decoder_layer, num_layers=num_layers
        )

        s.linear = nn.Linear(embed_dim, vocab_size)

    def forward(s, src=None, tgt=None, src_mask=0, tgt_mask=0):
        x = None
        if src is not None: x = s.encoder(src, mask=src_mask)
        if tgt is not None: x = s.decoder(tgt, x, mask=tgt_mask + get_mask(tgt.size(-1)))

        if tgt is not None:
            x = s.linear(x)
            # x = torch.softmax(x, dim=-1) # CrossEntropyLoss Loss
        
        return x

    def generate(s, num_tokens=10, tgt=None, tgt_mask=0):
        lt = lambda a : torch.stack(a)[..., :, -1] # tokens, batch, pred
        o = [tgt]
        for i in range(num_tokens):
            f = s(tgt=o[-1], tgt_mask=tgt_mask)
            pred = torch.argmax(f, dim=-1)
            o.append(pred)
        return lt(o)

In [24]:
vocab_size = 2
seq_len = 4
embed_dim = 3
num_heads = 1

t = Transformer(vocab_size, seq_len, seq_len, embed_dim, num_heads, num_layers=2)

xg = lambda b=10 : torch.randint(0, vocab_size, (b, seq_len,))
yg = lambda b=10 : torch.randint(0, vocab_size, (b, seq_len,))

# print( t(src=x, tgt=y).shape )
# print( t(src=x).shape )
# print( t(tgt=y).shape )


criterion = nn.CrossEntropyLoss()

# optimizer = optim.Adam(t.parameters(), lr=1.0, betas=(0.9, 0.98), eps=1e-9)
# optimizer = optim.Adam(t.parameters(), lr=0.000005)
optimizer = optim.Adam(t.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)


t.train()
for i in range(1):
    y = yg()
    optimizer.zero_grad()
    dy = t(tgt=y)
    
    loss = criterion(dy.view(-1, vocab_size), y.view(-1))
    # loss = criterion(dy, y)
    loss.backward()
    optimizer.step()
    
    pred = torch.argmax(dy, dim=-1)

o = t.generate(tgt=yg(1))
o


tensor([[0],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])

# Tokenizer

In [25]:
import regex

auto_map = lambda fn, arr, *arg: [fn(e, *arg) if type(e) is not list else auto_map(fn, e, *arg) for e in arr]

# [1,2,3] -> {(1,2):1, (2,3):1}
def count_pair(token, state={}):
    for pair in zip(token, token[1:]): # (0, 1), (1, 2), (2, 3)
        state[pair] = state.get(pair, 0) + 1 # pair:count
    return state

# [1,2,3], (1,2), 257 -> [257,3]
def merge(token, pair, replace):
    y = []; i = 0; lt = len(token)
    while i < lt:
        if token[i] == pair[0] and i+1 < lt and token[i+1] == pair[1]:
            y.append(replace); i += 2
        else:
            y.append(token[i]); i += 1
    return y

# [1,[2,[3,4]]] -> [1,2,3,4]
flattern = lambda arr : a if (a := [i for a in arr for i in (a if type(a) is list else [a])]) and all([type(i) is not list for i in a]) else flattern(a)

# [[[1,2]][[[3, 4]]] -> [[1,2],[3,4]]
keep_last_dim = lambda v : keep_last_dim([i for arr in v for i in arr]) if type(v[0][0]) is list else v # merges first dim


class BytePairEncoding:
    
    def __init__(s):
        pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
        s.regex = regex.compile(pattern)
        s.char_size = 256 # UTF-8 | 0-255
        s.vocab = {i:i for i in range(s.char_size)}; # 258: [1, 7, 9] | idx:values
        s.merges = {}; # (1, 7): 257 | pair:new_id
    
    def split(s, text): return s.regex.findall(text) # ['Hello', ' world']
    
    def chr(s, idx): return chr(idx) # 65 -> 'A'
    
    def ord(s, text): return [ord(c) if ord(c) < s.char_size else s.char_size-1 for c in text] # 'a1A' -> [97, 49, 65]

    def tokenize_text(s, text): return auto_map(s.ord, s.split(text)) # [[0,6], [1,7,6], [5,8,4]]
    
    def str(s, ids): return ''.join([s.chr(idx) for idx in ids]) # [97, 98, 99] -> 'abc"
    
    def train(s, text, num_merges=3):
        cs = s.char_size
        tokens = s.tokenize(text) # [[0,6], [1,7,6], [5,8,4]]
        
        for i in range(cs, cs+num_merges):
            state = {} # count pairs eg. = (1,2):3, (2,3):1
            for token in tokens: count_pair(token, state) # update state
            pair = max(state, key=state.get, default=0) # top pair | (1,2):3 | when 3 is max
            if pair == 0: continue
            tokens = [merge(token, pair, i) for token in tokens] # merge

            s.merges[pair] = i
            s.vocab[i] = flattern((s.vocab[pair[0]], s.vocab[pair[1]])) # i: [c1, c2, c3...] | flattern(vocab[idx]..)
    
    def encode_text(s, text):
        tokens = s.tokenize(text) # [[0,6], [1,7,6], [5,8,4]]
        for i in range(len(tokens)):
        	while len(tokens[i]) > 1:
        		pairs = (set(zip(tokens[i], tokens[i][1:]))) # (256, 257), (257, 258)
        		pair = min([p for p in pairs if p in s.merges.keys()] or [0], key=lambda p: s.merges.get(p, 0)) # get earliest pair | (1, 6): 256 | when 256 is min
        		if pair == 0: break
        		tokens[i] = merge(tokens[i], pair, s.merges[pair])
        return tuple(tokens)

    def decode_tokens(s, tokens):
        return s.str(flattern([s.vocab[idx] for idx in flattern(tokens)]))
    
    # wrapper functions
    
    def tokenize(s, o):
        fn = s.tokenize_text
        if type(o) is list: return keep_last_dim(auto_map(fn, o))
        else: return fn(o)
    
    def encode(s, o):
        fn = s.encode_text
        if type(o) is list: return auto_map(fn, o)
        else: return fn(o)
    
    def decode(s, o):
        fn = s.decode_tokens
        if type(o) is list: return auto_map(fn, o)
        else: return fn(o)

    # debug

    def tokens(s):
        added_vocabs = list(s.vocab.values())[s.char_size:]
        return [s.decode((vocab,)) for vocab in added_vocabs]
    
    def test(s, text="test"):
        d = s.decode(s.encode(text))
        print("Tokenizer Decode Pass" if d == text else "Tokenizer Decode Fail!")
        return text

In [26]:
a = "Hello hello hello world world"
t = BytePairEncoding()
t.train(a, 4)

e = (flattern(t.encode(a)),)
d = t.decode(e)
d

'Hello hello hello world world'

In [27]:
text = ["tetetete tetetet"]
t = BytePairEncoding()
t.train(text, 4)

e = t.encode(text)
t.decode(e)
t.test()

Tokenizer Decode Pass


'test'

# Dataset Import

In [28]:
import os

path = "../dataset/dataset/discord/"

files = os.listdir(path)

chats = ''

for file in files:
	with open(path+file, 'r', encoding="UTF-8") as f:
		chats += f.read()


chats = chats.split('	')


if __name__ == '__main__':
	
	from random import randint as r
	rd = r(0, len(chats) - 10)
    
	print('files :')
	[print(i) for i in files]
	print()
	print('chat examples :')
	[print(chats[rd+i]) for i in range(10)]


files :
[136542963336478720] [part 10].txt

chat examples :
kiirby: What concert?! @coffee
iwinalot7: Shadiverisity fucking sucks
bearnadette: @greyasashe meirl
iwinalot7: This is 2018 get a better mic
bearnadette: except the car thing
greyasashe: i know we all have your preconceptions about medieval snake people but what would they REALLY use
iwinalot7: He's also cuz incorrect a lot of the time
coffee: Kii what do you mean what concert
bearnadette: also known as: which human race is most similar to snakes based on my perceptions of foreigners,\nalso also known as: let's write a fantasy novel.
iwinalot7: Idk I just hate that guy lmao


# Usage

## File

In [29]:
import dill as pickle
import os

default = {
    'directory'   : '../dataset/',
    'name'        : 'default',
    'extension'   : '.pickle'
}

path = lambda name, extension = default['extension']  : default['directory'] + name + extension


def load(name = default['name']):
    with open(path(name), 'rb') as file: obj = pickle.load(file)
    return obj


def save(dat, name = default['name']):
    os.makedirs(os.path.dirname(path(name)), exist_ok=True)
    with open(path(name), 'wb') as file: pickle.dump(dat, file)
    return dat

## Code

### Train Tokenizer

In [30]:
# Tokenizer Train 
""" ~30 mins
merges = 10000
limit = merges*1//4
tokenizer = BytePairEncoding()
tokenizer.train(chats[:limit], merges)
save(tokenizer, 'tokenizer')
print(tokenizer.tokens())
#"""

" ~30 mins\nmerges = 10000\nlimit = merges*1//4\ntokenizer = BytePairEncoding()\ntokenizer.train(chats[:limit], merges)\nsave(tokenizer, 'tokenizer')\nprint(tokenizer.tokens())\n#"

In [31]:
t = load('tokenizer')
a = ["Hello world! how are you!"]*batch_size
e = [(flattern(i),) for i in t.encode(a)]
d = t.decode(e)
t.test(a)

Tokenizer Decode Pass


['Hello world! how are you!', 'Hello world! how are you!']

### Train Transformer

In [35]:
reset = 1

chats = chats

tokenizer = load('tokenizer')

vocab_size = len(tokenizer.vocab)
seq_len = 20
embed_dim = 16
num_heads = 2

padd_seq = lambda ids: torch.tensor( (ids + [0]*(seq_len-len(ids)))[:seq_len] )
get_right_mask = lambda ids : torch.stack([padding_mask(seq_len, len(tokens)) for tokens in ids]).unsqueeze(-2).repeat(1, seq_len, 1).unsqueeze(-3)

model = Transformer(
    vocab_size=vocab_size, src_len=seq_len, tgt_len=seq_len,
    embed_dim=embed_dim, num_heads=num_heads,
    num_layers=1, dropout=0.1,
) if reset else load('model')

# criterion = nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss(ignore_index=0)


total_batches = 100 # 500
batch_size = 16
outputs = []

optimizer = optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.98), eps=1e-9)
# optimizer = optim.Adam(model.parameters(), lr=5e-4)



# training loop
model.train()
for i in range(total_batches):
    chat = [''.join(chat.split(':')[1:]) for chat in chats[i*batch_size:i*batch_size+batch_size]]
    # chat = ["Hello world! how are you!"]*batch_size
    ids = [flattern(i) for i in tokenizer.encode(chat)]
    
    optimizer.zero_grad()
    
    xpadded = torch.stack([padd_seq(tokens[:-1]) for tokens in ids])
    right_mask = get_right_mask(ids)
    output = model(tgt=xpadded, tgt_mask=right_mask)
    
    ypadded = torch.stack([padd_seq(tokens[1:]) for tokens in ids])
    loss = criterion(output.view(-1, vocab_size), ypadded.view(-1))
    
    loss.backward()
    optimizer.step()
    
    
    predicted = torch.argmax(output, dim=-1)
    decoded = tokenizer.decode([tuple(pred) for pred in predicted.tolist()])
    
    if i % 1 == 0: print(f'({i}) Loss : {loss.item()}');
    # outputs.append(output); if i % 5 == 0 and i: outputs = torch.hstack(outputs); display_img(outputs); outputs=[];
    # display_img(output);
    if i % batch_size == 0 and 1: [(print('*', chat[i]), print('-', chat[i][0] + decoded[i]), print()) for i in range(len(decoded))]
    # print()


model = save(model, 'model')

print("TESTING...", "\n"*2)

for i in range(10):
    text_start = ["hi "]
    ids = [flattern(i) for i in tokenizer.encode(text_start)]
    xpadded = torch.stack([padd_seq(tokens[:-1]) for tokens in ids])
    output = model.generate(num_tokens=10, tgt=xpadded, tgt_mask=get_right_mask(ids))
    decoded = tokenizer.decode((*output.tolist(),))
    print(decoded)

(0) Loss : 9.170742988586426
*  discord just told me to look be hind you while i loaded in.
-   arenallall aren lood aren aren Mand\nPatriots} Mand whis\nPatriotshie loodpping\nPatriots Mand Mand lood

*  Like kabouter plop\nIk word daar zoe move vaan
-   Cured Wut Hes serious serious Wutppingpping:Cpping belive:C:C happen:Cpping stuffing\nPatriots happen text

*  @Moose teleports behind you nothing personell kid
-   wheels\nPatriots\ndurr tele\nPatriots:heart\nPatriots ques\nPatriots\nPatriots\nPatriots\nPatriots\nPatriots\nPatriotsfortunately\nPatriots\nPatriots\nPatriots\nPatriots\nPatriots

*  scream
-   sat Lookthol Look emailshie Look\non\nonils tow\non\nIT\nIT sat tow\nonthol Dpping

*  Tfw neither my brother or I quite remember how to get to our great aunts house
-   DY COR tele embarrasingly lood lood lood ac teleppingpping tele\nPatriots spe quespping\nPatriots silverhie\nPatriots

*  Just wing it
-   fracture nah:C jo jophph:C Murfre shore fils fils jo fils NOT fils jo\nThey