In [1]:
CHARS = [ "a", "b" ]
tokenize = lambda s : [CHARS.index(c) for c in s]
decode = lambda token : CHARS[token]

In [2]:
print(tokenize("aabaa"))
print(decode(1))
print(decode(0))

[0, 0, 1, 0, 0]
b
a


In [3]:
import numpy as np

def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

# [m, in] [in, out], [out] --> [m, out]
def linear(x, w, b):
    return x @ w + b

# [ nq, dk] [nk, dk] [nk, dv] [nq, nk] --> [nq, dv]
def attention(query, key, value, mask):
    return softmax(query @ key.T / np.sqrt(query.shape[-1]) + mask) @ value

# [n_seq, n_embd ] --> [ n_seq, n_embd ]
def casual_self_attention(x, c_attn, c_proj):

    #qkv projections
    x = linear(x, **c_attn) # [ n_seq, n_embd ] --> [ n_seq, 3*n_embd ]

    # [ n_seq, 3* n_emdd ] --> 3 * [n _seq, n_embd]
    q, k, v = np.split(x, 3, axis=-1)

    #masking [n_seq, n_seq]
    casual_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10

    #casual attention [n_seq, n_seq ] --> [n_seq, n_seq]
    x = attention(q,k,v,casual_mask)

    #out [n_seq, n_embd] @ [n_embd, n_embd] --> [n_seq, n_embd]
    x = linear(x, **c_proj)

    return x

def transformer_block(x, attn):
    x = x + casual_self_attention(x, **attn)
    return x

# [n_seq] --> [n_seq, n_vocab]
def gpt(inputs, wte, wpe, blocks):
    #token && positional embeddings
    # [n_seq] --> [n_seq, n_embd]
    x = wte[inputs] + wpe[range(len(inputs))]

    #forward layers
    for block in blocks:
        # [n_seq, n_embd] --> [n_seq, n_embd]
        x = transformer_block(x, **block) 

    #project to vocab
    return x @ wte.T # [n_seq, n_embd] --> [n_seq, n_vocab]

In [4]:
context_legnth = 5
vocab = len(CHARS)
n_embd = 8

In [5]:
Lg = 1024
MODEL = {
    "wte": np.array(
    [
        [0,0,0,0,0,1,0,0],
        [0,0,0,0,0,0,1,0]
    ]
    ),
    "wpe": np.array(
    [
        [1,0,0,0,0,0,0,0],
        [0,1,0,0,0,0,0,0],
        [0,0,1,0,0,0,0,0],
        [0,0,0,1,0,0,0,0],
        [0,0,0,0,1,0,0,0]
    ]
    ),
    "blocks": [
        {
            "attn": {
                "c_attn": {
                    "b": np.zeros(n_embd * 3),
                    "w": np.array(
                        #fmt off
                        [
                            [Lg, 0., 0., 0., 0., 0., 0., 0.,  # q
                                1., 0., 0., 0., 0., 0., 0., 0.,  # k
                                    0., 0., 0., 0., 0., 0., 0., 0.], # v
                            [Lg, Lg, 0., 0., 0., 0., 0., 0.,  # q
                                0., 1., 0., 0., 0., 0., 0., 0.,  # k
                                    0., 0., 0., 0., 0., 0., 0., 0.], # v
                            [0., Lg, Lg, 0., 0., 0., 0., 0.,  # q
                                0., 0., 1., 0., 0., 0., 0., 0.,  # k
                                    0., 0., 0., 0., 0., 0., 0., 0.], # v
                            [0., 0., Lg, Lg, 0., 0., 0., 0.,  # q
                                0., 0., 0., 1., 0., 0., 0., 0.,  # k
                                    0., 0., 0., 0., 0., 0., 0., 0.], # v
                            [0., 0., 0., Lg, Lg, 0., 0., 0.,  # q
                                0., 0., 0., 0., 1., 0., 0., 0.,  # k
                                    0., 0., 0., 0., 0., 0., 0., 0.], # v
                            [0., 0., 0., 0., 0., 0., 0., 0.,  # q
                                0., 0., 0., 0., 0., 0., 0., 0.,  # k
                                    0., 0., 0., 0., 0., 0., 0., 1.], # v
                            [0., 0., 0., 0., 0., 0., 0., 0.,  # q
                                0., 0., 0., 0., 0., 0., 0., 0.,  # k
                                    0., 0., 0., 0., 0., 0., 0., -1], # v
                            [0., 0., 0., 0., 0., 0., 0., 0.,  # q
                                0., 0., 0., 0., 0., 0., 0., 0.,  # k
                                    0., 0., 0., 0., 0., 0., 0., 0.], # v
                        ]
                        #fmt on
                    )},
                "c_proj": {  # weights to project attn result back to embedding space
                    "b": [0, 0, 0, 0, 0, Lg, 0, 0],
                    "w": np.array([
                        [0, 0, 0, 0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0, -Lg, Lg, 0],
                    ]),
                },
            }
        }
    ] 
}

In [6]:
def complete(s, max_new_tokens = 10):
    tokens = tokenize(s)
    while len(tokens) < len(s) + max_new_tokens:
        logits = gpt(np.array(tokens[-5:]), **MODEL)
        probs = softmax(logits)
        pred = np.argmax(probs[-1])
        tokens.append(pred)
    return s + " :: " + "".join(decode(t) for t in tokens[len(s):])



In [7]:
print(complete("a"))

a :: baabaabaab


In [11]:
print(complete("a"))

a :: baabaabaab


In [None]:
p